mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 17:16:55 +08:00
TST: Use fixture's data with tmp_trading_env
instead of env needing to download it
This commit is contained in:
+33
-21
@@ -92,7 +92,6 @@ from zipline.testing import (
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
create_minute_df_for_asset,
|
||||
empty_trading_env,
|
||||
make_test_handler,
|
||||
make_trade_data_for_asset_info,
|
||||
parameter_space,
|
||||
@@ -108,7 +107,6 @@ from zipline.testing.fixtures import (
|
||||
WithSimParams,
|
||||
WithTradingEnvironment,
|
||||
WithTmpDir,
|
||||
WithTradingCalendars,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.test_algorithms import (
|
||||
@@ -786,7 +784,8 @@ def log_nyse_close(context, data):
|
||||
for i, date in enumerate(dates)
|
||||
]
|
||||
)
|
||||
with tmp_trading_env(equities=metadata) as env:
|
||||
with tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = TradingAlgorithm(env=env)
|
||||
|
||||
# Set the period end to a date after the period end
|
||||
@@ -853,7 +852,8 @@ class TestTransformAlgorithm(WithLogger,
|
||||
def init_class_fixtures(cls):
|
||||
super(TestTransformAlgorithm, cls).init_class_fixtures()
|
||||
cls.futures_env = cls.enter_class_context(
|
||||
tmp_trading_env(futures=cls.make_futures_info()),
|
||||
tmp_trading_env(futures=cls.make_futures_info(),
|
||||
load=cls.make_load_function()),
|
||||
)
|
||||
|
||||
def test_invalid_order_parameters(self):
|
||||
@@ -1065,7 +1065,8 @@ def before_trading_start(context, data):
|
||||
}] * 2)
|
||||
equities['symbol'] = ['A', 'B']
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
tmp_trading_env(equities=equities,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = SimulationParameters(
|
||||
start_session=start_session,
|
||||
end_session=period_end,
|
||||
@@ -3175,7 +3176,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
orient='index',
|
||||
)
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=start,
|
||||
num_days=4,
|
||||
@@ -3302,7 +3304,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = SetAssetDateBoundsAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
env=env,
|
||||
@@ -3324,7 +3327,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3347,7 +3351,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3774,7 +3779,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
self.assertEqual(txn['price'], expected_price)
|
||||
|
||||
|
||||
class TestTradingAlgorithm(ZiplineTestCase):
|
||||
class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase):
|
||||
def test_analyze_called(self):
|
||||
self.perf_ref = None
|
||||
|
||||
@@ -3794,9 +3799,8 @@ class TestTradingAlgorithm(ZiplineTestCase):
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
with empty_trading_env() as env:
|
||||
data_portal = FakeDataPortal(env)
|
||||
results = algo.run(data_portal)
|
||||
data_portal = FakeDataPortal(self.env)
|
||||
results = algo.run(data_portal)
|
||||
|
||||
self.assertIs(results, self.perf_ref)
|
||||
|
||||
@@ -3996,7 +4000,7 @@ class TestOrderCancelation(WithDataPortal,
|
||||
self.assertFalse(log_catcher.has_warnings)
|
||||
|
||||
|
||||
class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase):
|
||||
"""
|
||||
Tests if delisted equities are properly removed from a portfolio holding
|
||||
positions in said equities.
|
||||
@@ -4027,7 +4031,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
|
||||
sids = asset_info.index
|
||||
|
||||
env = self.enter_instance_context(tmp_trading_env(equities=asset_info))
|
||||
env = self.enter_instance_context(
|
||||
tmp_trading_env(equities=asset_info,
|
||||
load=self.make_load_function())
|
||||
)
|
||||
|
||||
if frequency == 'daily':
|
||||
dates = self.test_days
|
||||
@@ -4680,7 +4687,7 @@ class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestPanelData(ZiplineTestCase):
|
||||
class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
('daily',
|
||||
@@ -4702,6 +4709,9 @@ class TestPanelData(ZiplineTestCase):
|
||||
|
||||
def dt_transform(dt):
|
||||
return dt
|
||||
else:
|
||||
raise AssertionError('Unexpected data_frequency: %s' %
|
||||
data_frequency)
|
||||
|
||||
sids = range(1, 3)
|
||||
dfs = {}
|
||||
@@ -4742,11 +4752,13 @@ class TestPanelData(ZiplineTestCase):
|
||||
'prev_close']].values.astype('float64')
|
||||
)
|
||||
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
with tmp_trading_env(load=self.make_load_function()) as env:
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
env=env)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
|
||||
@@ -190,7 +190,8 @@ class FinanceTestCase(WithLogger,
|
||||
asset1 = self.asset_finder.retrieve_asset(1)
|
||||
metadata = make_simple_equity_info([asset1.sid], self.start, self.end)
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
|
||||
if trade_interval < timedelta(days=1):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
|
||||
@@ -57,7 +57,6 @@ from zipline.testing.fixtures import (
|
||||
WithSimParams,
|
||||
WithTmpDir,
|
||||
WithTradingEnvironment,
|
||||
WithTradingCalendars,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils.calendars import get_calendar
|
||||
@@ -1029,7 +1028,8 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance):
|
||||
END_DATE = pd.Timestamp('2003-12-08', tz='utc')
|
||||
|
||||
|
||||
class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
|
||||
class TestPositionPerformance(WithInstanceTmpDir,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase):
|
||||
|
||||
def create_environment_stuff(self,
|
||||
@@ -1054,6 +1054,7 @@ class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
|
||||
self.env = self.enter_instance_context(tmp_trading_env(
|
||||
equities=equities,
|
||||
futures=futures,
|
||||
load=self.make_load_function(),
|
||||
))
|
||||
self.sim_params = create_simulation_parameters(
|
||||
start=start,
|
||||
|
||||
@@ -15,7 +15,7 @@ from zipline.testing import (
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithLogger,
|
||||
WithTradingCalendars,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils import factory
|
||||
@@ -82,7 +82,9 @@ class IterateRLAlgo(TradingAlgorithm):
|
||||
self.found = True
|
||||
|
||||
|
||||
class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
class SecurityListTestCase(WithLogger,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase):
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
@@ -103,6 +105,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
'symbol': symbol,
|
||||
'exchange': "TEST",
|
||||
} for symbol in symbols]),
|
||||
load=cls.make_load_function(),
|
||||
))
|
||||
|
||||
cls.sim_params = factory.create_simulation_parameters(
|
||||
@@ -122,6 +125,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
'symbol': symbol,
|
||||
'exchange': "TEST",
|
||||
} for symbol in symbols]),
|
||||
load=cls.make_load_function(),
|
||||
))
|
||||
|
||||
cls.tempdir = cls.enter_class_context(tmp_dir())
|
||||
@@ -304,7 +308,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
}])
|
||||
with TempDirectory() as new_tempdir, \
|
||||
security_list_copy(), \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
tmp_trading_env(equities=equities,
|
||||
load=self.make_load_function()) as env:
|
||||
# add a delete statement removing bzq
|
||||
# write a new delete statement file to disk
|
||||
add_security_data([], ['BZQ'])
|
||||
|
||||
@@ -859,6 +859,8 @@ class tmp_trading_env(tmp_asset_finder):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
load : callable, optional
|
||||
Function that returns benchmark returns and treasury curves.
|
||||
finder_cls : type, optional
|
||||
The type of asset finder to create from the assets db.
|
||||
**frames
|
||||
@@ -869,8 +871,13 @@ class tmp_trading_env(tmp_asset_finder):
|
||||
empty_trading_env
|
||||
tmp_asset_finder
|
||||
"""
|
||||
def __init__(self, load=None, *args, **kwargs):
|
||||
super(tmp_trading_env, self).__init__(*args, **kwargs)
|
||||
self._load = load
|
||||
|
||||
def __enter__(self):
|
||||
return TradingEnvironment(
|
||||
load=self._load,
|
||||
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user