mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 01:41:38 +08:00
ENH: Make the data_query_time arguments optional
This commit is contained in:
@@ -327,9 +327,9 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
df = self.df.copy()
|
||||
df['timestamp'] = (
|
||||
pd.DatetimeIndex(df['timestamp'], tz='EST') +
|
||||
timedelta(hours=8, minutes=45)
|
||||
timedelta(hours=8, minutes=44)
|
||||
).tz_convert('utc')
|
||||
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc')
|
||||
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45', tz='utc')
|
||||
expr = bz.Data(df, name='expr', dshape=self.dshape)
|
||||
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
|
||||
ds = from_blaze(
|
||||
|
||||
@@ -127,7 +127,7 @@ from __future__ import division, absolute_import
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from collections import namedtuple, defaultdict
|
||||
from copy import copy
|
||||
from datetime import time
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
import warnings
|
||||
@@ -160,13 +160,18 @@ import toolz.curried.operator as op
|
||||
|
||||
from zipline.pipeline.data.dataset import DataSet, Column
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
normalize_data_query_time,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
from zipline.utils.enum import enum
|
||||
from zipline.utils.input_validation import expect_element, ensure_timezone
|
||||
from zipline.utils.input_validation import (
|
||||
expect_element,
|
||||
ensure_timezone,
|
||||
optionally,
|
||||
)
|
||||
from zipline.utils.numpy_utils import repeat_last_axis
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
@@ -781,12 +786,14 @@ class BlazeLoader(dict):
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
"""
|
||||
@preprocess(data_query_tz=ensure_timezone)
|
||||
@preprocess(data_query_tz=optionally(ensure_timezone))
|
||||
def __init__(self,
|
||||
colmap=None,
|
||||
data_query_time=time(0),
|
||||
data_query_tz='utc'):
|
||||
data_query_time=None,
|
||||
data_query_tz=None):
|
||||
self.update(colmap or {})
|
||||
|
||||
check_data_query_args(data_query_time, data_query_tz)
|
||||
self._data_query_time = data_query_time
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
@@ -827,16 +834,20 @@ class BlazeLoader(dict):
|
||||
|
||||
data_query_time = self._data_query_time
|
||||
data_query_tz = self._data_query_tz
|
||||
lower_dt = normalize_data_query_time(
|
||||
dates[0],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
upper_dt = normalize_data_query_time(
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
if data_query_time is not None:
|
||||
lower_dt = normalize_data_query_time(
|
||||
dates[0] - timedelta(days=1),
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
upper_dt = normalize_data_query_time(
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
else:
|
||||
lower_dt = dates[0] - timedelta(days=1)
|
||||
upper_dt = dates[-1]
|
||||
|
||||
def where(e):
|
||||
"""Create the query to run against the resources.
|
||||
@@ -871,18 +882,20 @@ class BlazeLoader(dict):
|
||||
if deltas is not None else
|
||||
pd.DataFrame(columns=query_fields)
|
||||
)
|
||||
for m in (materialized_expr, materialized_deltas):
|
||||
m.loc[:, TS_FIELD_NAME] = m.loc[
|
||||
:, TS_FIELD_NAME
|
||||
].astype('datetime64[ns]')
|
||||
|
||||
normalize_timestamp_to_query_time(
|
||||
m,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
if data_query_time is not None:
|
||||
for m in (materialized_expr, materialized_deltas):
|
||||
m.loc[:, TS_FIELD_NAME] = m.loc[
|
||||
:, TS_FIELD_NAME
|
||||
].astype('datetime64[ns]')
|
||||
|
||||
normalize_timestamp_to_query_time(
|
||||
m,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
|
||||
# Inline the deltas that changed our most recently known value.
|
||||
# Also, we reindex by the dates to create a dense representation of
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import datetime
|
||||
from datetime import timedelta
|
||||
|
||||
from datashape import istabular
|
||||
import pandas as pd
|
||||
@@ -14,10 +14,11 @@ from zipline.pipeline.data import EarningsCalendar
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
normalize_data_query_time,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_timezone
|
||||
from zipline.utils.input_validation import ensure_timezone, optionally
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
@@ -70,13 +71,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
})
|
||||
|
||||
@preprocess(data_query_tz=ensure_timezone)
|
||||
@preprocess(data_query_tz=optionally(ensure_timezone))
|
||||
def __init__(self,
|
||||
expr,
|
||||
resources=None,
|
||||
odo_kwargs=None,
|
||||
data_query_time=datetime.time(0),
|
||||
data_query_tz='utc',
|
||||
data_query_time=None,
|
||||
data_query_tz=None,
|
||||
dataset=EarningsCalendar):
|
||||
dshape = expr.dshape
|
||||
|
||||
@@ -92,24 +93,32 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
)
|
||||
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
|
||||
self._dataset = dataset
|
||||
check_data_query_args(data_query_time, data_query_tz)
|
||||
self._data_query_time = data_query_time
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
data_query_time = self._data_query_time
|
||||
data_query_tz = self._data_query_tz
|
||||
raw = ffill_query_in_range(
|
||||
self._expr,
|
||||
normalize_data_query_time(
|
||||
dates[0],
|
||||
if data_query_time is not None:
|
||||
lower_dt = normalize_data_query_time(
|
||||
dates[0] - timedelta(days=1),
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
),
|
||||
normalize_data_query_time(
|
||||
)
|
||||
upper_dt = normalize_data_query_time(
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
),
|
||||
)
|
||||
else:
|
||||
lower_dt = dates[0] - timedelta(days=1)
|
||||
upper_dt = dates[-1]
|
||||
|
||||
raw = ffill_query_in_range(
|
||||
self._expr,
|
||||
lower_dt,
|
||||
upper_dt,
|
||||
self._odo_kwargs,
|
||||
)
|
||||
sids = raw.loc[:, SID_FIELD_NAME]
|
||||
@@ -117,13 +126,14 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
sids[~sids.isin(assets)].index,
|
||||
inplace=True
|
||||
)
|
||||
normalize_timestamp_to_query_time(
|
||||
raw,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
if data_query_time is not None:
|
||||
normalize_timestamp_to_query_time(
|
||||
raw,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
|
||||
gb = raw.groupby(SID_FIELD_NAME)
|
||||
|
||||
|
||||
@@ -162,7 +162,7 @@ def normalize_timestamp_to_query_time(df,
|
||||
|
||||
dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc')
|
||||
dtidx_local_time = dtidx.tz_convert(tz)
|
||||
to_roll_forward = dtidx_local_time.time > time
|
||||
to_roll_forward = dtidx_local_time.time >= time
|
||||
# for all of the times that are greater than our query time add 1
|
||||
# day and truncate to the date
|
||||
df.loc[to_roll_forward, ts_field] = (
|
||||
@@ -170,3 +170,27 @@ def normalize_timestamp_to_query_time(df,
|
||||
).normalize().tz_localize(None).tz_localize('utc') # cast back to utc
|
||||
df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize()
|
||||
return df
|
||||
|
||||
|
||||
def check_data_query_args(data_query_time, data_query_tz):
|
||||
"""Checks the data_query_time and data_query_tz arguments for loaders
|
||||
and raises a standard exception if one is None and the other is not.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_query_time : datetime.time or None
|
||||
data_query_tz : tzinfo or None
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
Raised when only one of the arguments is None.
|
||||
"""
|
||||
if (data_query_time is None) ^ (data_query_tz is None):
|
||||
raise ValueError(
|
||||
"either 'data_query_time' and 'data_query_tz' must both be"
|
||||
" None or neither may be None (got %r, %r)" % (
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user