TST: Use fixture's trading env for FakeDataPortal or TradingAlgo

to avoid a new trading env needing to download data unnecessarily
This commit is contained in:
Richard Frank
2017-05-17 09:36:46 -04:00
parent dde4974705
commit c5b3ceecc1
4 changed files with 30 additions and 22 deletions
+3 -3
View File
@@ -566,7 +566,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
)
algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
# Yes, I really do want to use the start and end dates I passed to
# TradingAlgorithm.
overwrite_sim_params=False,
@@ -606,7 +606,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
)
algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
overwrite_sim_params=False,
)
@@ -654,7 +654,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
)
algo.run(
FakeDataPortal(),
FakeDataPortal(self.env),
overwrite_sim_params=False,
)
+15 -7
View File
@@ -314,6 +314,7 @@ def handle_data(algo, data):
initialize=lambda context: None,
handle_data=lambda context, data: None,
sim_params=self.sim_params,
env=self.env,
)
# Verify that api methods get resolved dynamically by patching them out
@@ -892,7 +893,8 @@ def before_trading_start(context, data):
def test_run_twice(self):
algo1 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
sids=[0, 1],
env=self.env,
)
res1 = algo1.run(self.data_portal)
@@ -901,7 +903,8 @@ def before_trading_start(context, data):
# use the newly instantiated environment.
algo2 = TestRegisterTransformAlgorithm(
sim_params=self.sim_params,
sids=[0, 1]
sids=[0, 1],
env=self.env,
)
res2 = algo2.run(self.data_portal)
@@ -1569,15 +1572,16 @@ class TestAlgoScript(WithLogger,
def test_noop(self):
algo = TradingAlgorithm(initialize=initialize_noop,
handle_data=handle_data_noop)
handle_data=handle_data_noop,
env=self.env)
algo.run(self.data_portal)
def test_noop_string(self):
algo = TradingAlgorithm(script=noop_algo)
algo = TradingAlgorithm(script=noop_algo, env=self.env)
algo.run(self.data_portal)
def test_no_handle_data(self):
algo = TradingAlgorithm(script=no_handle_data)
algo = TradingAlgorithm(script=no_handle_data, env=self.env)
algo.run(self.data_portal)
def test_api_calls(self):
@@ -1593,7 +1597,8 @@ class TestAlgoScript(WithLogger,
def test_api_get_environment(self):
platform = 'zipline'
algo = TradingAlgorithm(script=api_get_environment_algo,
platform=platform)
platform=platform,
env=self.env)
algo.run(self.data_portal)
self.assertEqual(algo.environment, platform)
@@ -1779,6 +1784,7 @@ def handle_data(context, data):
test_algo = TradingAlgorithm(
script=record_variables,
sim_params=self.sim_params,
env=self.env,
)
set_algo_instance(test_algo)
@@ -3785,6 +3791,7 @@ class TestTradingAlgorithm(ZiplineTestCase):
initialize=initialize,
handle_data=handle_data,
analyze=analyze,
env=self.env,
)
with empty_trading_env() as env:
@@ -4642,7 +4649,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase):
self.assertEqual(expected_message, w.message)
class AlgoInputValidationTestCase(ZiplineTestCase):
class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
def test_reject_passing_both_api_methods_and_script(self):
script = dedent(
@@ -4668,6 +4675,7 @@ class AlgoInputValidationTestCase(ZiplineTestCase):
with self.assertRaises(ValueError):
TradingAlgorithm(
script=script,
env=self.env,
**{method: lambda *args, **kwargs: None}
)
+11 -8
View File
@@ -19,7 +19,6 @@ from mock import patch
from nose_parameterized import parameterized
from six.moves import range
from unittest import TestCase
from zipline import TradingAlgorithm
from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR
@@ -28,8 +27,12 @@ from zipline.finance.asset_restrictions import NoRestrictions
from zipline.gens.tradesimulation import AlgorithmSimulator
from zipline.sources.benchmark_source import BenchmarkSource
from zipline.test_algorithms import NoopAlgorithm
from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \
WithDataPortal
from zipline.testing.fixtures import (
WithDataPortal,
WithSimParams,
WithTradingEnvironment,
ZiplineTestCase,
)
from zipline.utils import factory
from zipline.testing.core import FakeDataPortal
from zipline.utils.calendars.trading_calendar import days_at_time
@@ -50,7 +53,7 @@ class BeforeTradingAlgorithm(TradingAlgorithm):
FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute
class TestTradeSimulation(TestCase):
class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase):
def fake_minutely_benchmark(self, dt):
return 0.01
@@ -61,8 +64,8 @@ class TestTradeSimulation(TestCase):
emission_rate='minute')
with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = NoopAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
algo = NoopAlgorithm(sim_params=params, env=self.env)
algo.run(FakeDataPortal(self.env))
self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)
@parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate),
@@ -82,8 +85,8 @@ class TestTradeSimulation(TestCase):
with patch.object(BenchmarkSource, "get_value",
self.fake_minutely_benchmark):
algo = BeforeTradingAlgorithm(sim_params=params)
algo.run(FakeDataPortal())
algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
algo.run(FakeDataPortal(self.env))
self.assertEqual(
len(algo.perf_tracker.sim_params.sessions),
+1 -4
View File
@@ -695,11 +695,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar,
class FakeDataPortal(DataPortal):
def __init__(self, env=None, trading_calendar=None,
def __init__(self, env, trading_calendar=None,
first_trading_day=None):
if env is None:
env = TradingEnvironment()
if trading_calendar is None:
trading_calendar = get_calendar("NYSE")