diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index dc3db16e..2d3a243e 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -92,7 +92,6 @@ from zipline.testing import ( create_data_portal, create_data_portal_from_trade_history, create_minute_df_for_asset, - empty_trading_env, make_test_handler, make_trade_data_for_asset_info, parameter_space, @@ -108,7 +107,6 @@ from zipline.testing.fixtures import ( WithSimParams, WithTradingEnvironment, WithTmpDir, - WithTradingCalendars, ZiplineTestCase, ) from zipline.test_algorithms import ( @@ -786,7 +784,8 @@ def log_nyse_close(context, data): for i, date in enumerate(dates) ] ) - with tmp_trading_env(equities=metadata) as env: + with tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: algo = TradingAlgorithm(env=env) # Set the period end to a date after the period end @@ -853,7 +852,8 @@ class TestTransformAlgorithm(WithLogger, def init_class_fixtures(cls): super(TestTransformAlgorithm, cls).init_class_fixtures() cls.futures_env = cls.enter_class_context( - tmp_trading_env(futures=cls.make_futures_info()), + tmp_trading_env(futures=cls.make_futures_info(), + load=cls.make_load_function()), ) def test_invalid_order_parameters(self): @@ -1065,7 +1065,8 @@ def before_trading_start(context, data): }] * 2) equities['symbol'] = ['A', 'B'] with TempDirectory() as tempdir, \ - tmp_trading_env(equities=equities) as env: + tmp_trading_env(equities=equities, + load=self.make_load_function()) as env: sim_params = SimulationParameters( start_session=start_session, end_session=period_end, @@ -3175,7 +3176,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): orient='index', ) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: sim_params = factory.create_simulation_parameters( start=start, num_days=4, @@ -3302,7 +3304,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: algo = SetAssetDateBoundsAlgorithm( sim_params=self.sim_params, env=env, @@ -3324,7 +3327,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: data_portal = create_data_portal( env.asset_finder, tempdir, @@ -3347,7 +3351,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase): 'sid': 999, }]) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: data_portal = create_data_portal( env.asset_finder, tempdir, @@ -3774,7 +3779,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase): self.assertEqual(txn['price'], expected_price) -class TestTradingAlgorithm(ZiplineTestCase): +class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase): def test_analyze_called(self): self.perf_ref = None @@ -3794,9 +3799,8 @@ class TestTradingAlgorithm(ZiplineTestCase): env=self.env, ) - with empty_trading_env() as env: - data_portal = FakeDataPortal(env) - results = algo.run(data_portal) + data_portal = FakeDataPortal(self.env) + results = algo.run(data_portal) self.assertIs(results, self.perf_ref) @@ -3996,7 +4000,7 @@ class TestOrderCancelation(WithDataPortal, self.assertFalse(log_catcher.has_warnings) -class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase): +class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase): """ Tests if delisted equities are properly removed from a portfolio holding positions in said equities. @@ -4027,7 +4031,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase): sids = asset_info.index - env = self.enter_instance_context(tmp_trading_env(equities=asset_info)) + env = self.enter_instance_context( + tmp_trading_env(equities=asset_info, + load=self.make_load_function()) + ) if frequency == 'daily': dates = self.test_days @@ -4680,7 +4687,7 @@ class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase): ) -class TestPanelData(ZiplineTestCase): +class TestPanelData(WithTradingEnvironment, ZiplineTestCase): @parameterized.expand([ ('daily', @@ -4702,6 +4709,9 @@ class TestPanelData(ZiplineTestCase): def dt_transform(dt): return dt + else: + raise AssertionError('Unexpected data_frequency: %s' % + data_frequency) sids = range(1, 3) dfs = {} @@ -4742,11 +4752,13 @@ class TestPanelData(ZiplineTestCase): 'prev_close']].values.astype('float64') ) - trading_algo = TradingAlgorithm(initialize=initialize, - handle_data=handle_data) - trading_algo.run(data=panel) - check_panels() - price_record.loc[:] = np.nan + with tmp_trading_env(load=self.make_load_function()) as env: + trading_algo = TradingAlgorithm(initialize=initialize, + handle_data=handle_data, + env=env) + trading_algo.run(data=panel) + check_panels() + price_record.loc[:] = np.nan run_algorithm( start=start_dt, diff --git a/tests/test_finance.py b/tests/test_finance.py index 40667326..42c5b872 100644 --- a/tests/test_finance.py +++ b/tests/test_finance.py @@ -190,7 +190,8 @@ class FinanceTestCase(WithLogger, asset1 = self.asset_finder.retrieve_asset(1) metadata = make_simple_equity_info([asset1.sid], self.start, self.end) with TempDirectory() as tempdir, \ - tmp_trading_env(equities=metadata) as env: + tmp_trading_env(equities=metadata, + load=self.make_load_function()) as env: if trade_interval < timedelta(days=1): sim_params = factory.create_simulation_parameters( diff --git a/tests/test_perf_tracking.py b/tests/test_perf_tracking.py index 53f52367..ebc068ee 100644 --- a/tests/test_perf_tracking.py +++ b/tests/test_perf_tracking.py @@ -57,7 +57,6 @@ from zipline.testing.fixtures import ( WithSimParams, WithTmpDir, WithTradingEnvironment, - WithTradingCalendars, ZiplineTestCase, ) from zipline.utils.calendars import get_calendar @@ -1029,7 +1028,8 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance): END_DATE = pd.Timestamp('2003-12-08', tz='utc') -class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars, +class TestPositionPerformance(WithInstanceTmpDir, + WithTradingEnvironment, ZiplineTestCase): def create_environment_stuff(self, @@ -1054,6 +1054,7 @@ class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars, self.env = self.enter_instance_context(tmp_trading_env( equities=equities, futures=futures, + load=self.make_load_function(), )) self.sim_params = create_simulation_parameters( start=start, diff --git a/tests/test_security_list.py b/tests/test_security_list.py index b381adc3..ac7b99fe 100644 --- a/tests/test_security_list.py +++ b/tests/test_security_list.py @@ -15,7 +15,7 @@ from zipline.testing import ( ) from zipline.testing.fixtures import ( WithLogger, - WithTradingCalendars, + WithTradingEnvironment, ZiplineTestCase, ) from zipline.utils import factory @@ -82,7 +82,9 @@ class IterateRLAlgo(TradingAlgorithm): self.found = True -class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): +class SecurityListTestCase(WithLogger, + WithTradingEnvironment, + ZiplineTestCase): @classmethod def init_class_fixtures(cls): @@ -103,6 +105,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): 'symbol': symbol, 'exchange': "TEST", } for symbol in symbols]), + load=cls.make_load_function(), )) cls.sim_params = factory.create_simulation_parameters( @@ -122,6 +125,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): 'symbol': symbol, 'exchange': "TEST", } for symbol in symbols]), + load=cls.make_load_function(), )) cls.tempdir = cls.enter_class_context(tmp_dir()) @@ -304,7 +308,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase): }]) with TempDirectory() as new_tempdir, \ security_list_copy(), \ - tmp_trading_env(equities=equities) as env: + tmp_trading_env(equities=equities, + load=self.make_load_function()) as env: # add a delete statement removing bzq # write a new delete statement file to disk add_security_data([], ['BZQ']) diff --git a/zipline/testing/core.py b/zipline/testing/core.py index f8014bf4..695ece7e 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -859,6 +859,8 @@ class tmp_trading_env(tmp_asset_finder): Parameters ---------- + load : callable, optional + Function that returns benchmark returns and treasury curves. finder_cls : type, optional The type of asset finder to create from the assets db. **frames @@ -869,8 +871,13 @@ class tmp_trading_env(tmp_asset_finder): empty_trading_env tmp_asset_finder """ + def __init__(self, load=None, *args, **kwargs): + super(tmp_trading_env, self).__init__(*args, **kwargs) + self._load = load + def __enter__(self): return TradingEnvironment( + load=self._load, asset_db_path=super(tmp_trading_env, self).__enter__().engine, )