Merge pull request #947 from quantopian/data-query-time

allows users to specify the cutoff time for data query in blaze loaders
This commit is contained in:
Joe Jevnik
2016-01-14 14:11:07 -05:00
7 changed files with 408 additions and 11 deletions
+19 -1
View File
@@ -63,7 +63,25 @@ Enhancements
subclasses inherit all of the columns from the parent. These columns will be
new sentinels so you can register them a custom loader (:issue:`924`).
* Added :func:`~zipline.utils.input_validation.coerce`.
* Added :func:`~zipline.utils.input_validation.coerce` to coerce inputs from one
type into another before passing them to the function (:issue:`948`).
* Added :func:`~zipline.utils.input_validation.optionally` to wrap other
preprocessor functions to explicitly allow ``None`` (:issue:`947`).
* Added :func:`~zipline.utils.input_validation.ensure_timezone` to allow string
arguments to get converted into ``datetime.tzinfo`` objects. This also allows
``tzinfo`` objects to be passed directly (:issue:`947`).
* Added two optional arguments, ``data_query_time`` and ``data_query_tz`` to
:class:`~zipline.pipeline.loaders.blaze.core.BlazeLoader` and
:class:`~zipline.pipeline.loaders.blaze.earnings.BlazeEarningsCalendarLoader`.
These arguments allow the user to specify some cutoff time for data when
loading from the resource. For example, if I want to simulate executing my
``before_trading_start`` function at ``8:45 US/Eastern`` then I could pass
``datetime.time(8, 45)`` and ``'US/Eastern'`` to the loader. This means that
data that is timestamped on or after ``8:45`` will not seen on that day in the
simulation. The data will be made available on the next day (:issue:`947`).
Experimental Features
~~~~~~~~~~~~~~~~~~~~~
+38 -1
View File
@@ -4,7 +4,7 @@ Tests for the blaze interface to the pipeline api.
from __future__ import division
from collections import OrderedDict
from datetime import timedelta
from datetime import timedelta, time
from unittest import TestCase
import warnings
@@ -323,6 +323,43 @@ class BlazeToPipelineTestCase(TestCase):
))
assert_frame_equal(result, expected, check_dtype=False)
def test_custom_query_time_tz(self):
df = self.df.copy()
df['timestamp'] = (
pd.DatetimeIndex(df['timestamp'], tz='EST') +
timedelta(hours=8, minutes=44)
).tz_convert('utc')
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45', tz='utc')
expr = bz.Data(df, name='expr', dshape=self.dshape)
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
ds = from_blaze(
expr,
loader=loader,
no_deltas_rule=no_deltas_rules.ignore,
)
p = Pipeline()
p.add(ds.value.latest, 'value')
dates = self.dates
with tmp_asset_finder() as finder:
result = SimplePipelineEngine(
loader,
dates,
finder,
).run_pipeline(p, dates[0], dates[-1])
expected = df.drop('asof_date', axis=1)
expected['timestamp'] = expected['timestamp'].dt.normalize().astype(
'datetime64[ns]',
)
expected.ix[3:5, 'timestamp'] += timedelta(days=1)
expected.set_index(['timestamp', 'sid'], inplace=True)
expected.index = pd.MultiIndex.from_product((
expected.index.levels[0],
finder.retrieve_all(expected.index.levels[1]),
))
assert_frame_equal(result, expected, check_dtype=False)
def test_id_macro_dataset(self):
expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape)
loader = BlazeLoader()
+50
View File
@@ -7,14 +7,17 @@ from unittest import TestCase
from nose_parameterized import parameterized
from numpy import arange, dtype
import pytz
from six import PY3
from zipline.utils.preprocess import call, preprocess
from zipline.utils.input_validation import (
ensure_timezone,
expect_element,
expect_dtypes,
expect_types,
optional,
optionally,
)
@@ -317,3 +320,50 @@ class PreprocessTestCase(TestCase):
"or 'float64' for argument 'a', but got 'uint32' instead."
).format(qualname=qualname(foo))
self.assertEqual(e.exception.args[0], expected_message)
def test_ensure_timezone(self):
@preprocess(tz=ensure_timezone)
def f(tz):
return tz
valid = {
'utc',
'EST',
'US/Eastern',
}
invalid = {
# unfortunatly, these are not actually timezones (yet)
'ayy',
'lmao',
}
# test coercing from string
for tz in valid:
self.assertEqual(f(tz), pytz.timezone(tz))
# test pass through of tzinfo objects
for tz in map(pytz.timezone, valid):
self.assertEqual(f(tz), tz)
# test invalid timezone strings
for tz in invalid:
self.assertRaises(pytz.UnknownTimeZoneError, f, tz)
def test_optionally(self):
error = TypeError('arg must be int')
def preprocessor(func, argname, arg):
if not isinstance(arg, int):
raise error
return arg
@preprocess(a=optionally(preprocessor))
def f(a):
return a
self.assertIs(f(1), 1)
self.assertIsNone(f(None))
with self.assertRaises(TypeError) as e:
f('a')
self.assertIs(e.exception, error)
+60 -6
View File
@@ -144,6 +144,7 @@ from datashape import (
)
from odo import odo
import pandas as pd
from six import with_metaclass, PY2, itervalues, iteritems
from toolz import (
complement,
compose,
@@ -154,15 +155,24 @@ from toolz import (
memoize,
)
import toolz.curried.operator as op
from six import with_metaclass, PY2, itervalues, iteritems
from zipline.pipeline.data.dataset import DataSet, Column
from zipline.pipeline.loaders.utils import (
check_data_query_args,
normalize_data_query_bounds,
normalize_timestamp_to_query_time,
)
from zipline.lib.adjusted_array import AdjustedArray
from zipline.lib.adjustment import Float64Overwrite
from zipline.utils.enum import enum
from zipline.utils.input_validation import expect_element
from zipline.utils.input_validation import (
expect_element,
ensure_timezone,
optionally,
)
from zipline.utils.numpy_utils import repeat_last_axis
from zipline.utils.preprocess import preprocess
AD_FIELD_NAME = 'asof_date'
@@ -764,9 +774,28 @@ def adjustments_from_deltas_with_sids(dates,
class BlazeLoader(dict):
def __init__(self, colmap=None):
"""A PipelineLoader for datasets constructed with ``from_blaze``.
Parameters
----------
colmap : mapping[BoundColumn -> tuple[Expr, Expr, any]], optional
The initial column mapping to use.
data_query_time : time, optional
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
"""
@preprocess(data_query_tz=optionally(ensure_timezone))
def __init__(self,
colmap=None,
data_query_time=None,
data_query_tz=None):
self.update(colmap or {})
check_data_query_args(data_query_time, data_query_tz)
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
@classmethod
@memoize(cache=WeakKeyDictionary())
def global_instance(cls):
@@ -802,6 +831,15 @@ class BlazeLoader(dict):
[SID_FIELD_NAME] if have_sids else []
)
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
lower_dt, upper_dt = normalize_data_query_bounds(
dates[0],
dates[-1],
data_query_time,
data_query_tz,
)
def where(e):
"""Create the query to run against the resources.
@@ -819,8 +857,8 @@ class BlazeLoader(dict):
# Hack to get the lower bound to query:
# This must be strictly executed because the data for `ts` will
# be removed from scope too early otherwise.
lower = odo(ts[ts <= dates[0]].max(), pd.Timestamp)
selection = ts <= dates[-1]
lower = odo(ts[ts <= lower_dt].max(), pd.Timestamp)
selection = ts <= upper_dt
if have_sids:
selection &= e[SID_FIELD_NAME].isin(assets)
if lower is not pd.NaT:
@@ -836,6 +874,20 @@ class BlazeLoader(dict):
pd.DataFrame(columns=query_fields)
)
if data_query_time is not None:
for m in (materialized_expr, materialized_deltas):
m.loc[:, TS_FIELD_NAME] = m.loc[
:, TS_FIELD_NAME
].astype('datetime64[ns]')
normalize_timestamp_to_query_time(
m,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
# Inline the deltas that changed our most recently known value.
# Also, we reindex by the dates to create a dense representation of
# the data.
@@ -978,7 +1030,7 @@ def ffill_query_in_range(expr,
# range. It must all be null anyways.
computed_lower = lower
return odo(
raw = odo(
expr[
(expr[ts_field] >= computed_lower) &
(expr[ts_field] <= upper)
@@ -986,3 +1038,5 @@ def ffill_query_in_range(expr,
pd.DataFrame,
**odo_kwargs
)
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
return raw
+36 -2
View File
@@ -11,6 +11,13 @@ from .core import (
from zipline.pipeline.data import EarningsCalendar
from zipline.pipeline.loaders.base import PipelineLoader
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
from zipline.pipeline.loaders.utils import (
check_data_query_args,
normalize_data_query_bounds,
normalize_timestamp_to_query_time,
)
from zipline.utils.input_validation import ensure_timezone, optionally
from zipline.utils.preprocess import preprocess
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
@@ -28,6 +35,10 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
Mapping from the atomic terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
Notes
-----
@@ -58,10 +69,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
ANNOUNCEMENT_FIELD_NAME,
})
@preprocess(data_query_tz=optionally(ensure_timezone))
def __init__(self,
expr,
resources=None,
odo_kwargs=None,
data_query_time=None,
data_query_tz=None,
dataset=EarningsCalendar):
dshape = expr.dshape
@@ -77,12 +91,24 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
)
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
self._dataset = dataset
check_data_query_args(data_query_time, data_query_tz)
self._data_query_time = data_query_time
self._data_query_tz = data_query_tz
def load_adjusted_array(self, columns, dates, assets, mask):
raw = ffill_query_in_range(
self._expr,
data_query_time = self._data_query_time
data_query_tz = self._data_query_tz
lower_dt, upper_dt = normalize_data_query_bounds(
dates[0],
dates[-1],
data_query_time,
data_query_tz,
)
raw = ffill_query_in_range(
self._expr,
lower_dt,
upper_dt,
self._odo_kwargs,
)
sids = raw.loc[:, SID_FIELD_NAME]
@@ -90,6 +116,14 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
sids[~sids.isin(assets)].index,
inplace=True
)
if data_query_time is not None:
normalize_timestamp_to_query_time(
raw,
data_query_time,
data_query_tz,
inplace=True,
ts_field=TS_FIELD_NAME,
)
gb = raw.groupby(SID_FIELD_NAME)
+135
View File
@@ -1,3 +1,5 @@
import datetime
import numpy as np
import pandas as pd
from six import iteritems
@@ -93,3 +95,136 @@ def previous_date_frame(date_index, events_by_sid):
frame = pd.DataFrame(out, index=date_index, columns=sids)
frame.ffill(inplace=True)
return frame
def normalize_data_query_time(dt, time, tz):
"""Apply the correct time and timezone to a date.
Parameters
----------
dt : pd.Timestamp
The original datetime that represents the date.
time : datetime.time
The time of day to use as the cutoff point for new data. Data points
that you learn about after this time will become available to your
algorithm on the next trading day.
tz : tzinfo
The timezone to normalize your dates to before comparing against
`time`.
Returns
-------
query_dt : pd.Timestamp
The timestamp with the correct time and date in utc.
"""
# merge the correct date with the time in the given timezone then convert
# back to utc
return pd.Timestamp(
datetime.datetime.combine(dt.date(), time),
tz=tz,
).tz_convert('utc')
def normalize_data_query_bounds(lower, upper, time, tz):
"""Adjust the first and last dates in the requested datetime index based on
the provided query time and tz.
lower : pd.Timestamp
The lower date requested.
upper : pd.Timestamp
The upper date requested.
time : datetime.time
The time of day to use as the cutoff point for new data. Data points
that you learn about after this time will become available to your
algorithm on the next trading day.
tz : tzinfo
The timezone to normalize your dates to before comparing against
`time`.
"""
# Subtract one day to grab things that happened on the first day we are
# requesting. This doesn't need to be a trading day, we are only adding
# a lower bound to limit the amount of in memory filtering that needs
# to happen.
lower -= datetime.timedelta(days=1)
if time is not None:
return normalize_data_query_time(
lower,
time,
tz,
), normalize_data_query_time(
upper,
time,
tz,
)
return lower, upper
def normalize_timestamp_to_query_time(df,
time,
tz,
inplace=False,
ts_field='timestamp'):
"""Update the timestamp field of a dataframe to normalize dates around
some data query time/timezone.
Parameters
----------
df : pd.DataFrame
The dataframe to update. This needs a column named ``ts_field``.
time : datetime.time
The time of day to use as the cutoff point for new data. Data points
that you learn about after this time will become available to your
algorithm on the next trading day.
tz : tzinfo
The timezone to normalize your dates to before comparing against
`time`.
inplace : bool, optional
Update the dataframe in place.
ts_field : str, optional
The name of the timestamp field in ``df``.
Returns
-------
df : pd.DataFrame
The dataframe with the timestamp field normalized. If ``inplace`` is
true, then this will be the same object as ``df`` otherwise this will
be a copy.
"""
if not inplace:
# don't mutate the dataframe in place
df = df.copy()
dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc')
dtidx_local_time = dtidx.tz_convert(tz)
to_roll_forward = dtidx_local_time.time >= time
# for all of the times that are greater than our query time add 1
# day and truncate to the date
df.loc[to_roll_forward, ts_field] = (
dtidx_local_time[to_roll_forward] + datetime.timedelta(days=1)
).normalize().tz_localize(None).tz_localize('utc') # cast back to utc
df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize()
return df
def check_data_query_args(data_query_time, data_query_tz):
"""Checks the data_query_time and data_query_tz arguments for loaders
and raises a standard exception if one is None and the other is not.
Parameters
----------
data_query_time : datetime.time or None
data_query_tz : tzinfo or None
Raises
------
ValueError
Raised when only one of the arguments is None.
"""
if (data_query_time is None) ^ (data_query_tz is None):
raise ValueError(
"either 'data_query_time' and 'data_query_tz' must both be"
" None or neither may be None (got %r, %r)" % (
data_query_time,
data_query_tz,
),
)
+70 -1
View File
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from datetime import tzinfo
from functools import partial, wraps
from operator import attrgetter
from numpy import dtype
from pytz import timezone
from six import iteritems, string_types, PY3
from toolz import valmap, complement, compose
import toolz.curried.operator as op
@@ -22,6 +24,46 @@ import toolz.curried.operator as op
from zipline.utils.preprocess import preprocess
def optionally(preprocessor):
"""Modify a preprocessor to explicitly allow `None`.
Parameters
----------
preprocessor : callable[callable, str, any -> any]
A preprocessor to delegate to when `arg is not None`.
Returns
-------
optional_preprocessor : callable[callable, str, any -> any]
A preprocessor that delegates to `preprocessor` when `arg is not None`.
Usage
-----
>>> def preprocessor(func, argname, arg):
... if not isinstance(arg, int):
... raise TypeError('arg must be int')
... return arg
...
>>> @preprocess(a=optionally(preprocessor))
... def f(a):
... return a
...
>>> f(1) # call with int
1
>>> f('a') # call with not int
Traceback (most recent call last):
...
TypeError: arg must be int
>>> f(None) is None # call with explicit None
True
"""
@wraps(preprocessor)
def wrapper(func, argname, arg):
return arg if arg is None else preprocessor(func, argname, arg)
return wrapper
def ensure_upper_case(func, argname, arg):
if isinstance(arg, string_types):
return arg.upper()
@@ -61,6 +103,33 @@ def ensure_dtype(func, argname, arg):
)
def ensure_timezone(func, argname, arg):
"""Argument preprocessor that converts the input into a tzinfo object.
Usage
-----
>>> from zipline.utils.preprocess import preprocess
>>> @preprocess(tz=ensure_timezone)
... def foo(tz):
... return tz
>>> foo('utc')
<UTC>
"""
if isinstance(arg, tzinfo):
return arg
if isinstance(arg, string_types):
return timezone(arg)
raise TypeError(
"{func}() couldn't convert argument "
"{argname}={arg!r} to a timezone.".format(
func=_qualified_name(func),
argname=argname,
arg=arg,
),
)
def expect_dtypes(*_pos, **named):
"""
Preprocessing decorator that verifies inputs have expected numpy dtypes.