mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 04:42:38 +08:00
Merge pull request #449 from quantopian/bug_5089
expand get_environment
This commit is contained in:
+25
-3
@@ -136,6 +136,28 @@ class TestMiscellaneousAPI(TestCase):
|
||||
concurrent=True,
|
||||
)
|
||||
|
||||
def test_get_environment(self):
|
||||
expected_env = {
|
||||
'arena': 'backtest',
|
||||
'data_frequency': 'minute',
|
||||
'start': pd.Timestamp('2006-01-03 14:31:00+0000', tz='UTC'),
|
||||
'end': pd.Timestamp('2006-01-04 21:00:00+0000', tz='UTC'),
|
||||
'capital_base': 100000.0,
|
||||
'platform': 'zipline'
|
||||
}
|
||||
|
||||
def initialize(algo):
|
||||
self.assertEqual('zipline', algo.get_environment())
|
||||
self.assertEqual(expected_env, algo.get_environment('*'))
|
||||
|
||||
def handle_data(algo, data):
|
||||
pass
|
||||
|
||||
algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
sim_params=self.sim_params)
|
||||
algo.run(self.source)
|
||||
|
||||
def test_get_open_orders(self):
|
||||
|
||||
def initialize(algo):
|
||||
@@ -481,11 +503,11 @@ class TestAlgoScript(TestCase):
|
||||
algo.run(self.df)
|
||||
|
||||
def test_api_get_environment(self):
|
||||
environment = 'zipline'
|
||||
platform = 'zipline'
|
||||
algo = TradingAlgorithm(script=api_get_environment_algo,
|
||||
environment=environment)
|
||||
platform=platform)
|
||||
algo.run(self.df)
|
||||
self.assertEqual(algo.environment, environment)
|
||||
self.assertEqual(algo.environment, platform)
|
||||
|
||||
def test_api_symbol(self):
|
||||
algo = TradingAlgorithm(script=api_symbol_algo)
|
||||
|
||||
+14
-3
@@ -152,7 +152,7 @@ class TradingAlgorithm(object):
|
||||
self._recorded_vars = {}
|
||||
self.namespace = kwargs.get('namespace', {})
|
||||
|
||||
self._environment = kwargs.pop('environment', 'zipline')
|
||||
self._platform = kwargs.pop('platform', 'zipline')
|
||||
|
||||
self.logger = None
|
||||
|
||||
@@ -549,8 +549,19 @@ class TradingAlgorithm(object):
|
||||
self.add_history(bars, freq, 'volume')
|
||||
|
||||
@api_method
|
||||
def get_environment(self):
|
||||
return self._environment
|
||||
def get_environment(self, field='platform'):
|
||||
env = {
|
||||
'arena': self.sim_params.arena,
|
||||
'data_frequency': self.sim_params.data_frequency,
|
||||
'start': self.sim_params.first_open,
|
||||
'end': self.sim_params.last_close,
|
||||
'capital_base': self.sim_params.capital_base,
|
||||
'platform': self._platform
|
||||
}
|
||||
if field == '*':
|
||||
return env
|
||||
else:
|
||||
return env[field]
|
||||
|
||||
def add_event(self, rule=None, callback=None):
|
||||
"""
|
||||
|
||||
@@ -396,6 +396,9 @@ class SimulationParameters(object):
|
||||
self.data_frequency = data_frequency
|
||||
self.sids = sids
|
||||
|
||||
# copied to algorithm's environment for runtime access
|
||||
self.arena = 'backtest'
|
||||
|
||||
self._update_internal()
|
||||
|
||||
def _update_internal(self):
|
||||
|
||||
Reference in New Issue
Block a user