adding trading_day and trading_days as variables to load_market_data

This commit is contained in:
warren-oneill
2015-03-23 11:52:34 +01:00
committed by jfkirk
parent d578d5825e
commit 49c168b3d0
+13 -15
View File
@@ -31,10 +31,6 @@ from six import iteritems
from . import benchmarks
from . benchmarks import get_benchmark_returns
from zipline.utils.tradingcalendar import (
trading_day,
trading_days
)
logger = logbook.Logger('Loader')
@@ -154,7 +150,7 @@ def get_benchmark_filename(symbol):
return "%s_benchmark.csv" % symbol
def load_market_data(bm_symbol='^GSPC'):
def load_market_data(trading_day, trading_days, bm_symbol='^GSPC'):
bm_filepath = get_data_filepath(get_benchmark_filename(bm_symbol))
try:
saved_benchmarks = pd.Series.from_csv(bm_filepath)
@@ -180,18 +176,21 @@ Fetching data from Yahoo Finance.
# If more than 1 trading days has elapsed since the last day where
# we have data,then we need to update
# We're doing "> 2" rather than "> 1" because we're subtracting an array
# _length_ from an array _index_, and therefore even if we had data up to
# and including the current day, the difference would still be 1.
if len(days_up_to_now) - last_bm_date_offset > 2:
if len(days_up_to_now) - last_bm_date_offset > 1:
benchmark_returns = update_benchmarks(bm_symbol, last_bm_date)
if benchmark_returns.index.tz is None or \
benchmark_returns.index.tz.zone != 'UTC':
if (
benchmark_returns.index.tz is None
or
benchmark_returns.index.tz.zone != 'UTC'
):
benchmark_returns = benchmark_returns.tz_localize('UTC')
else:
benchmark_returns = saved_benchmarks
if benchmark_returns.index.tz is None or\
benchmark_returns.index.tz.zone != 'UTC':
if (
benchmark_returns.index.tz is None
or
benchmark_returns.index.tz.zone != 'UTC'
):
benchmark_returns = benchmark_returns.tz_localize('UTC')
# Get treasury curve module, filename & source from mapping.
@@ -218,8 +217,7 @@ Fetching data from {0}
# If more than 1 trading days has elapsed since the last day where
# we have data,then we need to update
# Comment above explains why this is "> 2".
if len(days_up_to_now) - last_tr_date_offset > 2:
if len(days_up_to_now) - last_tr_date_offset > 1:
treasury_curves = dump_treasury_curves(module, filename)
else:
treasury_curves = saved_curves.tz_localize('UTC')