TST: Make TradingEnvironment resources static

This commit is contained in:
dmichalowicz
2017-03-31 11:07:41 -04:00
parent cf68953bf2
commit 483ec5dae8
6 changed files with 13824 additions and 74 deletions
+9 -4
View File
@@ -72,6 +72,13 @@ install:
- sed -i "s/scipy==.*/scipy==%SCIPY_VERSION%/" etc/requirements.txt
- conda info -a
- conda install conda=4.1.11 conda-build=1.21.11 anaconda-client=1.5.1 --yes -q
# https://blog.ionelmc.ro/2014/12/21/compiling-python-extensions-on-windows/ for 64bit C compilation
- ps: copy .\ci\appveyor\vcvars64.bat "C:\Program Files (x86)\Microsoft Visual Studio 10.0\VC\bin\amd64"
- "%CMD_IN_ENV% python .\\ci\\make_conda_packages.py"
# test that we can conda install zipline in a new env
- conda create -n installenv --yes -q --use-local python=%PYTHON_VERSION% numpy=%NUMPY_VERSION% zipline -c quantopian -c https://conda.anaconda.org/quantopian/label/ci
- ps: $env:BCOLZ_VERSION=(sls "bcolz==(.*)" .\etc\requirements.txt -ca).matches.groups[1].value
- ps: $env:NUMEXPR_VERSION=(sls "numexpr==(.*)" .\etc\requirements.txt -ca).matches.groups[1].value
@@ -88,11 +95,9 @@ install:
- pip freeze | sort
test_script:
- nosetests -e zipline.utils.numpy_utils -x
- nosetests -e zipline.utils.numpy_utils
- flake8 zipline tests
branches:
only:
- master
on_finish:
- ps: $blockRdp = $true; iex ((new-object net.webclient).DownloadString('https://raw.githubusercontent.com/appveyor/ci/master/scripts/enable-rdp.ps1'))
+4 -9
View File
@@ -64,11 +64,6 @@ class TestRisk(WithTradingEnvironment, ZiplineTestCase):
treasury_curves=self.env.treasury_curves,
)
@classmethod
def init_class_fixtures(cls):
cls.TRADING_CALENDAR_PRIMARY_CAL = 'NYSE'
super(TestRisk, cls).init_class_fixtures()
def test_factory(self):
returns = [0.1] * 100
r_objects = factory.create_returns_from_list(returns, self.sim_params)
@@ -393,18 +388,18 @@ class TestRisk(WithTradingEnvironment, ZiplineTestCase):
pd.Timestamp("1991-01-01", tz='UTC')
)
# 2008 and 2012 were leap years
# 1992 and 1996 were leap years
total_days = 365 * 5 + 2
end_session = start_session + datetime.timedelta(days=total_days)
sim_params = SimulationParameters(
sim_params90s = SimulationParameters(
start_session=start_session,
end_session=end_session,
trading_calendar=self.trading_calendar,
)
returns = factory.create_returns_from_range(sim_params)
returns = factory.create_returns_from_range(sim_params90s)
returns = returns[:-10] # truncate the returns series to end mid-month
metrics = risk.RiskReport(returns, sim_params,
metrics = risk.RiskReport(returns, sim_params90s,
trading_calendar=self.trading_calendar,
treasury_curves=self.env.treasury_curves,
benchmark_returns=self.env.benchmark_returns)
+79 -60
View File
@@ -21,6 +21,7 @@ from pandas_datareader.data import DataReader
import pytz
from six import iteritems
from six.moves.urllib_error import HTTPError
import zipline
from .benchmarks import get_benchmark_returns
from . import treasuries, treasuries_can
@@ -43,6 +44,9 @@ INDEX_MAPPING = {
(treasuries, 'treasury_curves.csv', 'www.federalreserve.gov'),
}
zipline_dir = os.path.dirname(zipline.__file__)
MARKET_DATA_DIR = os.path.join(zipline_dir, 'resources', 'market_data')
ONE_HOUR = pd.Timedelta(hours=1)
@@ -194,51 +198,23 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
A trading day delta. Used to find the day before first_date so we can
get the close of the day prior to first_date.
We attempt to download data unless we already have data stored at the data
cache for `symbol` whose first entry is before or on `first_date` and whose
last entry is on or after `last_date`.
We attempt to download data unless we already have data stored in source or
in the data cache for `symbol` whose first entry is before or on
`first_date` and whose last entry is on or after `last_date`.
If we perform a download and the cache criteria are not satisfied, we wait
at least one hour before attempting a redownload. This is determined by
comparing the current time to the result of os.path.getmtime on the cache
path.
"""
path = get_data_filepath(get_benchmark_filename(symbol))
filename = get_benchmark_filename(symbol)
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
if data is not None:
return data
# If the path does not exist, it means the first download has not happened
# yet, so don't try to read from 'path'.
if os.path.exists(path):
try:
data = pd.Series.from_csv(path).tz_localize('UTC')
if has_data_for_dates(data, first_date, last_date):
return data
# Don't re-download if we've successfully downloaded and written a
# file in the last hour.
last_download_time = last_modified_time(path)
if (now - last_download_time) <= ONE_HOUR:
logger.warn(
"Refusing to download new benchmark data because a "
"download succeeded at %s." % last_download_time
)
return data
except (OSError, IOError, ValueError) as e:
# These can all be raised by various versions of pandas on various
# classes of malformed input. Treat them all as cache misses.
logger.info(
"Loading data for {path} failed with error [{error}].".format(
path=path, error=e,
)
)
logger.info(
"Cache at {path} does not have data from {start} to {end}.\n"
"Downloading benchmark data for '{symbol}'.",
start=first_date,
end=last_date,
symbol=symbol,
path=path,
)
# If no cached data was found or it was missing any dates then download the
# necessary data.
logger.info('Downloading benchmark data for {symbol!r}.', symbol=symbol)
try:
data = get_benchmark_returns(
@@ -246,7 +222,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
first_date - trading_day,
last_date,
)
data.to_csv(path)
data.to_csv(get_data_filepath(filename))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache the new benchmark returns')
raise
@@ -255,14 +231,14 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
return data
def ensure_treasury_data(bm_symbol, first_date, last_date, now):
def ensure_treasury_data(symbol, first_date, last_date, now):
"""
Ensure we have treasury data from treasury module associated with
`bm_symbol`.
`symbol`.
Parameters
----------
bm_symbol : str
symbol : str
Benchmark symbol for which we're loading associated treasury curves.
first_date : pd.Timestamp
First date required to be in the cache.
@@ -273,9 +249,9 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
re-download data that isn't available due to scheduling quirks or other
failures.
We attempt to download data unless we already have data stored in the cache
for `module_name` whose first entry is before or on `first_date` and whose
last entry is on or after `last_date`.
We attempt to download data unless we already have data stored in source or
in the cache for `module_name` whose first entry is before or on
`first_date` and whose last entry is on or after `last_date`.
If we perform a download and the cache criteria are not satisfied, we wait
at least one hour before attempting a redownload. This is determined by
@@ -283,16 +259,58 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
path.
"""
loader_module, filename, source = INDEX_MAPPING.get(
bm_symbol, INDEX_MAPPING['^GSPC']
symbol, INDEX_MAPPING['^GSPC'],
)
first_date = max(first_date, loader_module.earliest_possible_date())
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
if data is not None:
return data
# If no cached data was found or it was missing any dates then download the
# necessary data.
logger.info('Downloading treasury data for {symbol!r}.', symbol=symbol)
try:
data = loader_module.get_treasury_data(first_date, last_date)
data.to_csv(get_data_filepath(filename))
except (OSError, IOError, HTTPError):
logger.exception('failed to cache treasury data')
if not has_data_for_dates(data, first_date, last_date):
logger.warn("Still don't have expected data after redownload!")
return data
def _load_cached_data(filename, first_date, last_date, now, resource_name):
if resource_name == 'benchmark':
from_csv = pd.Series.from_csv
else:
from_csv = pd.DataFrame.from_csv
# First try to retrieve the data from a static csv file in source.
source_path = os.path.join(MARKET_DATA_DIR, filename)
try:
data = from_csv(source_path).tz_localize('UTC')
if has_data_for_dates(data, first_date, last_date):
return data
except (OSError, IOError, ValueError) as e:
# These can all be raised by various versions of pandas on various
# classes of malformed input.
logger.info(
'Loading data from source path {path!r} failed with error '
'[{error}].',
path=source_path,
error=e,
)
# If the data in source was missing any dates then check the cache.
path = get_data_filepath(filename)
# If the path does not exist, it means the first download has not happened
# yet, so don't try to read from 'path'.
if os.path.exists(path):
try:
data = pd.DataFrame.from_csv(path).tz_localize('UTC')
data = from_csv(path).tz_localize('UTC')
if has_data_for_dates(data, first_date, last_date):
return data
@@ -301,8 +319,10 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
last_download_time = last_modified_time(path)
if (now - last_download_time) <= ONE_HOUR:
logger.warn(
"Refusing to download new treasury data because a "
"download succeeded at %s." % last_download_time
"Refusing to download new {resource} data because a "
"download succeeded at {time}.",
resource=resource_name,
time=last_download_time,
)
return data
@@ -310,19 +330,18 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
# These can all be raised by various versions of pandas on various
# classes of malformed input. Treat them all as cache misses.
logger.info(
"Loading data for {path} failed with error [{error}].".format(
path=path, error=e,
)
"Loading data for {path} failed with error [{error}].",
path=path,
error=e,
)
try:
data = loader_module.get_treasury_data(first_date, last_date)
data.to_csv(path)
except (OSError, IOError, HTTPError):
logger.exception('failed to cache treasury data')
if not has_data_for_dates(data, first_date, last_date):
logger.warn("Still don't have expected data after redownload!")
return data
logger.info(
"Cache at {path} does not have data from {start} to {end}.\n",
start=first_date,
end=last_date,
path=path,
)
return None
def _load_raw_yahoo_data(indexes=None, stocks=None, start=None, end=None):
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+50 -1
View File
@@ -1,3 +1,4 @@
import os
import sqlite3
from unittest import TestCase
@@ -18,6 +19,11 @@ from ..data.data_portal import (
DEFAULT_MINUTE_HISTORY_PREFETCH,
DEFAULT_DAILY_HISTORY_PREFETCH,
)
from ..data.loader import (
get_benchmark_filename,
INDEX_MAPPING,
MARKET_DATA_DIR,
)
from ..data.resample import (
minute_frame_to_session_frame,
MinuteResampleSessionBarReader
@@ -484,7 +490,50 @@ class WithTradingEnvironment(WithAssetFinder,
@classmethod
def make_load_function(cls):
return None
def load(*args, **kwargs):
symbol = '^GSPC'
filename = get_benchmark_filename(symbol)
source_path = os.path.join(MARKET_DATA_DIR, filename)
benchmark_returns = \
pd.Series.from_csv(source_path).tz_localize('UTC')
filename = INDEX_MAPPING[symbol][1]
source_path = os.path.join(MARKET_DATA_DIR, filename)
treasury_curves = \
pd.DataFrame.from_csv(source_path).tz_localize('UTC')
# The TradingEnvironment ordinarily uses cached benchmark returns
# and treasury curves data, but when running the zipline tests this
# cache is not always updated to include the appropriate dates
# required by both the futures and equity calendars. In order to
# create more reliable and consistent data throughout the entirety
# of the tests, we read static benchmark returns and treasury curve
# csv files from source. If a test using the TradingEnvironment
# fixture attempts to run outside of the static date range of the
# csv files, raise an exception warning the user to either update
# the csv files in source or to use a date range within the current
# bounds.
static_start_date = benchmark_returns.index[0].date()
static_end_date = benchmark_returns.index[-1].date()
warning_message = (
'The TradingEnvironment fixture uses static data between '
'{static_start} and {static_end}. To use a start and end date '
'of {given_start} and {given_end} you will have to update the '
'files in {resource_dir} to include the missing dates.'.format(
static_start=static_start_date,
static_end=static_end_date,
given_start=cls.START_DATE.date(),
given_end=cls.END_DATE.date(),
resource_dir=MARKET_DATA_DIR,
)
)
if cls.START_DATE.date() < static_start_date or \
cls.END_DATE.date() > static_end_date:
raise Warning(warning_message)
return benchmark_returns, treasury_curves
return load
@classmethod
def make_trading_environment(cls):