mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 07:39:11 +08:00
TST: Make TradingEnvironment resources static
This commit is contained in:
+9
-4
@@ -72,6 +72,13 @@ install:
|
||||
- sed -i "s/scipy==.*/scipy==%SCIPY_VERSION%/" etc/requirements.txt
|
||||
|
||||
- conda info -a
|
||||
- conda install conda=4.1.11 conda-build=1.21.11 anaconda-client=1.5.1 --yes -q
|
||||
# https://blog.ionelmc.ro/2014/12/21/compiling-python-extensions-on-windows/ for 64bit C compilation
|
||||
- ps: copy .\ci\appveyor\vcvars64.bat "C:\Program Files (x86)\Microsoft Visual Studio 10.0\VC\bin\amd64"
|
||||
- "%CMD_IN_ENV% python .\\ci\\make_conda_packages.py"
|
||||
|
||||
# test that we can conda install zipline in a new env
|
||||
- conda create -n installenv --yes -q --use-local python=%PYTHON_VERSION% numpy=%NUMPY_VERSION% zipline -c quantopian -c https://conda.anaconda.org/quantopian/label/ci
|
||||
|
||||
- ps: $env:BCOLZ_VERSION=(sls "bcolz==(.*)" .\etc\requirements.txt -ca).matches.groups[1].value
|
||||
- ps: $env:NUMEXPR_VERSION=(sls "numexpr==(.*)" .\etc\requirements.txt -ca).matches.groups[1].value
|
||||
@@ -88,11 +95,9 @@ install:
|
||||
- pip freeze | sort
|
||||
|
||||
test_script:
|
||||
- nosetests -e zipline.utils.numpy_utils -x
|
||||
- nosetests -e zipline.utils.numpy_utils
|
||||
- flake8 zipline tests
|
||||
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
|
||||
on_finish:
|
||||
- ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1'))
|
||||
|
||||
@@ -64,11 +64,6 @@ class TestRisk(WithTradingEnvironment, ZiplineTestCase):
|
||||
treasury_curves=self.env.treasury_curves,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
cls.TRADING_CALENDAR_PRIMARY_CAL = 'NYSE'
|
||||
super(TestRisk, cls).init_class_fixtures()
|
||||
|
||||
def test_factory(self):
|
||||
returns = [0.1] * 100
|
||||
r_objects = factory.create_returns_from_list(returns, self.sim_params)
|
||||
@@ -393,18 +388,18 @@ class TestRisk(WithTradingEnvironment, ZiplineTestCase):
|
||||
pd.Timestamp("1991-01-01", tz='UTC')
|
||||
)
|
||||
|
||||
# 2008 and 2012 were leap years
|
||||
# 1992 and 1996 were leap years
|
||||
total_days = 365 * 5 + 2
|
||||
end_session = start_session + datetime.timedelta(days=total_days)
|
||||
sim_params = SimulationParameters(
|
||||
sim_params90s = SimulationParameters(
|
||||
start_session=start_session,
|
||||
end_session=end_session,
|
||||
trading_calendar=self.trading_calendar,
|
||||
)
|
||||
|
||||
returns = factory.create_returns_from_range(sim_params)
|
||||
returns = factory.create_returns_from_range(sim_params90s)
|
||||
returns = returns[:-10] # truncate the returns series to end mid-month
|
||||
metrics = risk.RiskReport(returns, sim_params,
|
||||
metrics = risk.RiskReport(returns, sim_params90s,
|
||||
trading_calendar=self.trading_calendar,
|
||||
treasury_curves=self.env.treasury_curves,
|
||||
benchmark_returns=self.env.benchmark_returns)
|
||||
|
||||
+79
-60
@@ -21,6 +21,7 @@ from pandas_datareader.data import DataReader
|
||||
import pytz
|
||||
from six import iteritems
|
||||
from six.moves.urllib_error import HTTPError
|
||||
import zipline
|
||||
|
||||
from .benchmarks import get_benchmark_returns
|
||||
from . import treasuries, treasuries_can
|
||||
@@ -43,6 +44,9 @@ INDEX_MAPPING = {
|
||||
(treasuries, 'treasury_curves.csv', 'www.federalreserve.gov'),
|
||||
}
|
||||
|
||||
zipline_dir = os.path.dirname(zipline.__file__)
|
||||
MARKET_DATA_DIR = os.path.join(zipline_dir, 'resources', 'market_data')
|
||||
|
||||
ONE_HOUR = pd.Timedelta(hours=1)
|
||||
|
||||
|
||||
@@ -194,51 +198,23 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
A trading day delta. Used to find the day before first_date so we can
|
||||
get the close of the day prior to first_date.
|
||||
|
||||
We attempt to download data unless we already have data stored at the data
|
||||
cache for `symbol` whose first entry is before or on `first_date` and whose
|
||||
last entry is on or after `last_date`.
|
||||
We attempt to download data unless we already have data stored in source or
|
||||
in the data cache for `symbol` whose first entry is before or on
|
||||
`first_date` and whose last entry is on or after `last_date`.
|
||||
|
||||
If we perform a download and the cache criteria are not satisfied, we wait
|
||||
at least one hour before attempting a redownload. This is determined by
|
||||
comparing the current time to the result of os.path.getmtime on the cache
|
||||
path.
|
||||
"""
|
||||
path = get_data_filepath(get_benchmark_filename(symbol))
|
||||
filename = get_benchmark_filename(symbol)
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
data = pd.Series.from_csv(path).tz_localize('UTC')
|
||||
if has_data_for_dates(data, first_date, last_date):
|
||||
return data
|
||||
|
||||
# Don't re-download if we've successfully downloaded and written a
|
||||
# file in the last hour.
|
||||
last_download_time = last_modified_time(path)
|
||||
if (now - last_download_time) <= ONE_HOUR:
|
||||
logger.warn(
|
||||
"Refusing to download new benchmark data because a "
|
||||
"download succeeded at %s." % last_download_time
|
||||
)
|
||||
return data
|
||||
|
||||
except (OSError, IOError, ValueError) as e:
|
||||
# These can all be raised by various versions of pandas on various
|
||||
# classes of malformed input. Treat them all as cache misses.
|
||||
logger.info(
|
||||
"Loading data for {path} failed with error [{error}].".format(
|
||||
path=path, error=e,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Cache at {path} does not have data from {start} to {end}.\n"
|
||||
"Downloading benchmark data for '{symbol}'.",
|
||||
start=first_date,
|
||||
end=last_date,
|
||||
symbol=symbol,
|
||||
path=path,
|
||||
)
|
||||
# If no cached data was found or it was missing any dates then download the
|
||||
# necessary data.
|
||||
logger.info('Downloading benchmark data for {symbol!r}.', symbol=symbol)
|
||||
|
||||
try:
|
||||
data = get_benchmark_returns(
|
||||
@@ -246,7 +222,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
first_date - trading_day,
|
||||
last_date,
|
||||
)
|
||||
data.to_csv(path)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache the new benchmark returns')
|
||||
raise
|
||||
@@ -255,14 +231,14 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
return data
|
||||
|
||||
|
||||
def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
"""
|
||||
Ensure we have treasury data from treasury module associated with
|
||||
`bm_symbol`.
|
||||
`symbol`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bm_symbol : str
|
||||
symbol : str
|
||||
Benchmark symbol for which we're loading associated treasury curves.
|
||||
first_date : pd.Timestamp
|
||||
First date required to be in the cache.
|
||||
@@ -273,9 +249,9 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
re-download data that isn't available due to scheduling quirks or other
|
||||
failures.
|
||||
|
||||
We attempt to download data unless we already have data stored in the cache
|
||||
for `module_name` whose first entry is before or on `first_date` and whose
|
||||
last entry is on or after `last_date`.
|
||||
We attempt to download data unless we already have data stored in source or
|
||||
in the cache for `module_name` whose first entry is before or on
|
||||
`first_date` and whose last entry is on or after `last_date`.
|
||||
|
||||
If we perform a download and the cache criteria are not satisfied, we wait
|
||||
at least one hour before attempting a redownload. This is determined by
|
||||
@@ -283,16 +259,58 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
path.
|
||||
"""
|
||||
loader_module, filename, source = INDEX_MAPPING.get(
|
||||
bm_symbol, INDEX_MAPPING['^GSPC']
|
||||
symbol, INDEX_MAPPING['^GSPC'],
|
||||
)
|
||||
first_date = max(first_date, loader_module.earliest_possible_date())
|
||||
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
# If no cached data was found or it was missing any dates then download the
|
||||
# necessary data.
|
||||
logger.info('Downloading treasury data for {symbol!r}.', symbol=symbol)
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
logger.warn("Still don't have expected data after redownload!")
|
||||
return data
|
||||
|
||||
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name):
|
||||
if resource_name == 'benchmark':
|
||||
from_csv = pd.Series.from_csv
|
||||
else:
|
||||
from_csv = pd.DataFrame.from_csv
|
||||
|
||||
# First try to retrieve the data from a static csv file in source.
|
||||
source_path = os.path.join(MARKET_DATA_DIR, filename)
|
||||
try:
|
||||
data = from_csv(source_path).tz_localize('UTC')
|
||||
if has_data_for_dates(data, first_date, last_date):
|
||||
return data
|
||||
except (OSError, IOError, ValueError) as e:
|
||||
# These can all be raised by various versions of pandas on various
|
||||
# classes of malformed input.
|
||||
logger.info(
|
||||
'Loading data from source path {path!r} failed with error '
|
||||
'[{error}].',
|
||||
path=source_path,
|
||||
error=e,
|
||||
)
|
||||
|
||||
# If the data in source was missing any dates then check the cache.
|
||||
path = get_data_filepath(filename)
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
data = pd.DataFrame.from_csv(path).tz_localize('UTC')
|
||||
data = from_csv(path).tz_localize('UTC')
|
||||
if has_data_for_dates(data, first_date, last_date):
|
||||
return data
|
||||
|
||||
@@ -301,8 +319,10 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
last_download_time = last_modified_time(path)
|
||||
if (now - last_download_time) <= ONE_HOUR:
|
||||
logger.warn(
|
||||
"Refusing to download new treasury data because a "
|
||||
"download succeeded at %s." % last_download_time
|
||||
"Refusing to download new {resource} data because a "
|
||||
"download succeeded at {time}.",
|
||||
resource=resource_name,
|
||||
time=last_download_time,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -310,19 +330,18 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
# These can all be raised by various versions of pandas on various
|
||||
# classes of malformed input. Treat them all as cache misses.
|
||||
logger.info(
|
||||
"Loading data for {path} failed with error [{error}].".format(
|
||||
path=path, error=e,
|
||||
)
|
||||
"Loading data for {path} failed with error [{error}].",
|
||||
path=path,
|
||||
error=e,
|
||||
)
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(path)
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
logger.warn("Still don't have expected data after redownload!")
|
||||
return data
|
||||
logger.info(
|
||||
"Cache at {path} does not have data from {start} to {end}.\n",
|
||||
start=first_date,
|
||||
end=last_date,
|
||||
path=path,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _load_raw_yahoo_data(indexes=None, stocks=None, start=None, end=None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -18,6 +19,11 @@ from ..data.data_portal import (
|
||||
DEFAULT_MINUTE_HISTORY_PREFETCH,
|
||||
DEFAULT_DAILY_HISTORY_PREFETCH,
|
||||
)
|
||||
from ..data.loader import (
|
||||
get_benchmark_filename,
|
||||
INDEX_MAPPING,
|
||||
MARKET_DATA_DIR,
|
||||
)
|
||||
from ..data.resample import (
|
||||
minute_frame_to_session_frame,
|
||||
MinuteResampleSessionBarReader
|
||||
@@ -484,7 +490,50 @@ class WithTradingEnvironment(WithAssetFinder,
|
||||
|
||||
@classmethod
|
||||
def make_load_function(cls):
|
||||
return None
|
||||
def load(*args, **kwargs):
|
||||
symbol = '^GSPC'
|
||||
|
||||
filename = get_benchmark_filename(symbol)
|
||||
source_path = os.path.join(MARKET_DATA_DIR, filename)
|
||||
benchmark_returns = \
|
||||
pd.Series.from_csv(source_path).tz_localize('UTC')
|
||||
|
||||
filename = INDEX_MAPPING[symbol][1]
|
||||
source_path = os.path.join(MARKET_DATA_DIR, filename)
|
||||
treasury_curves = \
|
||||
pd.DataFrame.from_csv(source_path).tz_localize('UTC')
|
||||
|
||||
# The TradingEnvironment ordinarily uses cached benchmark returns
|
||||
# and treasury curves data, but when running the zipline tests this
|
||||
# cache is not always updated to include the appropriate dates
|
||||
# required by both the futures and equity calendars. In order to
|
||||
# create more reliable and consistent data throughout the entirety
|
||||
# of the tests, we read static benchmark returns and treasury curve
|
||||
# csv files from source. If a test using the TradingEnvironment
|
||||
# fixture attempts to run outside of the static date range of the
|
||||
# csv files, raise an exception warning the user to either update
|
||||
# the csv files in source or to use a date range within the current
|
||||
# bounds.
|
||||
static_start_date = benchmark_returns.index[0].date()
|
||||
static_end_date = benchmark_returns.index[-1].date()
|
||||
warning_message = (
|
||||
'The TradingEnvironment fixture uses static data between '
|
||||
'{static_start} and {static_end}. To use a start and end date '
|
||||
'of {given_start} and {given_end} you will have to update the '
|
||||
'files in {resource_dir} to include the missing dates.'.format(
|
||||
static_start=static_start_date,
|
||||
static_end=static_end_date,
|
||||
given_start=cls.START_DATE.date(),
|
||||
given_end=cls.END_DATE.date(),
|
||||
resource_dir=MARKET_DATA_DIR,
|
||||
)
|
||||
)
|
||||
if cls.START_DATE.date() < static_start_date or \
|
||||
cls.END_DATE.date() > static_end_date:
|
||||
raise Warning(warning_message)
|
||||
|
||||
return benchmark_returns, treasury_curves
|
||||
return load
|
||||
|
||||
@classmethod
|
||||
def make_trading_environment(cls):
|
||||
|
||||
Reference in New Issue
Block a user