diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index feda4c99..75e365da 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -127,7 +127,6 @@ from __future__ import division, absolute_import from abc import ABCMeta, abstractproperty from collections import namedtuple, defaultdict from copy import copy -from datetime import timedelta from functools import partial from itertools import count import warnings @@ -161,7 +160,7 @@ import toolz.curried.operator as op from zipline.pipeline.data.dataset import DataSet, Column from zipline.pipeline.loaders.utils import ( check_data_query_args, - normalize_data_query_time, + normalize_data_query_bounds, normalize_timestamp_to_query_time, ) from zipline.lib.adjusted_array import AdjustedArray @@ -834,20 +833,12 @@ class BlazeLoader(dict): data_query_time = self._data_query_time data_query_tz = self._data_query_tz - if data_query_time is not None: - lower_dt = normalize_data_query_time( - dates[0] - timedelta(days=1), - data_query_time, - data_query_tz, - ) - upper_dt = normalize_data_query_time( - dates[-1], - data_query_time, - data_query_tz, - ) - else: - lower_dt = dates[0] - timedelta(days=1) - upper_dt = dates[-1] + lower_dt, upper_dt = normalize_data_query_bounds( + dates[0], + dates[-1], + data_query_time, + data_query_tz, + ) def where(e): """Create the query to run against the resources. diff --git a/zipline/pipeline/loaders/blaze/earnings.py b/zipline/pipeline/loaders/blaze/earnings.py index bfb17d4d..9f5137f8 100644 --- a/zipline/pipeline/loaders/blaze/earnings.py +++ b/zipline/pipeline/loaders/blaze/earnings.py @@ -1,5 +1,3 @@ -from datetime import timedelta - from datashape import istabular import pandas as pd from toolz import valmap @@ -15,7 +13,7 @@ from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.loaders.earnings import EarningsCalendarLoader from zipline.pipeline.loaders.utils import ( check_data_query_args, - normalize_data_query_time, + normalize_data_query_bounds, normalize_timestamp_to_query_time, ) from zipline.utils.input_validation import ensure_timezone, optionally @@ -100,20 +98,12 @@ class BlazeEarningsCalendarLoader(PipelineLoader): def load_adjusted_array(self, columns, dates, assets, mask): data_query_time = self._data_query_time data_query_tz = self._data_query_tz - if data_query_time is not None: - lower_dt = normalize_data_query_time( - dates[0] - timedelta(days=1), - data_query_time, - data_query_tz, - ) - upper_dt = normalize_data_query_time( - dates[-1], - data_query_time, - data_query_tz, - ) - else: - lower_dt = dates[0] - timedelta(days=1) - upper_dt = dates[-1] + lower_dt, upper_dt = normalize_data_query_bounds( + dates[0], + dates[-1], + data_query_time, + data_query_tz, + ) raw = ffill_query_in_range( self._expr, diff --git a/zipline/pipeline/loaders/utils.py b/zipline/pipeline/loaders/utils.py index 244ac4ba..e5682461 100644 --- a/zipline/pipeline/loaders/utils.py +++ b/zipline/pipeline/loaders/utils.py @@ -125,6 +125,40 @@ def normalize_data_query_time(dt, time, tz): ).tz_convert('utc') +def normalize_data_query_bounds(lower, upper, time, tz): + """Adjust the first and last dates in the requested datetime index based on + the provided query time and tz. + + lower : pd.Timestamp + The lower date requested. + upper : pd.Timestamp + The upper date requested. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + """ + # Subtract one day to grab things that happened on the first day we are + # requesting. This doesn't need to be a trading day, we are only adding + # a lower bound to limit the amount of in memory filtering that needs + # to happen. + lower -= datetime.timedelta(days=1) + if time is not None: + return normalize_data_query_time( + lower, + time, + tz, + ), normalize_data_query_time( + upper, + time, + tz, + ) + return lower, upper + + def normalize_timestamp_to_query_time(df, time, tz,