From c5b3ceecc1217b735e698d3f3bda796932e51c87 Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Wed, 17 May 2017 09:36:46 -0400 Subject: [PATCH] TST: Use fixture's trading env for FakeDataPortal or TradingAlgo to avoid a new trading env needing to download data unnecessarily --- tests/pipeline/test_pipeline_algo.py | 6 +++--- tests/test_algorithm.py | 22 +++++++++++++++------- tests/test_tradesimulation.py | 19 +++++++++++-------- zipline/testing/core.py | 5 +---- 4 files changed, 30 insertions(+), 22 deletions(-) diff --git a/tests/pipeline/test_pipeline_algo.py b/tests/pipeline/test_pipeline_algo.py index b45f7742..4721236e 100644 --- a/tests/pipeline/test_pipeline_algo.py +++ b/tests/pipeline/test_pipeline_algo.py @@ -566,7 +566,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), # Yes, I really do want to use the start and end dates I passed to # TradingAlgorithm. overwrite_sim_params=False, @@ -606,7 +606,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), overwrite_sim_params=False, ) @@ -654,7 +654,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs, ) algo.run( - FakeDataPortal(), + FakeDataPortal(self.env), overwrite_sim_params=False, ) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 4c89a9b8..dc3db16e 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -314,6 +314,7 @@ def handle_data(algo, data): initialize=lambda context: None, handle_data=lambda context, data: None, sim_params=self.sim_params, + env=self.env, ) # Verify that api methods get resolved dynamically by patching them out @@ -892,7 +893,8 @@ def before_trading_start(context, data): def test_run_twice(self): algo1 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, - sids=[0, 1] + sids=[0, 1], + env=self.env, ) res1 = algo1.run(self.data_portal) @@ -901,7 +903,8 @@ def before_trading_start(context, data): # use the newly instantiated environment. algo2 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, - sids=[0, 1] + sids=[0, 1], + env=self.env, ) res2 = algo2.run(self.data_portal) @@ -1569,15 +1572,16 @@ class TestAlgoScript(WithLogger, def test_noop(self): algo = TradingAlgorithm(initialize=initialize_noop, - handle_data=handle_data_noop) + handle_data=handle_data_noop, + env=self.env) algo.run(self.data_portal) def test_noop_string(self): - algo = TradingAlgorithm(script=noop_algo) + algo = TradingAlgorithm(script=noop_algo, env=self.env) algo.run(self.data_portal) def test_no_handle_data(self): - algo = TradingAlgorithm(script=no_handle_data) + algo = TradingAlgorithm(script=no_handle_data, env=self.env) algo.run(self.data_portal) def test_api_calls(self): @@ -1593,7 +1597,8 @@ class TestAlgoScript(WithLogger, def test_api_get_environment(self): platform = 'zipline' algo = TradingAlgorithm(script=api_get_environment_algo, - platform=platform) + platform=platform, + env=self.env) algo.run(self.data_portal) self.assertEqual(algo.environment, platform) @@ -1779,6 +1784,7 @@ def handle_data(context, data): test_algo = TradingAlgorithm( script=record_variables, sim_params=self.sim_params, + env=self.env, ) set_algo_instance(test_algo) @@ -3785,6 +3791,7 @@ class TestTradingAlgorithm(ZiplineTestCase): initialize=initialize, handle_data=handle_data, analyze=analyze, + env=self.env, ) with empty_trading_env() as env: @@ -4642,7 +4649,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase): self.assertEqual(expected_message, w.message) -class AlgoInputValidationTestCase(ZiplineTestCase): +class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase): def test_reject_passing_both_api_methods_and_script(self): script = dedent( @@ -4668,6 +4675,7 @@ class AlgoInputValidationTestCase(ZiplineTestCase): with self.assertRaises(ValueError): TradingAlgorithm( script=script, + env=self.env, **{method: lambda *args, **kwargs: None} ) diff --git a/tests/test_tradesimulation.py b/tests/test_tradesimulation.py index 6277b8e2..a7dc8e69 100644 --- a/tests/test_tradesimulation.py +++ b/tests/test_tradesimulation.py @@ -19,7 +19,6 @@ from mock import patch from nose_parameterized import parameterized from six.moves import range -from unittest import TestCase from zipline import TradingAlgorithm from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR @@ -28,8 +27,12 @@ from zipline.finance.asset_restrictions import NoRestrictions from zipline.gens.tradesimulation import AlgorithmSimulator from zipline.sources.benchmark_source import BenchmarkSource from zipline.test_algorithms import NoopAlgorithm -from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \ - WithDataPortal +from zipline.testing.fixtures import ( + WithDataPortal, + WithSimParams, + WithTradingEnvironment, + ZiplineTestCase, +) from zipline.utils import factory from zipline.testing.core import FakeDataPortal from zipline.utils.calendars.trading_calendar import days_at_time @@ -50,7 +53,7 @@ class BeforeTradingAlgorithm(TradingAlgorithm): FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute -class TestTradeSimulation(TestCase): +class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase): def fake_minutely_benchmark(self, dt): return 0.01 @@ -61,8 +64,8 @@ class TestTradeSimulation(TestCase): emission_rate='minute') with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): - algo = NoopAlgorithm(sim_params=params) - algo.run(FakeDataPortal()) + algo = NoopAlgorithm(sim_params=params, env=self.env) + algo.run(FakeDataPortal(self.env)) self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1) @parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate), @@ -82,8 +85,8 @@ class TestTradeSimulation(TestCase): with patch.object(BenchmarkSource, "get_value", self.fake_minutely_benchmark): - algo = BeforeTradingAlgorithm(sim_params=params) - algo.run(FakeDataPortal()) + algo = BeforeTradingAlgorithm(sim_params=params, env=self.env) + algo.run(FakeDataPortal(self.env)) self.assertEqual( len(algo.perf_tracker.sim_params.sessions), diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 6afe1e4e..f8014bf4 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -695,11 +695,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar, class FakeDataPortal(DataPortal): - def __init__(self, env=None, trading_calendar=None, + def __init__(self, env, trading_calendar=None, first_trading_day=None): - if env is None: - env = TradingEnvironment() - if trading_calendar is None: trading_calendar = get_calendar("NYSE")