mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 09:09:19 +08:00
TST: Use fixture's trading env for FakeDataPortal or TradingAlgo
to avoid a new trading env needing to download data unnecessarily
This commit is contained in:
@@ -566,7 +566,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
# Yes, I really do want to use the start and end dates I passed to
|
||||
# TradingAlgorithm.
|
||||
overwrite_sim_params=False,
|
||||
@@ -606,7 +606,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
@@ -654,7 +654,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
|
||||
+15
-7
@@ -314,6 +314,7 @@ def handle_data(algo, data):
|
||||
initialize=lambda context: None,
|
||||
handle_data=lambda context, data: None,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
# Verify that api methods get resolved dynamically by patching them out
|
||||
@@ -892,7 +893,8 @@ def before_trading_start(context, data):
|
||||
def test_run_twice(self):
|
||||
algo1 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res1 = algo1.run(self.data_portal)
|
||||
@@ -901,7 +903,8 @@ def before_trading_start(context, data):
|
||||
# use the newly instantiated environment.
|
||||
algo2 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res2 = algo2.run(self.data_portal)
|
||||
@@ -1569,15 +1572,16 @@ class TestAlgoScript(WithLogger,
|
||||
|
||||
def test_noop(self):
|
||||
algo = TradingAlgorithm(initialize=initialize_noop,
|
||||
handle_data=handle_data_noop)
|
||||
handle_data=handle_data_noop,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_noop_string(self):
|
||||
algo = TradingAlgorithm(script=noop_algo)
|
||||
algo = TradingAlgorithm(script=noop_algo, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_no_handle_data(self):
|
||||
algo = TradingAlgorithm(script=no_handle_data)
|
||||
algo = TradingAlgorithm(script=no_handle_data, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_api_calls(self):
|
||||
@@ -1593,7 +1597,8 @@ class TestAlgoScript(WithLogger,
|
||||
def test_api_get_environment(self):
|
||||
platform = 'zipline'
|
||||
algo = TradingAlgorithm(script=api_get_environment_algo,
|
||||
platform=platform)
|
||||
platform=platform,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
self.assertEqual(algo.environment, platform)
|
||||
|
||||
@@ -1779,6 +1784,7 @@ def handle_data(context, data):
|
||||
test_algo = TradingAlgorithm(
|
||||
script=record_variables,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
set_algo_instance(test_algo)
|
||||
|
||||
@@ -3785,6 +3791,7 @@ class TestTradingAlgorithm(ZiplineTestCase):
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
analyze=analyze,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
with empty_trading_env() as env:
|
||||
@@ -4642,7 +4649,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase):
|
||||
self.assertEqual(expected_message, w.message)
|
||||
|
||||
|
||||
class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
def test_reject_passing_both_api_methods_and_script(self):
|
||||
script = dedent(
|
||||
@@ -4668,6 +4675,7 @@ class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
TradingAlgorithm(
|
||||
script=script,
|
||||
env=self.env,
|
||||
**{method: lambda *args, **kwargs: None}
|
||||
)
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from mock import patch
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from six.moves import range
|
||||
from unittest import TestCase
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR
|
||||
|
||||
@@ -28,8 +27,12 @@ from zipline.finance.asset_restrictions import NoRestrictions
|
||||
from zipline.gens.tradesimulation import AlgorithmSimulator
|
||||
from zipline.sources.benchmark_source import BenchmarkSource
|
||||
from zipline.test_algorithms import NoopAlgorithm
|
||||
from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \
|
||||
WithDataPortal
|
||||
from zipline.testing.fixtures import (
|
||||
WithDataPortal,
|
||||
WithSimParams,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils import factory
|
||||
from zipline.testing.core import FakeDataPortal
|
||||
from zipline.utils.calendars.trading_calendar import days_at_time
|
||||
@@ -50,7 +53,7 @@ class BeforeTradingAlgorithm(TradingAlgorithm):
|
||||
FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute
|
||||
|
||||
|
||||
class TestTradeSimulation(TestCase):
|
||||
class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
def fake_minutely_benchmark(self, dt):
|
||||
return 0.01
|
||||
@@ -61,8 +64,8 @@ class TestTradeSimulation(TestCase):
|
||||
emission_rate='minute')
|
||||
with patch.object(BenchmarkSource, "get_value",
|
||||
self.fake_minutely_benchmark):
|
||||
algo = NoopAlgorithm(sim_params=params)
|
||||
algo.run(FakeDataPortal())
|
||||
algo = NoopAlgorithm(sim_params=params, env=self.env)
|
||||
algo.run(FakeDataPortal(self.env))
|
||||
self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)
|
||||
|
||||
@parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate),
|
||||
@@ -82,8 +85,8 @@ class TestTradeSimulation(TestCase):
|
||||
|
||||
with patch.object(BenchmarkSource, "get_value",
|
||||
self.fake_minutely_benchmark):
|
||||
algo = BeforeTradingAlgorithm(sim_params=params)
|
||||
algo.run(FakeDataPortal())
|
||||
algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
|
||||
algo.run(FakeDataPortal(self.env))
|
||||
|
||||
self.assertEqual(
|
||||
len(algo.perf_tracker.sim_params.sessions),
|
||||
|
||||
@@ -695,11 +695,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar,
|
||||
|
||||
|
||||
class FakeDataPortal(DataPortal):
|
||||
def __init__(self, env=None, trading_calendar=None,
|
||||
def __init__(self, env, trading_calendar=None,
|
||||
first_trading_day=None):
|
||||
if env is None:
|
||||
env = TradingEnvironment()
|
||||
|
||||
if trading_calendar is None:
|
||||
trading_calendar = get_calendar("NYSE")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user