mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-27 19:30:28 +08:00
TST: Use testing market data with run_algorithm
so env doesn't need to download it
This commit is contained in:
+17
-10
@@ -88,6 +88,7 @@ from zipline.finance.asset_restrictions import (
|
||||
)
|
||||
from zipline.testing import (
|
||||
FakeDataPortal,
|
||||
copy_market_data,
|
||||
create_daily_df_for_asset,
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
@@ -99,6 +100,7 @@ from zipline.testing import (
|
||||
tmp_trading_env,
|
||||
to_utc,
|
||||
trades_by_sid_to_dfs,
|
||||
tmp_dir,
|
||||
)
|
||||
from zipline.testing import RecordBatchBlotter
|
||||
from zipline.testing.fixtures import (
|
||||
@@ -4760,13 +4762,18 @@ class TestPanelData(WithTradingEnvironment, ZiplineTestCase):
|
||||
check_panels()
|
||||
price_record.loc[:] = np.nan
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
capital_base=1,
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
data_frequency=data_frequency,
|
||||
data=panel
|
||||
)
|
||||
check_panels()
|
||||
with tmp_dir() as tmpdir:
|
||||
root = tmpdir.getpath('example_data/root')
|
||||
copy_market_data(self.MARKET_DATA_DIR, root)
|
||||
|
||||
run_algorithm(
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
capital_base=1,
|
||||
initialize=initialize,
|
||||
handle_data=handle_data,
|
||||
data_frequency=data_frequency,
|
||||
data=panel,
|
||||
environ={'ZIPLINE_ROOT': root},
|
||||
)
|
||||
check_panels()
|
||||
|
||||
@@ -21,8 +21,9 @@ import pandas as pd
|
||||
|
||||
from zipline import examples
|
||||
from zipline.data.bundles import register, unregister
|
||||
from zipline.testing import test_resource_path
|
||||
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
|
||||
from zipline.testing import test_resource_path, copy_market_data
|
||||
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase, \
|
||||
WithTradingEnvironment
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.cache import dataframe_cache
|
||||
|
||||
@@ -53,6 +54,9 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase):
|
||||
serialization='pickle',
|
||||
)
|
||||
|
||||
copy_market_data(WithTradingEnvironment.MARKET_DATA_DIR,
|
||||
cls.tmpdir.getpath('example_data/root'))
|
||||
|
||||
@parameterized.expand(examples.EXAMPLE_MODULES)
|
||||
def test_example(self, example_name):
|
||||
actual_perf = examples.run_example(
|
||||
|
||||
+18
-11
@@ -53,13 +53,13 @@ def last_modified_time(path):
|
||||
return pd.Timestamp(os.path.getmtime(path), unit='s', tz='UTC')
|
||||
|
||||
|
||||
def get_data_filepath(name):
|
||||
def get_data_filepath(name, environ=None):
|
||||
"""
|
||||
Returns a handle to data file.
|
||||
|
||||
Creates containing directory, if needed.
|
||||
"""
|
||||
dr = data_root()
|
||||
dr = data_root(environ)
|
||||
|
||||
if not os.path.exists(dr):
|
||||
os.makedirs(dr)
|
||||
@@ -91,7 +91,8 @@ def has_data_for_dates(series_or_df, first_date, last_date):
|
||||
return (first <= first_date) and (last >= last_date)
|
||||
|
||||
|
||||
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
|
||||
def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC',
|
||||
environ=None):
|
||||
"""
|
||||
Load benchmark returns and treasury yield curves for the given calendar and
|
||||
benchmark symbol.
|
||||
@@ -162,19 +163,22 @@ def load_market_data(trading_day=None, trading_days=None, bm_symbol='^GSPC'):
|
||||
# We need the trading_day to figure out the close prior to the first
|
||||
# date so that we can compute returns for the first date.
|
||||
trading_day,
|
||||
environ,
|
||||
)
|
||||
tc = ensure_treasury_data(
|
||||
bm_symbol,
|
||||
first_date,
|
||||
last_date,
|
||||
now,
|
||||
environ,
|
||||
)
|
||||
benchmark_returns = br[br.index.slice_indexer(first_date, last_date)]
|
||||
treasury_curves = tc[tc.index.slice_indexer(first_date, last_date)]
|
||||
return benchmark_returns, treasury_curves
|
||||
|
||||
|
||||
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day,
|
||||
environ=None):
|
||||
"""
|
||||
Ensure we have benchmark data for `symbol` from `first_date` to `last_date`
|
||||
|
||||
@@ -204,7 +208,8 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
path.
|
||||
"""
|
||||
filename = get_benchmark_filename(symbol)
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark',
|
||||
environ)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
@@ -218,7 +223,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
first_date - trading_day,
|
||||
last_date,
|
||||
)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
data.to_csv(get_data_filepath(filename, environ))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache the new benchmark returns')
|
||||
raise
|
||||
@@ -227,7 +232,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
return data
|
||||
|
||||
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now, environ=None):
|
||||
"""
|
||||
Ensure we have treasury data from treasury module associated with
|
||||
`symbol`.
|
||||
@@ -259,7 +264,8 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
)
|
||||
first_date = max(first_date, loader_module.earliest_possible_date())
|
||||
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury',
|
||||
environ)
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
@@ -269,7 +275,7 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
data.to_csv(get_data_filepath(filename, environ))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
@@ -277,14 +283,15 @@ def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
return data
|
||||
|
||||
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name):
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name,
|
||||
environ=None):
|
||||
if resource_name == 'benchmark':
|
||||
from_csv = pd.Series.from_csv
|
||||
else:
|
||||
from_csv = pd.DataFrame.from_csv
|
||||
|
||||
# Path for the cache.
|
||||
path = get_data_filepath(filename)
|
||||
path = get_data_filepath(filename, environ)
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
|
||||
import logbook
|
||||
import pandas as pd
|
||||
@@ -86,11 +87,12 @@ class TradingEnvironment(object):
|
||||
trading_calendar=None,
|
||||
asset_db_path=':memory:',
|
||||
future_chain_predicates=CHAIN_PREDICATES,
|
||||
environ=None,
|
||||
):
|
||||
|
||||
self.bm_symbol = bm_symbol
|
||||
if not load:
|
||||
load = load_market_data
|
||||
load = partial(load_market_data, environ=environ)
|
||||
|
||||
if not trading_calendar:
|
||||
trading_calendar = get_calendar("NYSE")
|
||||
|
||||
@@ -16,6 +16,7 @@ from .core import ( # noqa
|
||||
check_allclose,
|
||||
check_arrays,
|
||||
chrange,
|
||||
copy_market_data,
|
||||
create_daily_df_for_asset,
|
||||
create_data_portal,
|
||||
create_data_portal_from_trade_history,
|
||||
|
||||
@@ -29,6 +29,7 @@ from toolz import concat, curry
|
||||
from zipline.assets import AssetFinder, AssetDBWriter
|
||||
from zipline.assets.synthetic import make_simple_equity_info
|
||||
from zipline.data.data_portal import DataPortal
|
||||
from zipline.data.loader import get_benchmark_filename, INDEX_MAPPING
|
||||
from zipline.data.minute_bars import (
|
||||
BcolzMinuteBarReader,
|
||||
BcolzMinuteBarWriter,
|
||||
@@ -52,6 +53,7 @@ from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
from zipline.utils.numpy_utils import as_column, isnat
|
||||
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
|
||||
from zipline.utils.paths import ensure_directory
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
import numpy as np
|
||||
@@ -1490,6 +1492,19 @@ def patch_read_csv(url_map, module=pd, strict=False):
|
||||
yield
|
||||
|
||||
|
||||
def copy_market_data(src_market_data_dir, dest_root_dir):
|
||||
symbol = '^GSPC'
|
||||
filenames = (get_benchmark_filename(symbol), INDEX_MAPPING[symbol][1])
|
||||
|
||||
ensure_directory(os.path.join(dest_root_dir, 'data'))
|
||||
|
||||
for filename in filenames:
|
||||
shutil.copyfile(
|
||||
os.path.join(src_market_data_dir, filename),
|
||||
os.path.join(dest_root_dir, 'data', filename)
|
||||
)
|
||||
|
||||
|
||||
@curry
|
||||
def ensure_doctest(f, name=None):
|
||||
"""Ensure that an object gets doctested. This is useful for instances
|
||||
|
||||
@@ -129,7 +129,7 @@ def _run(handle_data,
|
||||
"invalid url %r, must begin with 'sqlite:///'" %
|
||||
str(bundle_data.asset_finder.engine.url),
|
||||
)
|
||||
env = TradingEnvironment(asset_db_path=connstr)
|
||||
env = TradingEnvironment(asset_db_path=connstr, environ=environ)
|
||||
first_trading_day =\
|
||||
bundle_data.equity_minute_bar_reader.first_trading_day
|
||||
data = DataPortal(
|
||||
@@ -152,7 +152,7 @@ def _run(handle_data,
|
||||
"No PipelineLoader registered for column %s." % column
|
||||
)
|
||||
else:
|
||||
env = None
|
||||
env = TradingEnvironment(environ=environ)
|
||||
choose_loader = None
|
||||
|
||||
perf = TradingAlgorithm(
|
||||
|
||||
Reference in New Issue
Block a user