TST: Use fixture's data with tmp_trading_env

instead of env needing to download it
This commit is contained in:
Richard Frank
2017-05-17 11:09:19 -04:00
parent c5b3ceecc1
commit 0f6dbcef3c
5 changed files with 53 additions and 27 deletions
+33 -21
View File
@@ -92,7 +92,6 @@ from zipline.testing import (
create_data_portal,
create_data_portal_from_trade_history,
create_minute_df_for_asset,
empty_trading_env,
make_test_handler,
make_trade_data_for_asset_info,
parameter_space,
@@ -108,7 +107,6 @@ from zipline.testing.fixtures import (
WithSimParams,
WithTradingEnvironment,
WithTmpDir,
WithTradingCalendars,
ZiplineTestCase,
)
from zipline.test_algorithms import (
@@ -786,7 +784,8 @@ def log_nyse_close(context, data):
for i, date in enumerate(dates)
]
)
with tmp_trading_env(equities=metadata) as env:
with tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = TradingAlgorithm(env=env)
# Set the period end to a date after the period end
@@ -853,7 +852,8 @@ class TestTransformAlgorithm(WithLogger,
def init_class_fixtures(cls):
super(TestTransformAlgorithm, cls).init_class_fixtures()
cls.futures_env = cls.enter_class_context(
tmp_trading_env(futures=cls.make_futures_info()),
tmp_trading_env(futures=cls.make_futures_info(),
load=cls.make_load_function()),
)
def test_invalid_order_parameters(self):
@@ -1065,7 +1065,8 @@ def before_trading_start(context, data):
}] * 2)
equities['symbol'] = ['A', 'B']
with TempDirectory() as tempdir, \
tmp_trading_env(equities=equities) as env:
tmp_trading_env(equities=equities,
load=self.make_load_function()) as env:
sim_params = SimulationParameters(
start_session=start_session,
end_session=period_end,
@@ -3175,7 +3176,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
orient='index',
)
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
sim_params = factory.create_simulation_parameters(
start=start,
num_days=4,
@@ -3302,7 +3304,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = SetAssetDateBoundsAlgorithm(
sim_params=self.sim_params,
env=env,
@@ -3324,7 +3327,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
data_portal = create_data_portal(
env.asset_finder,
tempdir,
@@ -3347,7 +3351,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
data_portal = create_data_portal(
env.asset_finder,
tempdir,
@@ -3774,7 +3779,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
self.assertEqual(txn['price'], expected_price)
class TestTradingAlgorithm(ZiplineTestCase):
class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase):
def test_analyze_called(self):
self.perf_ref = None
@@ -3794,9 +3799,8 @@ class TestTradingAlgorithm(ZiplineTestCase):
env=self.env,
)
with empty_trading_env() as env:
data_portal = FakeDataPortal(env)
results = algo.run(data_portal)
data_portal = FakeDataPortal(self.env)
results = algo.run(data_portal)
self.assertIs(results, self.perf_ref)
@@ -3996,7 +4000,7 @@ class TestOrderCancelation(WithDataPortal,
self.assertFalse(log_catcher.has_warnings)
class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase):
"""
Tests if delisted equities are properly removed from a portfolio holding
positions in said equities.
@@ -4027,7 +4031,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
sids = asset_info.index
env = self.enter_instance_context(tmp_trading_env(equities=asset_info))
env = self.enter_instance_context(
tmp_trading_env(equities=asset_info,
load=self.make_load_function())
)
if frequency == 'daily':
dates = self.test_days
@@ -4680,7 +4687,7 @@ class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
)
class TestPanelData(ZiplineTestCase):
class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
@parameterized.expand([
('daily',
@@ -4702,6 +4709,9 @@ class TestPanelData(ZiplineTestCase):
def dt_transform(dt):
return dt
else:
raise AssertionError('Unexpected data_frequency: %s' %
data_frequency)
sids = range(1, 3)
dfs = {}
@@ -4742,11 +4752,13 @@ class TestPanelData(ZiplineTestCase):
'prev_close']].values.astype('float64')
)
trading_algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data)
trading_algo.run(data=panel)
check_panels()
price_record.loc[:] = np.nan
with tmp_trading_env(load=self.make_load_function()) as env:
trading_algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
env=env)
trading_algo.run(data=panel)
check_panels()
price_record.loc[:] = np.nan
run_algorithm(
start=start_dt,