diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 73eee214..b9c874e7 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -136,6 +136,28 @@ class TestMiscellaneousAPI(TestCase): concurrent=True, ) + def test_get_environment(self): + expected_env = { + 'arena': 'backtest', + 'data_frequency': 'minute', + 'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'), + 'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'), + 'capital_base': 100000.0, + 'platform': 'zipline' + } + + def initialize(algo): + self.assertEqual('zipline', algo.get_environment()) + self.assertEqual(expected_env, algo.get_environment('*')) + + def handle_data(algo, data): + pass + + algo = TradingAlgorithm(initialize=initialize, + handle_data=handle_data, + sim_params=self.sim_params) + algo.run(self.source) + def test_get_open_orders(self): def initialize(algo): @@ -481,11 +503,11 @@ class TestAlgoScript(TestCase): algo.run(self.df) def test_api_get_environment(self): - environment = 'zipline' + platform = 'zipline' algo = TradingAlgorithm(script=api_get_environment_algo, - environment=environment) + platform=platform) algo.run(self.df) - self.assertEqual(algo.environment, environment) + self.assertEqual(algo.environment, platform) def test_api_symbol(self): algo = TradingAlgorithm(script=api_symbol_algo) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index b0843bbc..1c17bde9 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -152,7 +152,7 @@ class TradingAlgorithm(object): self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) - self._environment = kwargs.pop('environment', 'zipline') + self._platform = kwargs.pop('platform', 'zipline') self.logger = None @@ -549,8 +549,19 @@ class TradingAlgorithm(object): self.add_history(bars, freq, 'volume') @api_method - def get_environment(self): - return self._environment + def get_environment(self, field='platform'): + env = { + 'arena': self.sim_params.arena, + 'data_frequency': self.sim_params.data_frequency, + 'start': self.sim_params.first_open, + 'end': self.sim_params.last_close, + 'capital_base': self.sim_params.capital_base, + 'platform': self._platform + } + if field == '*': + return env + else: + return env[field] def add_event(self, rule=None, callback=None): """ diff --git a/zipline/finance/trading.py b/zipline/finance/trading.py index d5327ac2..42490b80 100644 --- a/zipline/finance/trading.py +++ b/zipline/finance/trading.py @@ -396,6 +396,9 @@ class SimulationParameters(object): self.data_frequency = data_frequency self.sids = sids + # copied to algorithm's environment for runtime access + self.arena = 'backtest' + self._update_internal() def _update_internal(self):