TST: Use fixture's data with tmp_trading_env

instead of env needing to download it
This commit is contained in:
Richard Frank
2017-05-17 11:09:19 -04:00
parent c5b3ceecc1
commit 0f6dbcef3c
5 changed files with 53 additions and 27 deletions
+33 -21
View File
@@ -92,7 +92,6 @@ from zipline.testing import (
create_data_portal,
create_data_portal_from_trade_history,
create_minute_df_for_asset,
empty_trading_env,
make_test_handler,
make_trade_data_for_asset_info,
parameter_space,
@@ -108,7 +107,6 @@ from zipline.testing.fixtures import (
WithSimParams,
WithTradingEnvironment,
WithTmpDir,
WithTradingCalendars,
ZiplineTestCase,
)
from zipline.test_algorithms import (
@@ -786,7 +784,8 @@ def log_nyse_close(context, data):
for i, date in enumerate(dates)
]
)
with tmp_trading_env(equities=metadata) as env:
with tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = TradingAlgorithm(env=env)
# Set the period end to a date after the period end
@@ -853,7 +852,8 @@ class TestTransformAlgorithm(WithLogger,
def init_class_fixtures(cls):
super(TestTransformAlgorithm, cls).init_class_fixtures()
cls.futures_env = cls.enter_class_context(
tmp_trading_env(futures=cls.make_futures_info()),
tmp_trading_env(futures=cls.make_futures_info(),
load=cls.make_load_function()),
)
def test_invalid_order_parameters(self):
@@ -1065,7 +1065,8 @@ def before_trading_start(context, data):
}] * 2)
equities['symbol'] = ['A', 'B']
with TempDirectory() as tempdir, \
tmp_trading_env(equities=equities) as env:
tmp_trading_env(equities=equities,
load=self.make_load_function()) as env:
sim_params = SimulationParameters(
start_session=start_session,
end_session=period_end,
@@ -3175,7 +3176,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
orient='index',
)
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
sim_params = factory.create_simulation_parameters(
start=start,
num_days=4,
@@ -3302,7 +3304,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
algo = SetAssetDateBoundsAlgorithm(
sim_params=self.sim_params,
env=env,
@@ -3324,7 +3327,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
data_portal = create_data_portal(
env.asset_finder,
tempdir,
@@ -3347,7 +3351,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
'sid': 999,
}])
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
data_portal = create_data_portal(
env.asset_finder,
tempdir,
@@ -3774,7 +3779,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
self.assertEqual(txn['price'], expected_price)
class TestTradingAlgorithm(ZiplineTestCase):
class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase):
def test_analyze_called(self):
self.perf_ref = None
@@ -3794,9 +3799,8 @@ class TestTradingAlgorithm(ZiplineTestCase):
env=self.env,
)
with empty_trading_env() as env:
data_portal = FakeDataPortal(env)
results = algo.run(data_portal)
data_portal = FakeDataPortal(self.env)
results = algo.run(data_portal)
self.assertIs(results, self.perf_ref)
@@ -3996,7 +4000,7 @@ class TestOrderCancelation(WithDataPortal,
self.assertFalse(log_catcher.has_warnings)
class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase):
"""
Tests if delisted equities are properly removed from a portfolio holding
positions in said equities.
@@ -4027,7 +4031,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
sids = asset_info.index
env = self.enter_instance_context(tmp_trading_env(equities=asset_info))
env = self.enter_instance_context(
tmp_trading_env(equities=asset_info,
load=self.make_load_function())
)
if frequency == 'daily':
dates = self.test_days
@@ -4680,7 +4687,7 @@ class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
)
class TestPanelData(ZiplineTestCase):
class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
@parameterized.expand([
('daily',
@@ -4702,6 +4709,9 @@ class TestPanelData(ZiplineTestCase):
def dt_transform(dt):
return dt
else:
raise AssertionError('Unexpected data_frequency: %s' %
data_frequency)
sids = range(1, 3)
dfs = {}
@@ -4742,11 +4752,13 @@ class TestPanelData(ZiplineTestCase):
'prev_close']].values.astype('float64')
)
trading_algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data)
trading_algo.run(data=panel)
check_panels()
price_record.loc[:] = np.nan
with tmp_trading_env(load=self.make_load_function()) as env:
trading_algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data,
env=env)
trading_algo.run(data=panel)
check_panels()
price_record.loc[:] = np.nan
run_algorithm(
start=start_dt,
+2 -1
View File
@@ -190,7 +190,8 @@ class FinanceTestCase(WithLogger,
asset1 = self.asset_finder.retrieve_asset(1)
metadata = make_simple_equity_info([asset1.sid], self.start, self.end)
with TempDirectory() as tempdir, \
tmp_trading_env(equities=metadata) as env:
tmp_trading_env(equities=metadata,
load=self.make_load_function()) as env:
if trade_interval < timedelta(days=1):
sim_params = factory.create_simulation_parameters(
+3 -2
View File
@@ -57,7 +57,6 @@ from zipline.testing.fixtures import (
WithSimParams,
WithTmpDir,
WithTradingEnvironment,
WithTradingCalendars,
ZiplineTestCase,
)
from zipline.utils.calendars import get_calendar
@@ -1029,7 +1028,8 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance):
END_DATE = pd.Timestamp('2003-12-08', tz='utc')
class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
class TestPositionPerformance(WithInstanceTmpDir,
WithTradingEnvironment,
ZiplineTestCase):
def create_environment_stuff(self,
@@ -1054,6 +1054,7 @@ class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
self.env = self.enter_instance_context(tmp_trading_env(
equities=equities,
futures=futures,
load=self.make_load_function(),
))
self.sim_params = create_simulation_parameters(
start=start,
+8 -3
View File
@@ -15,7 +15,7 @@ from zipline.testing import (
)
from zipline.testing.fixtures import (
WithLogger,
WithTradingCalendars,
WithTradingEnvironment,
ZiplineTestCase,
)
from zipline.utils import factory
@@ -82,7 +82,9 @@ class IterateRLAlgo(TradingAlgorithm):
self.found = True
class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
class SecurityListTestCase(WithLogger,
WithTradingEnvironment,
ZiplineTestCase):
@classmethod
def init_class_fixtures(cls):
@@ -103,6 +105,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
'symbol': symbol,
'exchange': "TEST",
} for symbol in symbols]),
load=cls.make_load_function(),
))
cls.sim_params = factory.create_simulation_parameters(
@@ -122,6 +125,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
'symbol': symbol,
'exchange': "TEST",
} for symbol in symbols]),
load=cls.make_load_function(),
))
cls.tempdir = cls.enter_class_context(tmp_dir())
@@ -304,7 +308,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
}])
with TempDirectory() as new_tempdir, \
security_list_copy(), \
tmp_trading_env(equities=equities) as env:
tmp_trading_env(equities=equities,
load=self.make_load_function()) as env:
# add a delete statement removing bzq
# write a new delete statement file to disk
add_security_data([], ['BZQ'])
+7
View File
@@ -859,6 +859,8 @@ class tmp_trading_env(tmp_asset_finder):
Parameters
----------
load : callable, optional
Function that returns benchmark returns and treasury curves.
finder_cls : type, optional
The type of asset finder to create from the assets db.
**frames
@@ -869,8 +871,13 @@ class tmp_trading_env(tmp_asset_finder):
empty_trading_env
tmp_asset_finder
"""
def __init__(self, load=None, *args, **kwargs):
super(tmp_trading_env, self).__init__(*args, **kwargs)
self._load = load
def __enter__(self):
return TradingEnvironment(
load=self._load,
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
)