ENH: Make the data_query_time arguments optional

This commit is contained in:
Joe Jevnik
2016-01-12 17:31:17 -05:00
parent 5351b60a4c
commit 2caa9277c4
4 changed files with 95 additions and 48 deletions
+2 -2
View File
@@ -327,9 +327,9 @@ class BlazeToPipelineTestCase(TestCase):
df = self.df.copy()
df['timestamp'] = (
pd.DatetimeIndex(df['timestamp'], tz='EST') +
timedelta(hours=8, minutes=45)
timedelta(hours=8, minutes=44)
).tz_convert('utc')
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc')
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45', tz='utc')
expr = bz.Data(df, name='expr', dshape=self.dshape)
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
ds = from_blaze(
+39 -26
View File
@@ -127,7 +127,7 @@ from __future__ import division, absolute_import
from abc import ABCMeta, abstractproperty
from collections import namedtuple, defaultdict
from copy import copy
from datetime import time
from datetime import timedelta
from functools import partial
from itertools import count
import warnings
@@ -160,13 +160,18 @@ import toolz.curried.operator as op
from zipline.pipeline.data.dataset import DataSet, Column
from zipline.pipeline.loaders.utils import (
check_data_query_args,
normalize_data_query_time,
normalize_timestamp_to_query_time,
)
from zipline.lib.adjusted_array import AdjustedArray
from zipline.lib.adjustment import Float64Overwrite
from zipline.utils.enum import enum
from zipline.utils.input_validation import expect_element, ensure_timezone
from zipline.utils.input_validation import (
expect_element,
ensure_timezone,
optionally,
)
from zipline.utils.numpy_utils import repeat_last_axis
from zipline.utils.preprocess import preprocess
@@ -781,12 +786,14 @@ class BlazeLoader(dict):
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
"""
@preprocess(data_query_tz=ensure_timezone)
@preprocess(data_query_tz=optionally(ensure_timezone))
def __init__(self,
colmap=None,
data_query_time=time(0),
data_query_tz='utc'):
data_query_time=None,
data_query_tz=None):
self.update(colmap or {})
check_data_query_args(data_query_time, data_query_tz)
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
@@ -827,16 +834,20 @@ class BlazeLoader(dict):
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
lower_dt = normalize_data_query_time(
dates[0],
data_query_time,
data_query_tz,
)
upper_dt = normalize_data_query_time(
dates[-1],
data_query_time,
data_query_tz,
)
if data_query_time is not None:
lower_dt = normalize_data_query_time(
dates[0] - timedelta(days=1),
data_query_time,
data_query_tz,
)
upper_dt = normalize_data_query_time(
dates[-1],
data_query_time,
data_query_tz,
)
else:
lower_dt = dates[0] - timedelta(days=1)
upper_dt = dates[-1]
def where(e):
"""Create the query to run against the resources.
@@ -871,18 +882,20 @@ class BlazeLoader(dict):
if deltas is not None else
pd.DataFrame(columns=query_fields)
)
for m in (materialized_expr, materialized_deltas):
m.loc[:, TS_FIELD_NAME] = m.loc[
:, TS_FIELD_NAME
].astype('datetime64[ns]')
normalize_timestamp_to_query_time(
m,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
if data_query_time is not None:
for m in (materialized_expr, materialized_deltas):
m.loc[:, TS_FIELD_NAME] = m.loc[
:, TS_FIELD_NAME
].astype('datetime64[ns]')
normalize_timestamp_to_query_time(
m,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
# Inline the deltas that changed our most recently known value.
# Also, we reindex by the dates to create a dense representation of
+29 -19
View File
@@ -1,4 +1,4 @@
import datetime
from datetime import timedelta
from datashape import istabular
import pandas as pd
@@ -14,10 +14,11 @@ from zipline.pipeline.data import EarningsCalendar
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
from zipline.pipeline.loaders.utils import (
check_data_query_args,
normalize_data_query_time,
normalize_timestamp_to_query_time,
)
from zipline.utils.input_validation import ensure_timezone
from zipline.utils.input_validation import ensure_timezone, optionally
from zipline.utils.preprocess import preprocess
@@ -70,13 +71,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
ANNOUNCEMENT_FIELD_NAME,
})
@preprocess(data_query_tz=ensure_timezone)
@preprocess(data_query_tz=optionally(ensure_timezone))
def __init__(self,
expr,
resources=None,
odo_kwargs=None,
data_query_time=datetime.time(0),
data_query_tz='utc',
data_query_time=None,
data_query_tz=None,
dataset=EarningsCalendar):
dshape = expr.dshape
@@ -92,24 +93,32 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
)
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
self._dataset = dataset
check_data_query_args(data_query_time, data_query_tz)
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
def load_adjusted_array(self, columns, dates, assets, mask):
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
raw = ffill_query_in_range(
self._expr,
normalize_data_query_time(
dates[0],
if data_query_time is not None:
lower_dt = normalize_data_query_time(
dates[0] - timedelta(days=1),
data_query_time,
data_query_tz,
),
normalize_data_query_time(
)
upper_dt = normalize_data_query_time(
dates[-1],
data_query_time,
data_query_tz,
),
)
else:
lower_dt = dates[0] - timedelta(days=1)
upper_dt = dates[-1]
raw = ffill_query_in_range(
self._expr,
lower_dt,
upper_dt,
self._odo_kwargs,
)
sids = raw.loc[:, SID_FIELD_NAME]
@@ -117,13 +126,14 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
sids[~sids.isin(assets)].index,
inplace=True
)
normalize_timestamp_to_query_time(
raw,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
if data_query_time is not None:
normalize_timestamp_to_query_time(
raw,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
gb = raw.groupby(SID_FIELD_NAME)
+25 -1
View File
@@ -162,7 +162,7 @@ def normalize_timestamp_to_query_time(df,
dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc')
dtidx_local_time = dtidx.tz_convert(tz)
to_roll_forward = dtidx_local_time.time > time
to_roll_forward = dtidx_local_time.time >= time
# for all of the times that are greater than our query time add 1
# day and truncate to the date
df.loc[to_roll_forward, ts_field] = (
@@ -170,3 +170,27 @@ def normalize_timestamp_to_query_time(df,
).normalize().tz_localize(None).tz_localize('utc') # cast back to utc
df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize()
return df
def check_data_query_args(data_query_time, data_query_tz):
"""Checks the data_query_time and data_query_tz arguments for loaders
and raises a standard exception if one is None and the other is not.
Parameters
----------
data_query_time : datetime.time or None
data_query_tz : tzinfo or None
Raises
------
ValueError
Raised when only one of the arguments is None.
"""
if (data_query_time is None) ^ (data_query_tz is None):
raise ValueError(
"either 'data_query_time' and 'data_query_tz' must both be"
" None or neither may be None (got %r, %r)" % (
data_query_time,
data_query_tz,
),
)