mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 23:44:45 +08:00
Merge pull request #1729 from quantopian/us-futures-cal-in-tests
Use 'us_futures' calendar in test fixtures
This commit is contained in:
@@ -42,8 +42,8 @@ class AssetDispatchSessionBarTestCase(WithBcolzEquityDailyBarReader,
|
||||
WithTradingSessions,
|
||||
ZiplineTestCase):
|
||||
|
||||
TRADING_CALENDAR_STRS = ('CME', 'NYSE')
|
||||
TRADING_CALENDAR_PRIMARY_CAL = 'CME'
|
||||
TRADING_CALENDAR_STRS = ('us_futures', 'NYSE')
|
||||
TRADING_CALENDAR_PRIMARY_CAL = 'us_futures'
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = 1, 2, 3
|
||||
|
||||
@@ -54,7 +54,7 @@ class AssetDispatchSessionBarTestCase(WithBcolzEquityDailyBarReader,
|
||||
def make_future_minute_bar_data(cls):
|
||||
m_opens = [
|
||||
cls.trading_calendar.open_and_close_for_session(session)[0]
|
||||
for session in cls.trading_sessions['CME']]
|
||||
for session in cls.trading_sessions['us_futures']]
|
||||
yield 10001, DataFrame({
|
||||
'open': [10000.5, 10001.5, nan],
|
||||
'high': [10000.9, 10001.9, nan],
|
||||
@@ -171,8 +171,8 @@ class AssetDispatchMinuteBarTestCase(WithBcolzEquityMinuteBarReader,
|
||||
WithBcolzFutureMinuteBarReader,
|
||||
ZiplineTestCase):
|
||||
|
||||
TRADING_CALENDAR_STRS = ('CME', 'NYSE')
|
||||
TRADING_CALENDAR_PRIMARY_CAL = 'CME'
|
||||
TRADING_CALENDAR_STRS = ('us_futures', 'NYSE')
|
||||
TRADING_CALENDAR_PRIMARY_CAL = 'us_futures'
|
||||
|
||||
ASSET_FINDER_EQUITY_SIDS = 1, 2, 3
|
||||
|
||||
|
||||
+53
-54
@@ -203,42 +203,14 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
comparing the current time to the result of os.path.getmtime on the cache
|
||||
path.
|
||||
"""
|
||||
path = get_data_filepath(get_benchmark_filename(symbol))
|
||||
filename = get_benchmark_filename(symbol)
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'benchmark')
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
data = pd.Series.from_csv(path).tz_localize('UTC')
|
||||
if has_data_for_dates(data, first_date, last_date):
|
||||
return data
|
||||
|
||||
# Don't re-download if we've successfully downloaded and written a
|
||||
# file in the last hour.
|
||||
last_download_time = last_modified_time(path)
|
||||
if (now - last_download_time) <= ONE_HOUR:
|
||||
logger.warn(
|
||||
"Refusing to download new benchmark data because a "
|
||||
"download succeeded at %s." % last_download_time
|
||||
)
|
||||
return data
|
||||
|
||||
except (OSError, IOError, ValueError) as e:
|
||||
# These can all be raised by various versions of pandas on various
|
||||
# classes of malformed input. Treat them all as cache misses.
|
||||
logger.info(
|
||||
"Loading data for {path} failed with error [{error}].".format(
|
||||
path=path, error=e,
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Cache at {path} does not have data from {start} to {end}.\n"
|
||||
"Downloading benchmark data for '{symbol}'.",
|
||||
start=first_date,
|
||||
end=last_date,
|
||||
symbol=symbol,
|
||||
path=path,
|
||||
)
|
||||
# If no cached data was found or it was missing any dates then download the
|
||||
# necessary data.
|
||||
logger.info('Downloading benchmark data for {symbol!r}.', symbol=symbol)
|
||||
|
||||
try:
|
||||
data = get_benchmark_returns(
|
||||
@@ -246,7 +218,7 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
first_date - trading_day,
|
||||
last_date,
|
||||
)
|
||||
data.to_csv(path)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache the new benchmark returns')
|
||||
raise
|
||||
@@ -255,14 +227,14 @@ def ensure_benchmark_data(symbol, first_date, last_date, now, trading_day):
|
||||
return data
|
||||
|
||||
|
||||
def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
def ensure_treasury_data(symbol, first_date, last_date, now):
|
||||
"""
|
||||
Ensure we have treasury data from treasury module associated with
|
||||
`bm_symbol`.
|
||||
`symbol`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bm_symbol : str
|
||||
symbol : str
|
||||
Benchmark symbol for which we're loading associated treasury curves.
|
||||
first_date : pd.Timestamp
|
||||
First date required to be in the cache.
|
||||
@@ -283,16 +255,42 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
path.
|
||||
"""
|
||||
loader_module, filename, source = INDEX_MAPPING.get(
|
||||
bm_symbol, INDEX_MAPPING['^GSPC']
|
||||
symbol, INDEX_MAPPING['^GSPC'],
|
||||
)
|
||||
first_date = max(first_date, loader_module.earliest_possible_date())
|
||||
|
||||
data = _load_cached_data(filename, first_date, last_date, now, 'treasury')
|
||||
if data is not None:
|
||||
return data
|
||||
|
||||
# If no cached data was found or it was missing any dates then download the
|
||||
# necessary data.
|
||||
logger.info('Downloading treasury data for {symbol!r}.', symbol=symbol)
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(get_data_filepath(filename))
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
logger.warn("Still don't have expected data after redownload!")
|
||||
return data
|
||||
|
||||
|
||||
def _load_cached_data(filename, first_date, last_date, now, resource_name):
|
||||
if resource_name == 'benchmark':
|
||||
from_csv = pd.Series.from_csv
|
||||
else:
|
||||
from_csv = pd.DataFrame.from_csv
|
||||
|
||||
# Path for the cache.
|
||||
path = get_data_filepath(filename)
|
||||
|
||||
# If the path does not exist, it means the first download has not happened
|
||||
# yet, so don't try to read from 'path'.
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
data = pd.DataFrame.from_csv(path).tz_localize('UTC')
|
||||
data = from_csv(path).tz_localize('UTC')
|
||||
if has_data_for_dates(data, first_date, last_date):
|
||||
return data
|
||||
|
||||
@@ -301,8 +299,10 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
last_download_time = last_modified_time(path)
|
||||
if (now - last_download_time) <= ONE_HOUR:
|
||||
logger.warn(
|
||||
"Refusing to download new treasury data because a "
|
||||
"download succeeded at %s." % last_download_time
|
||||
"Refusing to download new {resource} data because a "
|
||||
"download succeeded at {time}.",
|
||||
resource=resource_name,
|
||||
time=last_download_time,
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -310,19 +310,18 @@ def ensure_treasury_data(bm_symbol, first_date, last_date, now):
|
||||
# These can all be raised by various versions of pandas on various
|
||||
# classes of malformed input. Treat them all as cache misses.
|
||||
logger.info(
|
||||
"Loading data for {path} failed with error [{error}].".format(
|
||||
path=path, error=e,
|
||||
)
|
||||
"Loading data for {path} failed with error [{error}].",
|
||||
path=path,
|
||||
error=e,
|
||||
)
|
||||
|
||||
try:
|
||||
data = loader_module.get_treasury_data(first_date, last_date)
|
||||
data.to_csv(path)
|
||||
except (OSError, IOError, HTTPError):
|
||||
logger.exception('failed to cache treasury data')
|
||||
if not has_data_for_dates(data, first_date, last_date):
|
||||
logger.warn("Still don't have expected data after redownload!")
|
||||
return data
|
||||
logger.info(
|
||||
"Cache at {path} does not have data from {start} to {end}.\n",
|
||||
start=first_date,
|
||||
end=last_date,
|
||||
path=path,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _load_raw_yahoo_data(indexes=None, stocks=None, start=None, end=None):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
+68
-18
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import sqlite3
|
||||
from unittest import TestCase
|
||||
|
||||
@@ -11,6 +12,8 @@ import responses
|
||||
from .core import (
|
||||
create_daily_bar_data,
|
||||
create_minute_bar_data,
|
||||
make_simple_equity_info,
|
||||
tmp_asset_finder,
|
||||
tmp_dir,
|
||||
)
|
||||
from ..data.data_portal import (
|
||||
@@ -18,17 +21,9 @@ from ..data.data_portal import (
|
||||
DEFAULT_MINUTE_HISTORY_PREFETCH,
|
||||
DEFAULT_DAILY_HISTORY_PREFETCH,
|
||||
)
|
||||
from ..data.resample import (
|
||||
minute_frame_to_session_frame,
|
||||
MinuteResampleSessionBarReader
|
||||
)
|
||||
from ..data.us_equity_pricing import (
|
||||
SQLiteAdjustmentReader,
|
||||
SQLiteAdjustmentWriter,
|
||||
)
|
||||
from ..data.us_equity_pricing import (
|
||||
BcolzDailyBarReader,
|
||||
BcolzDailyBarWriter,
|
||||
from ..data.loader import (
|
||||
get_benchmark_filename,
|
||||
INDEX_MAPPING,
|
||||
)
|
||||
from ..data.minute_bars import (
|
||||
BcolzMinuteBarReader,
|
||||
@@ -36,12 +31,22 @@ from ..data.minute_bars import (
|
||||
US_EQUITIES_MINUTES_PER_DAY,
|
||||
FUTURES_MINUTES_PER_DAY,
|
||||
)
|
||||
|
||||
from ..data.resample import (
|
||||
minute_frame_to_session_frame,
|
||||
MinuteResampleSessionBarReader
|
||||
)
|
||||
from ..data.us_equity_pricing import (
|
||||
BcolzDailyBarReader,
|
||||
BcolzDailyBarWriter,
|
||||
SQLiteAdjustmentReader,
|
||||
SQLiteAdjustmentWriter,
|
||||
)
|
||||
from ..finance.trading import TradingEnvironment
|
||||
from ..utils import factory
|
||||
from ..utils.classproperty import classproperty
|
||||
from ..utils.final import FinalMeta, final
|
||||
from .core import tmp_asset_finder, make_simple_equity_info
|
||||
|
||||
import zipline
|
||||
from zipline.assets import Equity, Future
|
||||
from zipline.finance.asset_restrictions import NoRestrictions
|
||||
from zipline.pipeline import SimplePipelineEngine
|
||||
@@ -51,6 +56,8 @@ from zipline.utils.calendars import (
|
||||
get_calendar,
|
||||
register_calendar)
|
||||
|
||||
zipline_dir = os.path.dirname(zipline.__file__)
|
||||
|
||||
|
||||
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
|
||||
"""
|
||||
@@ -481,10 +488,54 @@ class WithTradingEnvironment(WithAssetFinder,
|
||||
:class:`zipline.finance.trading.TradingEnvironment`
|
||||
"""
|
||||
TRADING_ENV_FUTURE_CHAIN_PREDICATES = None
|
||||
MARKET_DATA_DIR = os.path.join(zipline_dir, 'resources', 'market_data')
|
||||
|
||||
@classmethod
|
||||
def make_load_function(cls):
|
||||
return None
|
||||
def load(*args, **kwargs):
|
||||
symbol = '^GSPC'
|
||||
|
||||
filename = get_benchmark_filename(symbol)
|
||||
source_path = os.path.join(cls.MARKET_DATA_DIR, filename)
|
||||
benchmark_returns = \
|
||||
pd.Series.from_csv(source_path).tz_localize('UTC')
|
||||
|
||||
filename = INDEX_MAPPING[symbol][1]
|
||||
source_path = os.path.join(cls.MARKET_DATA_DIR, filename)
|
||||
treasury_curves = \
|
||||
pd.DataFrame.from_csv(source_path).tz_localize('UTC')
|
||||
|
||||
# The TradingEnvironment ordinarily uses cached benchmark returns
|
||||
# and treasury curves data, but when running the zipline tests this
|
||||
# cache is not always updated to include the appropriate dates
|
||||
# required by both the futures and equity calendars. In order to
|
||||
# create more reliable and consistent data throughout the entirety
|
||||
# of the tests, we read static benchmark returns and treasury curve
|
||||
# csv files from source. If a test using the TradingEnvironment
|
||||
# fixture attempts to run outside of the static date range of the
|
||||
# csv files, raise an exception warning the user to either update
|
||||
# the csv files in source or to use a date range within the current
|
||||
# bounds.
|
||||
static_start_date = benchmark_returns.index[0].date()
|
||||
static_end_date = benchmark_returns.index[-1].date()
|
||||
warning_message = (
|
||||
'The TradingEnvironment fixture uses static data between '
|
||||
'{static_start} and {static_end}. To use a start and end date '
|
||||
'of {given_start} and {given_end} you will have to update the '
|
||||
'files in {resource_dir} to include the missing dates.'.format(
|
||||
static_start=static_start_date,
|
||||
static_end=static_end_date,
|
||||
given_start=cls.START_DATE.date(),
|
||||
given_end=cls.END_DATE.date(),
|
||||
resource_dir=cls.MARKET_DATA_DIR,
|
||||
)
|
||||
)
|
||||
if cls.START_DATE.date() < static_start_date or \
|
||||
cls.END_DATE.date() > static_end_date:
|
||||
raise AssertionError(warning_message)
|
||||
|
||||
return benchmark_returns, treasury_curves
|
||||
return load
|
||||
|
||||
@classmethod
|
||||
def make_trading_environment(cls):
|
||||
@@ -964,7 +1015,7 @@ class WithFutureMinuteBarData(_WithMinuteBarDataBase):
|
||||
|
||||
@classmethod
|
||||
def make_future_minute_bar_data(cls):
|
||||
trading_calendar = get_calendar('CME')
|
||||
trading_calendar = get_calendar('us_futures')
|
||||
return create_minute_bar_data(
|
||||
trading_calendar.minutes_for_sessions_in_range(
|
||||
cls.future_minute_bar_days[0],
|
||||
@@ -976,8 +1027,7 @@ class WithFutureMinuteBarData(_WithMinuteBarDataBase):
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithFutureMinuteBarData, cls).init_class_fixtures()
|
||||
# To be replaced by quanto calendar.
|
||||
trading_calendar = get_calendar('CME')
|
||||
trading_calendar = get_calendar('us_futures')
|
||||
cls.future_minute_bar_days = _trading_days_for_minute_bars(
|
||||
trading_calendar,
|
||||
pd.Timestamp(cls.FUTURE_MINUTE_BAR_START_DATE),
|
||||
@@ -1087,7 +1137,7 @@ class WithBcolzFutureMinuteBarReader(WithFutureMinuteBarData, WithTmpDir):
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithBcolzFutureMinuteBarReader, cls).init_class_fixtures()
|
||||
trading_calendar = get_calendar('CME')
|
||||
trading_calendar = get_calendar('us_futures')
|
||||
cls.bcolz_future_minute_bar_path = p = \
|
||||
cls.make_bcolz_future_minute_bar_rootdir_path()
|
||||
days = cls.future_minute_bar_days
|
||||
|
||||
Reference in New Issue
Block a user