From 2caa9277c45e56f255ea6d6d90f3dbd7485e07be Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Tue, 12 Jan 2016 17:31:17 -0500 Subject: [PATCH] ENH: Make the data_query_time arguments optional --- tests/pipeline/test_blaze.py | 4 +- zipline/pipeline/loaders/blaze/core.py | 65 +++++++++++++--------- zipline/pipeline/loaders/blaze/earnings.py | 48 +++++++++------- zipline/pipeline/loaders/utils.py | 26 ++++++++- 4 files changed, 95 insertions(+), 48 deletions(-) diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index 289c22c3..11d741b2 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -327,9 +327,9 @@ class BlazeToPipelineTestCase(TestCase): df = self.df.copy() df['timestamp'] = ( pd.DatetimeIndex(df['timestamp'], tz='EST') + - timedelta(hours=8, minutes=45) + timedelta(hours=8, minutes=44) ).tz_convert('utc') - df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc') + df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45', tz='utc') expr = bz.Data(df, name='expr', dshape=self.dshape) loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST') ds = from_blaze( diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index c0cea719..feda4c99 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -127,7 +127,7 @@ from __future__ import division, absolute_import from abc import ABCMeta, abstractproperty from collections import namedtuple, defaultdict from copy import copy -from datetime import time +from datetime import timedelta from functools import partial from itertools import count import warnings @@ -160,13 +160,18 @@ import toolz.curried.operator as op from zipline.pipeline.data.dataset import DataSet, Column from zipline.pipeline.loaders.utils import ( + check_data_query_args, normalize_data_query_time, normalize_timestamp_to_query_time, ) from zipline.lib.adjusted_array import AdjustedArray from zipline.lib.adjustment import Float64Overwrite from zipline.utils.enum import enum -from zipline.utils.input_validation import expect_element, ensure_timezone +from zipline.utils.input_validation import ( + expect_element, + ensure_timezone, + optionally, +) from zipline.utils.numpy_utils import repeat_last_axis from zipline.utils.preprocess import preprocess @@ -781,12 +786,14 @@ class BlazeLoader(dict): data_query_tz : tzinfo or str The timezeone to use for the data query cutoff. """ - @preprocess(data_query_tz=ensure_timezone) + @preprocess(data_query_tz=optionally(ensure_timezone)) def __init__(self, colmap=None, - data_query_time=time(0), - data_query_tz='utc'): + data_query_time=None, + data_query_tz=None): self.update(colmap or {}) + + check_data_query_args(data_query_time, data_query_tz) self._data_query_time = data_query_time self._data_query_tz = data_query_tz @@ -827,16 +834,20 @@ class BlazeLoader(dict): data_query_time = self._data_query_time data_query_tz = self._data_query_tz - lower_dt = normalize_data_query_time( - dates[0], - data_query_time, - data_query_tz, - ) - upper_dt = normalize_data_query_time( - dates[-1], - data_query_time, - data_query_tz, - ) + if data_query_time is not None: + lower_dt = normalize_data_query_time( + dates[0] - timedelta(days=1), + data_query_time, + data_query_tz, + ) + upper_dt = normalize_data_query_time( + dates[-1], + data_query_time, + data_query_tz, + ) + else: + lower_dt = dates[0] - timedelta(days=1) + upper_dt = dates[-1] def where(e): """Create the query to run against the resources. @@ -871,18 +882,20 @@ class BlazeLoader(dict): if deltas is not None else pd.DataFrame(columns=query_fields) ) - for m in (materialized_expr, materialized_deltas): - m.loc[:, TS_FIELD_NAME] = m.loc[ - :, TS_FIELD_NAME - ].astype('datetime64[ns]') - normalize_timestamp_to_query_time( - m, - data_query_time, - data_query_tz, - inplace=True, - ts_field=TS_FIELD_NAME, - ) + if data_query_time is not None: + for m in (materialized_expr, materialized_deltas): + m.loc[:, TS_FIELD_NAME] = m.loc[ + :, TS_FIELD_NAME + ].astype('datetime64[ns]') + + normalize_timestamp_to_query_time( + m, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) # Inline the deltas that changed our most recently known value. # Also, we reindex by the dates to create a dense representation of diff --git a/zipline/pipeline/loaders/blaze/earnings.py b/zipline/pipeline/loaders/blaze/earnings.py index 0b2e312c..bfb17d4d 100644 --- a/zipline/pipeline/loaders/blaze/earnings.py +++ b/zipline/pipeline/loaders/blaze/earnings.py @@ -1,4 +1,4 @@ -import datetime +from datetime import timedelta from datashape import istabular import pandas as pd @@ -14,10 +14,11 @@ from zipline.pipeline.data import EarningsCalendar from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.loaders.earnings import EarningsCalendarLoader from zipline.pipeline.loaders.utils import ( + check_data_query_args, normalize_data_query_time, normalize_timestamp_to_query_time, ) -from zipline.utils.input_validation import ensure_timezone +from zipline.utils.input_validation import ensure_timezone, optionally from zipline.utils.preprocess import preprocess @@ -70,13 +71,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ANNOUNCEMENT_FIELD_NAME, }) - @preprocess(data_query_tz=ensure_timezone) + @preprocess(data_query_tz=optionally(ensure_timezone)) def __init__(self, expr, resources=None, odo_kwargs=None, - data_query_time=datetime.time(0), - data_query_tz='utc', + data_query_time=None, + data_query_tz=None, dataset=EarningsCalendar): dshape = expr.dshape @@ -92,24 +93,32 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ) self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {} self._dataset = dataset + check_data_query_args(data_query_time, data_query_tz) self._data_query_time = data_query_time self._data_query_tz = data_query_tz def load_adjusted_array(self, columns, dates, assets, mask): data_query_time = self._data_query_time data_query_tz = self._data_query_tz - raw = ffill_query_in_range( - self._expr, - normalize_data_query_time( - dates[0], + if data_query_time is not None: + lower_dt = normalize_data_query_time( + dates[0] - timedelta(days=1), data_query_time, data_query_tz, - ), - normalize_data_query_time( + ) + upper_dt = normalize_data_query_time( dates[-1], data_query_time, data_query_tz, - ), + ) + else: + lower_dt = dates[0] - timedelta(days=1) + upper_dt = dates[-1] + + raw = ffill_query_in_range( + self._expr, + lower_dt, + upper_dt, self._odo_kwargs, ) sids = raw.loc[:, SID_FIELD_NAME] @@ -117,13 +126,14 @@ class BlazeEarningsCalendarLoader(PipelineLoader): sids[~sids.isin(assets)].index, inplace=True ) - normalize_timestamp_to_query_time( - raw, - data_query_time, - data_query_tz, - inplace=True, - ts_field=TS_FIELD_NAME, - ) + if data_query_time is not None: + normalize_timestamp_to_query_time( + raw, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) gb = raw.groupby(SID_FIELD_NAME) diff --git a/zipline/pipeline/loaders/utils.py b/zipline/pipeline/loaders/utils.py index bef8906b..244ac4ba 100644 --- a/zipline/pipeline/loaders/utils.py +++ b/zipline/pipeline/loaders/utils.py @@ -162,7 +162,7 @@ def normalize_timestamp_to_query_time(df, dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc') dtidx_local_time = dtidx.tz_convert(tz) - to_roll_forward = dtidx_local_time.time > time + to_roll_forward = dtidx_local_time.time >= time # for all of the times that are greater than our query time add 1 # day and truncate to the date df.loc[to_roll_forward, ts_field] = ( @@ -170,3 +170,27 @@ def normalize_timestamp_to_query_time(df, ).normalize().tz_localize(None).tz_localize('utc') # cast back to utc df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize() return df + + +def check_data_query_args(data_query_time, data_query_tz): + """Checks the data_query_time and data_query_tz arguments for loaders + and raises a standard exception if one is None and the other is not. + + Parameters + ---------- + data_query_time : datetime.time or None + data_query_tz : tzinfo or None + + Raises + ------ + ValueError + Raised when only one of the arguments is None. + """ + if (data_query_time is None) ^ (data_query_tz is None): + raise ValueError( + "either 'data_query_time' and 'data_query_tz' must both be" + " None or neither may be None (got %r, %r)" % ( + data_query_time, + data_query_tz, + ), + )