ENH: allows users to specify the cutoff time for data query in blaze

loaders

This allows people to set their cutoff time to the time they will
actually execute 'before_trading_start'. Currently this is just passed
to the constructor of the loader; however, I would like to make this
managed by the algorithm simulation runner. This would help keep all of
the loaders in sync and lock 'before_trading_start's execution to the
time the data is queried for.
This commit is contained in:
Joe Jevnik
2016-01-07 15:11:49 -05:00
parent dad2bb201c
commit 5a235bdaef
6 changed files with 275 additions and 9 deletions
+38 -1
View File
@@ -4,7 +4,7 @@ Tests for the blaze interface to the pipeline api.
from __future__ import division
from collections import OrderedDict
from datetime import timedelta
from datetime import timedelta, time
from unittest import TestCase
import warnings
@@ -323,6 +323,43 @@ class BlazeToPipelineTestCase(TestCase):
))
assert_frame_equal(result, expected, check_dtype=False)
def test_custom_query_time_tz(self):
df = self.df.copy()
df['timestamp'] = (
pd.DatetimeIndex(df['timestamp'], tz='EST') +
timedelta(hours=8, minutes=45)
).tz_convert('utc')
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc')
expr = bz.Data(df, name='expr', dshape=self.dshape)
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
ds = from_blaze(
expr,
loader=loader,
no_deltas_rule=no_deltas_rules.ignore,
)
p = Pipeline()
p.add(ds.value.latest, 'value')
dates = self.dates
with tmp_asset_finder() as finder:
result = SimplePipelineEngine(
loader,
dates,
finder,
).run_pipeline(p, dates[0], dates[-1])
expected = df.drop('asof_date', axis=1)
expected['timestamp'] = expected['timestamp'].dt.normalize().astype(
'datetime64[ns]',
)
expected.ix[3:5, 'timestamp'] += timedelta(days=1)
expected.set_index(['timestamp', 'sid'], inplace=True)
expected.index = pd.MultiIndex.from_product((
expected.index.levels[0],
finder.retrieve_all(expected.index.levels[1]),
))
assert_frame_equal(result, expected, check_dtype=False)
def test_id_macro_dataset(self):
expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape)
loader = BlazeLoader()
+30
View File
@@ -7,10 +7,12 @@ from unittest import TestCase
from nose_parameterized import parameterized
from numpy import arange, dtype
import pytz
from six import PY3
from zipline.utils.preprocess import call, preprocess
from zipline.utils.input_validation import (
ensure_timezone,
expect_element,
expect_dtypes,
expect_types,
@@ -317,3 +319,31 @@ class PreprocessTestCase(TestCase):
"or 'float64' for argument 'a', but got 'uint32' instead."
).format(qualname=qualname(foo))
self.assertEqual(e.exception.args[0], expected_message)
def test_ensure_timezone(self):
@preprocess(tz=ensure_timezone)
def f(tz):
return tz
valid = {
'utc',
'EST',
'US/Eastern',
}
invalid = {
# unfortunatly, these are not actually timezones (yet)
'ayy',
'lmao',
}
# test coercing from string
for tz in valid:
self.assertEqual(f(tz), pytz.timezone(tz))
# test pass through of tzinfo objects
for tz in map(pytz.timezone, valid):
self.assertEqual(f(tz), tz)
# test invalid timezone strings
for tz in invalid:
self.assertRaises(pytz.UnknownTimeZoneError, f, tz)
+65 -6
View File
@@ -127,6 +127,7 @@ from __future__ import division, absolute_import
from abc import ABCMeta, abstractproperty
from collections import namedtuple, defaultdict
from copy import copy
from datetime import time
from functools import partial
from itertools import count
import warnings
@@ -144,6 +145,7 @@ from datashape import (
)
from odo import odo
import pandas as pd
from six import with_metaclass, PY2, itervalues, iteritems
from toolz import (
complement,
compose,
@@ -154,15 +156,19 @@ from toolz import (
memoize,
)
import toolz.curried.operator as op
from six import with_metaclass, PY2, itervalues, iteritems
from zipline.pipeline.data.dataset import DataSet, Column
from zipline.pipeline.loaders.utils import (
normalize_data_query_time,
normalize_timestamp_to_query_time,
)
from zipline.lib.adjusted_array import AdjustedArray
from zipline.lib.adjustment import Float64Overwrite
from zipline.utils.enum import enum
from zipline.utils.input_validation import expect_element
from zipline.utils.input_validation import expect_element, ensure_timezone
from zipline.utils.numpy_utils import repeat_last_axis
from zipline.utils.preprocess import preprocess
AD_FIELD_NAME = 'asof_date'
@@ -764,8 +770,25 @@ def adjustments_from_deltas_with_sids(dates,
class BlazeLoader(dict):
def __init__(self, colmap=None):
"""A PipelineLoader for datasets constructed with ``from_blaze``.
Parameters
----------
colmap : mapping[BoundColumn -> tuple[Expr, Expr, any]], optional
The initial column mapping to use.
data_query_time : time, optional
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
"""
@preprocess(data_query_tz=ensure_timezone)
def __init__(self,
colmap=None,
data_query_time=time(0),
data_query_tz='utc'):
self.update(colmap or {})
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
@classmethod
@memoize(cache=WeakKeyDictionary())
@@ -802,6 +825,19 @@ class BlazeLoader(dict):
[SID_FIELD_NAME] if have_sids else []
)
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
lower_dt = normalize_data_query_time(
dates[0],
data_query_time,
data_query_tz,
)
upper_dt = normalize_data_query_time(
dates[-1],
data_query_time,
data_query_tz,
)
def where(e):
"""Create the query to run against the resources.
@@ -819,8 +855,8 @@ class BlazeLoader(dict):
# Hack to get the lower bound to query:
# This must be strictly executed because the data for `ts` will
# be removed from scope too early otherwise.
lower = odo(ts[ts <= dates[0]].max(), pd.Timestamp)
selection = ts <= dates[-1]
lower = odo(ts[ts <= lower_dt].max(), pd.Timestamp)
selection = ts <= upper_dt
if have_sids:
selection &= e[SID_FIELD_NAME].isin(assets)
if lower is not pd.NaT:
@@ -830,11 +866,32 @@ class BlazeLoader(dict):
extra_kwargs = {'d': resources} if resources else {}
materialized_expr = odo(where(expr), pd.DataFrame, **extra_kwargs)
materialized_expr[TS_FIELD_NAME] = materialized_expr[
TS_FIELD_NAME
].astype('datetime64[ns]')
materialized_deltas = (
odo(where(deltas), pd.DataFrame, **extra_kwargs)
if deltas is not None else
pd.DataFrame(columns=query_fields)
)
materialized_deltas[TS_FIELD_NAME] = materialized_deltas[
TS_FIELD_NAME
].astype('datetime64[ns]')
normalize_timestamp_to_query_time(
materialized_expr,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
normalize_timestamp_to_query_time(
materialized_deltas,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
# Inline the deltas that changed our most recently known value.
# Also, we reindex by the dates to create a dense representation of
@@ -978,7 +1035,7 @@ def ffill_query_in_range(expr,
# range. It must all be null anyways.
computed_lower = lower
return odo(
raw = odo(
expr[
(expr[ts_field] >= computed_lower) &
(expr[ts_field] <= upper)
@@ -986,3 +1043,5 @@ def ffill_query_in_range(expr,
pd.DataFrame,
**odo_kwargs
)
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
return raw
+36 -2
View File
@@ -1,3 +1,5 @@
import datetime
from datashape import istabular
import pandas as pd
from toolz import valmap
@@ -11,6 +13,12 @@ from .core import (
from zipline.pipeline.data import EarningsCalendar
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
from zipline.pipeline.loaders.utils import (
normalize_data_query_time,
normalize_timestamp_to_query_time,
)
from zipline.utils.input_validation import ensure_timezone
from zipline.utils.preprocess import preprocess
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
@@ -28,6 +36,10 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
Mapping from the atomic terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
Notes
-----
@@ -58,10 +70,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
ANNOUNCEMENT_FIELD_NAME,
})
@preprocess(data_query_tz=ensure_timezone)
def __init__(self,
expr,
resources=None,
odo_kwargs=None,
data_query_time=datetime.time(0),
data_query_tz='utc',
dataset=EarningsCalendar):
dshape = expr.dshape
@@ -77,12 +92,24 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
)
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
self._dataset = dataset
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
def load_adjusted_array(self, columns, dates, assets, mask):
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
raw = ffill_query_in_range(
self._expr,
dates[0],
dates[-1],
normalize_data_query_time(
dates[0],
data_query_time,
data_query_tz,
),
normalize_data_query_time(
dates[-1],
data_query_time,
data_query_tz,
),
self._odo_kwargs,
)
sids = raw.loc[:, SID_FIELD_NAME]
@@ -90,6 +117,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
sids[~sids.isin(assets)].index,
inplace=True
)
normalize_timestamp_to_query_time(
raw,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
gb = raw.groupby(SID_FIELD_NAME)
+77
View File
@@ -1,3 +1,5 @@
import datetime
import numpy as np
import pandas as pd
from six import iteritems
@@ -93,3 +95,78 @@ def previous_date_frame(date_index, events_by_sid):
frame = pd.DataFrame(out, index=date_index, columns=sids)
frame.ffill(inplace=True)
return frame
def normalize_data_query_time(dt, time, tz):
"""Apply the correct time and timezone to a date.
Parameters
----------
dt : pd.Timestamp
The original datetime that represents the date.
time : datetime.time
The time of day to use as the cutoff point for new data. Data points
that you learn about after this time will become available to your
algorithm on the next trading day.
tz : tzinfo
The timezone to normalize your dates to before comparing against
`time`.
Returns
-------
query_dt : pd.Timestamp
The timestamp with the correct time and date in utc.
"""
# merge the correct date with the time in the given timezone then convert
# back to utc
return pd.Timestamp(
datetime.datetime.combine(dt.date(), time),
tz=tz,
).tz_convert('utc')
def normalize_timestamp_to_query_time(df,
time,
tz,
inplace=False,
ts_field='timestamp'):
"""Update the timestamp field of a dataframe to normalize dates around
some data query time/timezone.
Parameters
----------
df : pd.DataFrame
The dataframe to update. This needs a column named ``ts_field``.
time : datetime.time
The time of day to use as the cutoff point for new data. Data points
that you learn about after this time will become available to your
algorithm on the next trading day.
tz : tzinfo
The timezone to normalize your dates to before comparing against
`time`.
inplace : bool, optional
Update the dataframe in place.
ts_field : str, optional
The name of the timestamp field in ``df``.
Returns
-------
df : pd.DataFrame
The dataframe with the timestamp field normalized. If ``inplace`` is
true, then this will be the same object as ``df`` otherwise this will
be a copy.
"""
if not inplace:
# don't mutate the dataframe in place
df = df.copy()
dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc')
dtidx_local_time = dtidx.tz_convert(tz)
to_roll_forward = dtidx_local_time.time > time
# for all of the times that are greater than our query time add 1
# day and truncate to the date
df.loc[to_roll_forward, ts_field] = (
dtidx_local_time[to_roll_forward] + datetime.timedelta(days=1)
).normalize().tz_localize(None).tz_localize('utc') # cast back to utc
df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize()
return df
+29
View File
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datetime import tzinfo
from functools import partial
from operator import attrgetter
from numpy import dtype
from pytz import timezone
from six import iteritems, string_types, PY3
from toolz import valmap, complement, compose
import toolz.curried.operator as op
@@ -61,6 +63,33 @@ def ensure_dtype(func, argname, arg):
)
def ensure_timezone(func, argname, arg):
"""Argument preprocessor that converts the input into a tzinfo object.
Usage
-----
>>> from zipline.utils.preprocess import preprocess
>>> @preprocess(tz=ensure_timezone)
... def foo(tz):
... return tz
>>> foo('utc')
<UTC>
"""
if isinstance(arg, tzinfo):
return arg
if isinstance(arg, string_types):
return timezone(arg)
raise TypeError(
"{func}() couldn't convert argument "
"{argname}={arg!r} to a timezone.".format(
func=_qualified_name(func),
argname=argname,
arg=arg,
),
)
def expect_dtypes(*_pos, **named):
"""
Preprocessing decorator that verifies inputs have expected numpy dtypes.