Merge pull request #449 from quantopian/bug_5089

expand get_environment
This commit is contained in:
fawce
2014-12-12 17:42:58 -05:00
3 changed files with 42 additions and 6 deletions
+25 -3
View File
@@ -136,6 +136,28 @@ class TestMiscellaneousAPI(TestCase):
concurrent=True,
)
def test_get_environment(self):
expected_env = {
'arena': 'backtest',
'data_frequency': 'minute',
'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
'capital_base': 100000.0,
'platform': 'zipline'
}
def initialize(algo):
self.assertEqual('zipline', algo.get_environment())
self.assertEqual(expected_env, algo.get_environment('*'))
def handle_data(algo, data):
pass
algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
sim_params=self.sim_params)
algo.run(self.source)
def test_get_open_orders(self):
def initialize(algo):
@@ -481,11 +503,11 @@ class TestAlgoScript(TestCase):
algo.run(self.df)
def test_api_get_environment(self):
environment = 'zipline'
platform = 'zipline'
algo = TradingAlgorithm(script=api_get_environment_algo,
environment=environment)
platform=platform)
algo.run(self.df)
self.assertEqual(algo.environment, environment)
self.assertEqual(algo.environment, platform)
def test_api_symbol(self):
algo = TradingAlgorithm(script=api_symbol_algo)
+14 -3
View File
@@ -152,7 +152,7 @@ class TradingAlgorithm(object):
self._recorded_vars = {}
self.namespace = kwargs.get('namespace', {})
self._environment = kwargs.pop('environment', 'zipline')
self._platform = kwargs.pop('platform', 'zipline')
self.logger = None
@@ -549,8 +549,19 @@ class TradingAlgorithm(object):
self.add_history(bars, freq, 'volume')
@api_method
def get_environment(self):
return self._environment
def get_environment(self, field='platform'):
env = {
'arena': self.sim_params.arena,
'data_frequency': self.sim_params.data_frequency,
'start': self.sim_params.first_open,
'end': self.sim_params.last_close,
'capital_base': self.sim_params.capital_base,
'platform': self._platform
}
if field == '*':
return env
else:
return env[field]
def add_event(self, rule=None, callback=None):
"""
+3
View File
@@ -396,6 +396,9 @@ class SimulationParameters(object):
self.data_frequency = data_frequency
self.sids = sids
# copied to algorithm's environment for runtime access
self.arena = 'backtest'
self._update_internal()
def _update_internal(self):