From 7be2cf8652d3ab5f3879217dae37e05484f22fc4 Mon Sep 17 00:00:00 2001 From: Stewart Douglas Date: Fri, 14 Aug 2015 14:39:41 -0400 Subject: [PATCH] MAINT: Allow algo.run() to write to db --- tests/test_algorithm.py | 40 ++++++++++++++++++++++++++++++---------- zipline/algorithm.py | 6 +++++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index e2984020..8460613f 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -489,7 +489,8 @@ class TestTransformAlgorithm(TestCase): def setUp(self): setup_logger(self) self.sim_params = factory.create_simulation_parameters(num_days=4) - trading.environment.write_data(equities_identifiers=[0, 1, 133]) + trading.environment = trading.TradingEnvironment() + trading.environment.write_data(equities_identifiers=[133]) trade_history = factory.create_trade_history( 133, @@ -499,7 +500,6 @@ class TestTransformAlgorithm(TestCase): self.sim_params ) self.source = SpecificEquityTrades(event_list=trade_history) - self.df_source, self.df = \ factory.create_test_df_source(self.sim_params) @@ -526,13 +526,14 @@ class TestTransformAlgorithm(TestCase): algo.run(self.source) def test_multi_source_as_input(self): + trading.environment.write_data(equities_identifiers=[0, 1]) sim_params = SimulationParameters( self.df.index[0], self.df.index[-1] ) algo = TestRegisterTransformAlgorithm( sim_params=sim_params, - sids=[0, 1, 133] + sids=[0, 1] ) algo.run([self.source, self.df_source], overwrite_sim_params=False) self.assertEqual(len(algo.sources), 2) @@ -540,7 +541,6 @@ class TestTransformAlgorithm(TestCase): def test_df_as_input(self): algo = TestRegisterTransformAlgorithm( sim_params=self.sim_params, - sids=[0, 1] ) algo.run(self.df) assert isinstance(algo.sources[0], DataFrameSource) @@ -553,13 +553,23 @@ class TestTransformAlgorithm(TestCase): assert isinstance(algo.sources[0], DataPanelSource) def test_run_twice(self): - algo = TestRegisterTransformAlgorithm( + algo1 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, sids=[0, 1] ) - res1 = algo.run(self.df) - res2 = algo.run(self.df) + res1 = algo1.run(self.df) + + # Create a new trading environment + trading.environment = trading.TradingEnvironment() + # Create a new trading algorithm, which will + # use the newly instantiated environment. + algo2 = TestRegisterTransformAlgorithm( + sim_params=self.sim_params, + sids=[0, 1] + ) + + res2 = algo2.run(self.df) np.testing.assert_array_equal(res1, res2) @@ -617,6 +627,7 @@ class TestTransformAlgorithm(TestCase): 'order_target_value'] for name in method_names_to_test: + trading.environment = trading.TradingEnvironment() algo = TestOrderStyleForwardingAlgorithm( sim_params=self.sim_params, instant_fill=False, @@ -630,6 +641,7 @@ class TestTransformAlgorithm(TestCase): algo.run(self.df) def test_minute_data(self): + trading.environment.write_data(equities_identifiers=[0, 1]) source = RandomWalkSource(freq='minute', start=pd.Timestamp('2000-1-3', tz='UTC'), @@ -646,7 +658,7 @@ class TestPositions(TestCase): def setUp(self): setup_logger(self) self.sim_params = factory.create_simulation_parameters(num_days=4) - trading.environment.write_data(equities_identifiers=[0, 1, 133]) + trading.environment.write_data(equities_identifiers=[1, 133]) trade_history = factory.create_trade_history( 1, @@ -691,10 +703,11 @@ class TestPositions(TestCase): class TestAlgoScript(TestCase): def setUp(self): days = 251 - # Note that create_simulation_parameters creates a new TradingEnvironment + # Note that create_simulation_parameters creates + # a new TradingEnvironment self.sim_params = factory.create_simulation_parameters(num_days=days) setup_logger(self) - trading.environment.write_data(equities_identifiers=[0, 1, 133]) + trading.environment.write_data(equities_identifiers=[1, 133]) trade_history = factory.create_trade_history( 133, [10.0] * days, @@ -755,6 +768,7 @@ class TestAlgoScript(TestCase): def test_fixed_slippage(self): # verify order -> transaction -> portfolio position. # -------------- + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=""" from zipline.api import (slippage, @@ -809,6 +823,7 @@ def handle_data(context, data): def test_volshare_slippage(self): # verify order -> transaction -> portfolio position. # -------------- + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=""" from zipline.api import * @@ -876,6 +891,7 @@ def handle_data(context, data): self.zipline_test_config['algorithm'] = test_algo self.zipline_test_config['trade_count'] = 200 + trading.environment.write_data(equities_identifiers=[0]) zipline = simfactory.create_test_zipline( **self.zipline_test_config) output, _ = drain_zipline(self, zipline) @@ -902,6 +918,7 @@ def handle_data(context, data): test_algo.record(foo=MagicMock()) def _algo_record_float_magic_should_pass(self, var_type): + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=record_float_magic % var_type, sim_params=self.sim_params, @@ -928,6 +945,7 @@ def handle_data(context, data): Only test that order methods can be called without error. Correct filling of orders is tested in zipline. """ + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=call_all_order_methods, sim_params=self.sim_params, @@ -959,6 +977,7 @@ def handle_data(context, data): """ Test that accessing portfolio in init doesn't break. """ + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=access_portfolio_in_init, sim_params=self.sim_params, @@ -977,6 +996,7 @@ def handle_data(context, data): """ Test that accessing account in init doesn't break. """ + trading.environment.write_data(equities_identifiers=[0]) test_algo = TradingAlgorithm( script=access_account_in_init, sim_params=self.sim_params, diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 5ef67116..e850f164 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -31,6 +31,7 @@ from six import ( ) from operator import attrgetter + from zipline.errors import ( AddTermPostInit, OrderDuringInitialize, @@ -192,7 +193,6 @@ class TradingAlgorithm(object): # set the capital base self.capital_base = kwargs.pop('capital_base', DEFAULT_CAPITAL_BASE) - self.sim_params = kwargs.pop('sim_params', None) if self.sim_params is None: self.sim_params = create_simulation_parameters( @@ -500,6 +500,8 @@ class TradingAlgorithm(object): # if DataFrame provided, map columns to sids and wrap # in DataFrameSource copy_frame = source.copy() + self.trading_environment.write_data( + equities_identifiers=source.columns) copy_frame.columns = \ self.asset_finder.map_identifier_index_to_sids( source.columns, source.index[0] @@ -510,6 +512,8 @@ class TradingAlgorithm(object): # If Panel provided, map items to sids and wrap # in DataPanelSource copy_panel = source.copy() + self.trading_environment.write_data( + equities_identifiers=source.items) copy_panel.items = self.asset_finder.map_identifier_index_to_sids( source.items, source.major_axis[0] )