mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:14:36 +08:00
Merge pull request #1793 from quantopian/tests-without-yahoo
TST: Don't require downloading of data for tests
This commit is contained in:
@@ -566,7 +566,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
# Yes, I really do want to use the start and end dates I passed to
|
||||
# TradingAlgorithm.
|
||||
overwrite_sim_params=False,
|
||||
@@ -606,7 +606,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
@@ -654,7 +654,7 @@ class PipelineAlgorithmTestCase(WithBcolzEquityDailyBarReaderFromCSVs,
|
||||
)
|
||||
|
||||
algo.run(
|
||||
FakeDataPortal(),
|
||||
FakeDataPortal(self.env),
|
||||
overwrite_sim_params=False,
|
||||
)
|
||||
|
||||
|
||||
+65
-38
@@ -88,11 +88,11 @@ from zipline.finance.asset_restrictions import (
|
||||
)
|
||||
from zipline.testing import (
|
||||
FakeDataPortal,
|
||||
copy_market_data,
|
||||
create_daily_df_for_asset,
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
create_minute_df_for_asset,
|
||||
empty_trading_env,
|
||||
make_test_handler,
|
||||
make_trade_data_for_asset_info,
|
||||
parameter_space,
|
||||
@@ -100,6 +100,7 @@ from zipline.testing import (
|
||||
tmp_trading_env,
|
||||
to_utc,
|
||||
trades_by_sid_to_dfs,
|
||||
tmp_dir,
|
||||
)
|
||||
from zipline.testing import RecordBatchBlotter
|
||||
from zipline.testing.fixtures import (
|
||||
@@ -108,7 +109,6 @@ from zipline.testing.fixtures import (
|
||||
WithSimParams,
|
||||
WithTradingEnvironment,
|
||||
WithTmpDir,
|
||||
WithTradingCalendars,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.test_algorithms import (
|
||||
@@ -314,6 +314,7 @@ def handle_data(algo, data):
|
||||
initialize=lambda context: None,
|
||||
handle_data=lambda context, data: None,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
# Verify that api methods get resolved dynamically by patching them out
|
||||
@@ -785,7 +786,8 @@ def log_nyse_close(context, data):
|
||||
for i, date in enumerate(dates)
|
||||
]
|
||||
)
|
||||
with tmp_trading_env(equities=metadata) as env:
|
||||
with tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = TradingAlgorithm(env=env)
|
||||
|
||||
# Set the period end to a date after the period end
|
||||
@@ -852,7 +854,8 @@ class TestTransformAlgorithm(WithLogger,
|
||||
def init_class_fixtures(cls):
|
||||
super(TestTransformAlgorithm, cls).init_class_fixtures()
|
||||
cls.futures_env = cls.enter_class_context(
|
||||
tmp_trading_env(futures=cls.make_futures_info()),
|
||||
tmp_trading_env(futures=cls.make_futures_info(),
|
||||
load=cls.make_load_function()),
|
||||
)
|
||||
|
||||
def test_invalid_order_parameters(self):
|
||||
@@ -892,7 +895,8 @@ def before_trading_start(context, data):
|
||||
def test_run_twice(self):
|
||||
algo1 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res1 = algo1.run(self.data_portal)
|
||||
@@ -901,7 +905,8 @@ def before_trading_start(context, data):
|
||||
# use the newly instantiated environment.
|
||||
algo2 = TestRegisterTransformAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
sids=[0, 1]
|
||||
sids=[0, 1],
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
res2 = algo2.run(self.data_portal)
|
||||
@@ -1062,7 +1067,8 @@ def before_trading_start(context, data):
|
||||
}] * 2)
|
||||
equities['symbol'] = ['A', 'B']
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
tmp_trading_env(equities=equities,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = SimulationParameters(
|
||||
start_session=start_session,
|
||||
end_session=period_end,
|
||||
@@ -1569,15 +1575,16 @@ class TestAlgoScript(WithLogger,
|
||||
|
||||
def test_noop(self):
|
||||
algo = TradingAlgorithm(initialize=initialize_noop,
|
||||
handle_data=handle_data_noop)
|
||||
handle_data=handle_data_noop,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_noop_string(self):
|
||||
algo = TradingAlgorithm(script=noop_algo)
|
||||
algo = TradingAlgorithm(script=noop_algo, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_no_handle_data(self):
|
||||
algo = TradingAlgorithm(script=no_handle_data)
|
||||
algo = TradingAlgorithm(script=no_handle_data, env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
def test_api_calls(self):
|
||||
@@ -1593,7 +1600,8 @@ class TestAlgoScript(WithLogger,
|
||||
def test_api_get_environment(self):
|
||||
platform = 'zipline'
|
||||
algo = TradingAlgorithm(script=api_get_environment_algo,
|
||||
platform=platform)
|
||||
platform=platform,
|
||||
env=self.env)
|
||||
algo.run(self.data_portal)
|
||||
self.assertEqual(algo.environment, platform)
|
||||
|
||||
@@ -1779,6 +1787,7 @@ def handle_data(context, data):
|
||||
test_algo = TradingAlgorithm(
|
||||
script=record_variables,
|
||||
sim_params=self.sim_params,
|
||||
env=self.env,
|
||||
)
|
||||
set_algo_instance(test_algo)
|
||||
|
||||
@@ -3169,7 +3178,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
orient='index',
|
||||
)
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=start,
|
||||
num_days=4,
|
||||
@@ -3296,7 +3306,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
algo = SetAssetDateBoundsAlgorithm(
|
||||
sim_params=self.sim_params,
|
||||
env=env,
|
||||
@@ -3318,7 +3329,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3341,7 +3353,8 @@ class TestTradingControls(WithSimParams, WithDataPortal, ZiplineTestCase):
|
||||
'sid': 999,
|
||||
}])
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
data_portal = create_data_portal(
|
||||
env.asset_finder,
|
||||
tempdir,
|
||||
@@ -3768,7 +3781,7 @@ class TestFuturesAlgo(WithDataPortal, WithSimParams, ZiplineTestCase):
|
||||
self.assertEqual(txn['price'], expected_price)
|
||||
|
||||
|
||||
class TestTradingAlgorithm(ZiplineTestCase):
|
||||
class TestTradingAlgorithm(WithTradingEnvironment, ZiplineTestCase):
|
||||
def test_analyze_called(self):
|
||||
self.perf_ref = None
|
||||
|
||||
@@ -3785,11 +3798,11 @@ class TestTradingAlgorithm(ZiplineTestCase):
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
analyze=analyze,
|
||||
env=self.env,
|
||||
)
|
||||
|
||||
with empty_trading_env() as env:
|
||||
data_portal = FakeDataPortal(env)
|
||||
results = algo.run(data_portal)
|
||||
data_portal = FakeDataPortal(self.env)
|
||||
results = algo.run(data_portal)
|
||||
|
||||
self.assertIs(results, self.perf_ref)
|
||||
|
||||
@@ -3989,7 +4002,7 @@ class TestOrderCancelation(WithDataPortal,
|
||||
self.assertFalse(log_catcher.has_warnings)
|
||||
|
||||
|
||||
class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
class TestEquityAutoClose(WithTradingEnvironment, WithTmpDir, ZiplineTestCase):
|
||||
"""
|
||||
Tests if delisted equities are properly removed from a portfolio holding
|
||||
positions in said equities.
|
||||
@@ -4020,7 +4033,10 @@ class TestEquityAutoClose(WithTmpDir, WithTradingCalendars, ZiplineTestCase):
|
||||
|
||||
sids = asset_info.index
|
||||
|
||||
env = self.enter_instance_context(tmp_trading_env(equities=asset_info))
|
||||
env = self.enter_instance_context(
|
||||
tmp_trading_env(equities=asset_info,
|
||||
load=self.make_load_function())
|
||||
)
|
||||
|
||||
if frequency == 'daily':
|
||||
dates = self.test_days
|
||||
@@ -4642,7 +4658,7 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase):
|
||||
self.assertEqual(expected_message, w.message)
|
||||
|
||||
|
||||
class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
class AlgoInputValidationTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
def test_reject_passing_both_api_methods_and_script(self):
|
||||
script = dedent(
|
||||
@@ -4668,11 +4684,12 @@ class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
TradingAlgorithm(
|
||||
script=script,
|
||||
env=self.env,
|
||||
**{method: lambda *args, **kwargs: None}
|
||||
)
|
||||
|
||||
|
||||
class TestPanelData(ZiplineTestCase):
|
||||
class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
@parameterized.expand([
|
||||
('daily',
|
||||
@@ -4694,6 +4711,9 @@ class TestPanelData(ZiplineTestCase):
|
||||
|
||||
def dt_transform(dt):
|
||||
return dt
|
||||
else:
|
||||
raise AssertionError('Unexpected data_frequency: %s' %
|
||||
data_frequency)
|
||||
|
||||
sids = range(1, 3)
|
||||
dfs = {}
|
||||
@@ -4734,19 +4754,26 @@ class TestPanelData(ZiplineTestCase):
|
||||
'prev_close']].values.astype('float64')
|
||||
)
|
||||
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
with tmp_trading_env(load=self.make_load_function()) as env:
|
||||
trading_algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
env=env)
|
||||
trading_algo.run(data=panel)
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
capital_base=1,
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
data_frequency=data_frequency,
|
||||
data=panel
|
||||
)
|
||||
check_panels()
|
||||
with tmp_dir() as tmpdir:
|
||||
root = tmpdir.getpath('example_data/root')
|
||||
copy_market_data(self.MARKET_DATA_DIR, root)
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
capital_base=1,
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
data_frequency=data_frequency,
|
||||
data=panel,
|
||||
environ={'ZIPLINE_ROOT': root},
|
||||
)
|
||||
check_panels()
|
||||
|
||||
@@ -21,8 +21,9 @@ import pandas as pd
|
||||
|
||||
from zipline import examples
|
||||
from zipline.data.bundles import register, unregister
|
||||
from zipline.testing import test_resource_path
|
||||
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
|
||||
from zipline.testing import test_resource_path, copy_market_data
|
||||
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \
|
||||
WithTradingEnvironment
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.cache import dataframe_cache
|
||||
|
||||
@@ -53,6 +54,9 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase):
|
||||
serialization='pickle',
|
||||
)
|
||||
|
||||
copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR,
|
||||
cls.tmpdir.getpath('example_data/root'))
|
||||
|
||||
@parameterized.expand(examples.EXAMPLE_MODULES)
|
||||
def test_example(self, example_name):
|
||||
actual_perf = examples.run_example(
|
||||
|
||||
@@ -190,7 +190,8 @@ class FinanceTestCase(WithLogger,
|
||||
asset1 = self.asset_finder.retrieve_asset(1)
|
||||
metadata = make_simple_equity_info([asset1.sid], self.start, self.end)
|
||||
with TempDirectory() as tempdir, \
|
||||
tmp_trading_env(equities=metadata) as env:
|
||||
tmp_trading_env(equities=metadata,
|
||||
load=self.make_load_function()) as env:
|
||||
|
||||
if trade_interval < timedelta(days=1):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
|
||||
@@ -57,7 +57,6 @@ from zipline.testing.fixtures import (
|
||||
WithSimParams,
|
||||
WithTmpDir,
|
||||
WithTradingEnvironment,
|
||||
WithTradingCalendars,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils.calendars import get_calendar
|
||||
@@ -1029,7 +1028,8 @@ class TestDividendPerformanceHolidayStyle(TestDividendPerformance):
|
||||
END_DATE = pd.Timestamp('2003-12-08', tz='utc')
|
||||
|
||||
|
||||
class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
|
||||
class TestPositionPerformance(WithInstanceTmpDir,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase):
|
||||
|
||||
def create_environment_stuff(self,
|
||||
@@ -1054,6 +1054,7 @@ class TestPositionPerformance(WithInstanceTmpDir, WithTradingCalendars,
|
||||
self.env = self.enter_instance_context(tmp_trading_env(
|
||||
equities=equities,
|
||||
futures=futures,
|
||||
load=self.make_load_function(),
|
||||
))
|
||||
self.sim_params = create_simulation_parameters(
|
||||
start=start,
|
||||
|
||||
@@ -15,7 +15,7 @@ from zipline.testing import (
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithLogger,
|
||||
WithTradingCalendars,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils import factory
|
||||
@@ -82,7 +82,9 @@ class IterateRLAlgo(TradingAlgorithm):
|
||||
self.found = True
|
||||
|
||||
|
||||
class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
class SecurityListTestCase(WithLogger,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase):
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
@@ -103,6 +105,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
'symbol': symbol,
|
||||
'exchange': "TEST",
|
||||
} for symbol in symbols]),
|
||||
load=cls.make_load_function(),
|
||||
))
|
||||
|
||||
cls.sim_params = factory.create_simulation_parameters(
|
||||
@@ -122,6 +125,7 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
'symbol': symbol,
|
||||
'exchange': "TEST",
|
||||
} for symbol in symbols]),
|
||||
load=cls.make_load_function(),
|
||||
))
|
||||
|
||||
cls.tempdir = cls.enter_class_context(tmp_dir())
|
||||
@@ -304,7 +308,8 @@ class SecurityListTestCase(WithLogger, WithTradingCalendars, ZiplineTestCase):
|
||||
}])
|
||||
with TempDirectory() as new_tempdir, \
|
||||
security_list_copy(), \
|
||||
tmp_trading_env(equities=equities) as env:
|
||||
tmp_trading_env(equities=equities,
|
||||
load=self.make_load_function()) as env:
|
||||
# add a delete statement removing bzq
|
||||
# write a new delete statement file to disk
|
||||
add_security_data([], ['BZQ'])
|
||||
|
||||
@@ -19,7 +19,6 @@ from mock import patch
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from six.moves import range
|
||||
from unittest import TestCase
|
||||
from zipline import TradingAlgorithm
|
||||
from zipline.gens.sim_engine import BEFORE_TRADING_START_BAR
|
||||
|
||||
@@ -28,8 +27,12 @@ from zipline.finance.asset_restrictions import NoRestrictions
|
||||
from zipline.gens.tradesimulation import AlgorithmSimulator
|
||||
from zipline.sources.benchmark_source import BenchmarkSource
|
||||
from zipline.test_algorithms import NoopAlgorithm
|
||||
from zipline.testing.fixtures import WithSimParams, ZiplineTestCase, \
|
||||
WithDataPortal
|
||||
from zipline.testing.fixtures import (
|
||||
WithDataPortal,
|
||||
WithSimParams,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.utils import factory
|
||||
from zipline.testing.core import FakeDataPortal
|
||||
from zipline.utils.calendars.trading_calendar import days_at_time
|
||||
@@ -50,7 +53,7 @@ class BeforeTradingAlgorithm(TradingAlgorithm):
|
||||
FREQUENCIES = {'daily': 0, 'minute': 1} # daily is less frequent than minute
|
||||
|
||||
|
||||
class TestTradeSimulation(TestCase):
|
||||
class TestTradeSimulation(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
def fake_minutely_benchmark(self, dt):
|
||||
return 0.01
|
||||
@@ -61,8 +64,8 @@ class TestTradeSimulation(TestCase):
|
||||
emission_rate='minute')
|
||||
with patch.object(BenchmarkSource, "get_value",
|
||||
self.fake_minutely_benchmark):
|
||||
algo = NoopAlgorithm(sim_params=params)
|
||||
algo.run(FakeDataPortal())
|
||||
algo = NoopAlgorithm(sim_params=params, env=self.env)
|
||||
algo.run(FakeDataPortal(self.env))
|
||||
self.assertEqual(len(algo.perf_tracker.sim_params.sessions), 1)
|
||||
|
||||
@parameterized.expand([('%s_%s_%s' % (num_sessions, freq, emission_rate),
|
||||
@@ -82,8 +85,8 @@ class TestTradeSimulation(TestCase):
|
||||
|
||||
with patch.object(BenchmarkSource, "get_value",
|
||||
self.fake_minutely_benchmark):
|
||||
algo = BeforeTradingAlgorithm(sim_params=params)
|
||||
algo.run(FakeDataPortal())
|
||||
algo = BeforeTradingAlgorithm(sim_params=params, env=self.env)
|
||||
algo.run(FakeDataPortal(self.env))
|
||||
|
||||
self.assertEqual(
|
||||
len(algo.perf_tracker.sim_params.sessions),
|
||||
|
||||
+18
-11
@@ -53,13 +53,13 @@ def last_modified_time(path):
|
||||
return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC')
|
||||
|
||||
|
||||
def get_data_filepath(name):
|
||||
def get_data_filepath(name, environ=None):
|
||||
"""
|
||||
Returns a handle to data file.
|
||||
|
||||
Creates containing directory, if needed.
|
||||
"""
|
||||
dr = data_root()
|
||||
dr = data_root(environ)
|
||||
|
||||
if not os.path.exists(dr):
|
||||
os.makedirs(dr)
|
||||
@@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date):
|
||||
return (first <= first_date) and (last >= last_date)
|
||||
|
||||
|
||||
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
|
||||
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC',
|
||||
environ=None):
|
||||
"""
|
||||
Load benchmark returns and treasury yield curves for the given calendar and
|
||||
benchmark symbol.
|
||||
@@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
|
||||
# We need the trading_day to figure out the close prior to the first
|
||||
# date so that we can compute returns for the first date.
|
||||
trading_day,
|
||||
environ,
|
||||
)
|
||||
tc = ensure_treasury_data(
|
||||
bm_symbol,
|
||||
first_date,
|
||||
last_date,
|
||||
now,
|
||||
environ,
|
||||
)
|
||||
benchmark_returns = br[br.index.slice_indexer(first_date, last_date)]
|
||||
treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)]
|
||||
return benchmark_returns, treasury_curves
|
||||
|
||||
|
||||
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day,
|
||||
environ=None):
|
||||
"""
|
||||
Ensure we have benchmark data for `symbol` from `first_date` to `last_date`
|
||||
|
||||
@@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
path.
|
||||
"""
|
||||
filename = get_benchmark_filename(symbol)
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark',
|
||||
environ)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
@@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
first_date - trading_day,
|
||||
last_date,
|
||||
)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
data.to_csv(get_data_filepath(filename, environ))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache the new benchmark returns')
|
||||
raise
|
||||
@@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
return data
|
||||
|
||||
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now, environ=None):
|
||||
"""
|
||||
Ensure we have treasury data from treasury module associated with
|
||||
`symbol`.
|
||||
@@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
)
|
||||
first_date = max(first_date, loader_module.earliest_possible_date())
|
||||
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury',
|
||||
environ)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
@@ -269,7 +275,7 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
data.to_csv(get_data_filepath(filename, environ))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
@@ -277,14 +283,15 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
return data
|
||||
|
||||
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name):
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name,
|
||||
environ=None):
|
||||
if resource_name == 'benchmark':
|
||||
from_csv = pd.Series.from_csv
|
||||
else:
|
||||
from_csv = pd.DataFrame.from_csv
|
||||
|
||||
# Path for the cache.
|
||||
path = get_data_filepath(filename)
|
||||
path = get_data_filepath(filename, environ)
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
|
||||
import logbook
|
||||
import pandas as pd
|
||||
@@ -86,11 +87,12 @@ class TradingEnvironment(object):
|
||||
trading_calendar=None,
|
||||
asset_db_path=':memory:',
|
||||
future_chain_predicates=CHAIN_PREDICATES,
|
||||
environ=None,
|
||||
):
|
||||
|
||||
self.bm_symbol = bm_symbol
|
||||
if not load:
|
||||
load = load_market_data
|
||||
load = partial(load_market_data, environ=environ)
|
||||
|
||||
if not trading_calendar:
|
||||
trading_calendar = get_calendar("NYSE")
|
||||
|
||||
@@ -16,6 +16,7 @@ from .core import ( # noqa
|
||||
check_allclose,
|
||||
check_arrays,
|
||||
chrange,
|
||||
copy_market_data,
|
||||
create_daily_df_for_asset,
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
|
||||
+23
-4
@@ -29,6 +29,7 @@ from toolz import concat, curry
|
||||
from zipline.assets import AssetFinder, AssetDBWriter
|
||||
from zipline.assets.synthetic import make_simple_equity_info
|
||||
from zipline.data.data_portal import DataPortal
|
||||
from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING
|
||||
from zipline.data.minute_bars import (
|
||||
BcolzMinuteBarReader,
|
||||
BcolzMinuteBarWriter,
|
||||
@@ -52,6 +53,7 @@ from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.numpy_utils import as_column, isnat
|
||||
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
|
||||
from zipline.utils.paths import ensure_directory
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
import numpy as np
|
||||
@@ -695,11 +697,8 @@ def create_data_portal_from_trade_history(asset_finder, trading_calendar,
|
||||
|
||||
|
||||
class FakeDataPortal(DataPortal):
|
||||
def __init__(self, env=None, trading_calendar=None,
|
||||
def __init__(self, env, trading_calendar=None,
|
||||
first_trading_day=None):
|
||||
if env is None:
|
||||
env = TradingEnvironment()
|
||||
|
||||
if trading_calendar is None:
|
||||
trading_calendar = get_calendar("NYSE")
|
||||
|
||||
@@ -862,6 +861,8 @@ class tmp_trading_env(tmp_asset_finder):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
load : callable, optional
|
||||
Function that returns benchmark returns and treasury curves.
|
||||
finder_cls : type, optional
|
||||
The type of asset finder to create from the assets db.
|
||||
**frames
|
||||
@@ -872,8 +873,13 @@ class tmp_trading_env(tmp_asset_finder):
|
||||
empty_trading_env
|
||||
tmp_asset_finder
|
||||
"""
|
||||
def __init__(self, load=None, *args, **kwargs):
|
||||
super(tmp_trading_env, self).__init__(*args, **kwargs)
|
||||
self._load = load
|
||||
|
||||
def __enter__(self):
|
||||
return TradingEnvironment(
|
||||
load=self._load,
|
||||
asset_db_path=super(tmp_trading_env, self).__enter__().engine,
|
||||
)
|
||||
|
||||
@@ -1486,6 +1492,19 @@ def patch_read_csv(url_map, module=pd, strict=False):
|
||||
yield
|
||||
|
||||
|
||||
def copy_market_data(src_market_data_dir, dest_root_dir):
|
||||
symbol = '^GSPC'
|
||||
filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1])
|
||||
|
||||
ensure_directory(os.path.join(dest_root_dir, 'data'))
|
||||
|
||||
for filename in filenames:
|
||||
shutil.copyfile(
|
||||
os.path.join(src_market_data_dir, filename),
|
||||
os.path.join(dest_root_dir, 'data', filename)
|
||||
)
|
||||
|
||||
|
||||
@curry
|
||||
def ensure_doctest(f, name=None):
|
||||
"""Ensure that an object gets doctested. This is useful for instances
|
||||
|
||||
@@ -129,7 +129,7 @@ def _run(handle_data,
|
||||
"invalid url %r, must begin with 'sqlite:///'" %
|
||||
str(bundle_data.asset_finder.engine.url),
|
||||
)
|
||||
env = TradingEnvironment(asset_db_path=connstr)
|
||||
env = TradingEnvironment(asset_db_path=connstr, environ=environ)
|
||||
first_trading_day =\
|
||||
bundle_data.equity_minute_bar_reader.first_trading_day
|
||||
data = DataPortal(
|
||||
@@ -152,7 +152,7 @@ def _run(handle_data,
|
||||
"No PipelineLoader registered for column %s." % column
|
||||
)
|
||||
else:
|
||||
env = None
|
||||
env = TradingEnvironment(environ=environ)
|
||||
choose_loader = None
|
||||
|
||||
perf = TradingAlgorithm(
|
||||
|
||||
Reference in New Issue
Block a user