mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 23:08:01 +08:00
TST: Use fixture's data with tmp_trading_env
instead of env needing to download it
This commit is contained in:
+33
-21
@@ -92,7 +92,6 @@ from zipline.testing import (
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
create_minute_df_for_asset,
|
||||
empty_trading_env,
|
||||
make_test_handler,
|
||||
make_trade_data_for_asset_info,
|
||||
parameter_space,
|
||||
@@ -108,7 +107,6 @@ from zipline.testing.fixtures import (
|
||||
WithSimParams,
|
||||
WithTradingEnvironment,
|
||||
WithTmpDir,
|
||||
WithTradingCalendars,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.test_algorithms import (
|
||||
@@ -786,7 +784,8 @@ def log_nyse_close(context, data):
|
||||
for i, date in enumerate(dates)
|
||||
]
|
||||
)
|
||||
with tmp_trading_env(equities=metadata) as env:
|
||||
with tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = TradingAlgorithm(env=env)
|
||||
|
||||
# Set the period end to a date after the period end
|
||||
@@ -853,7 +852,8 @@ class TestTransformAlgorithm(WithLogger,
|
||||
def init_class_fixtures(cls):
|
||||
super(TestTransformAlgorithm, cls).init_class_fixtures()
|
||||
cls.futures_env = cls.enter_class_context(
|
||||
tmp_trading_env(futures=cls.make_futures_info()),
|
||||
tmp_trading_env(futures=cls.make_futures_info(),
|
||||
load=cls.make_load_function()),
|
||||
)
|
||||
|
||||
def test_invalid_order_parameters(self):
|
||||
@@ -1065,7 +1065,8 @@ def before_trading_start(context, data):
|
||||
}] * 2)
|
||||
equities['symbol'] = ['A', 'B']
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
tmp_trading_env(equities=equities,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = SimulationParameters(
|
||||
start_session=start_session,
|
||||
end_session=period_end,
|
||||
@@ -3175,7 +3176,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
orient='index',
|
||||
)
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=start,
|
||||
num_days=4,
|
||||
@@ -3302,7 +3304,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = SetAssetDateBoundsAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
env=env,
|
||||
@@ -3324,7 +3327,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3347,7 +3351,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3774,7 +3779,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
self.assertEqual(txn['price'], expected_price)
|
||||
|
||||
|
||||
class TestTradingAlgorithm(ZiplineTestCase):
|
||||
class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase):
|
||||
def test_analyze_called(self):
|
||||
self.perf_ref = None
|
||||
|
||||
@@ -3794,9 +3799,8 @@ class TestTradingAlgorithm(ZiplineTestCase):
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
with empty_trading_env() as env:
|
||||
data_portal = FakeDataPortal(env)
|
||||
results = algo.run(data_portal)
|
||||
data_portal = FakeDataPortal(self.env)
|
||||
results = algo.run(data_portal)
|
||||
|
||||
self.assertIs(results, self.perf_ref)
|
||||
|
||||
@@ -3996,7 +4000,7 @@ class TestOrderCancelation(WithDataPortal,
|
||||
self.assertFalse(log_catcher.has_warnings)
|
||||
|
||||
|
||||
class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase):
|
||||
"""
|
||||
Tests if delisted equities are properly removed from a portfolio holding
|
||||
positions in said equities.
|
||||
@@ -4027,7 +4031,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
|
||||
sids = asset_info.index
|
||||
|
||||
env = self.enter_instance_context(tmp_trading_env(equities=asset_info))
|
||||
env = self.enter_instance_context(
|
||||
tmp_trading_env(equities=asset_info,
|
||||
load=self.make_load_function())
|
||||
)
|
||||
|
||||
if frequency == 'daily':
|
||||
dates = self.test_days
|
||||
@@ -4680,7 +4687,7 @@ class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestPanelData(ZiplineTestCase):
|
||||
class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
('daily',
|
||||
@@ -4702,6 +4709,9 @@ class TestPanelData(ZiplineTestCase):
|
||||
|
||||
def dt_transform(dt):
|
||||
return dt
|
||||
else:
|
||||
raise AssertionError('Unexpected data_frequency: %s' %
|
||||
data_frequency)
|
||||
|
||||
sids = range(1, 3)
|
||||
dfs = {}
|
||||
@@ -4742,11 +4752,13 @@ class TestPanelData(ZiplineTestCase):
|
||||
'prev_close']].values.astype('float64')
|
||||
)
|
||||
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
with tmp_trading_env(load=self.make_load_function()) as env:
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
env=env)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
|
||||
Reference in New Issue
Block a user