From 5a235bdaef9024a36bd6abfa2d055ddc8fc96241 Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Thu, 7 Jan 2016 15:11:49 -0500 Subject: [PATCH] ENH: allows users to specify the cutoff time for data query in blaze loaders This allows people to set their cutoff time to the time they will actually execute 'before_trading_start'. Currently this is just passed to the constructor of the loader; however, I would like to make this managed by the algorithm simulation runner. This would help keep all of the loaders in sync and lock 'before_trading_start's execution to the time the data is queried for. --- tests/pipeline/test_blaze.py | 39 ++++++++++- tests/utils/test_preprocess.py | 30 +++++++++ zipline/pipeline/loaders/blaze/core.py | 71 ++++++++++++++++++-- zipline/pipeline/loaders/blaze/earnings.py | 38 ++++++++++- zipline/pipeline/loaders/utils.py | 77 ++++++++++++++++++++++ zipline/utils/input_validation.py | 29 ++++++++ 6 files changed, 275 insertions(+), 9 deletions(-) diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index d22aa244..289c22c3 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -4,7 +4,7 @@ Tests for the blaze interface to the pipeline api. from __future__ import division from collections import OrderedDict -from datetime import timedelta +from datetime import timedelta, time from unittest import TestCase import warnings @@ -323,6 +323,43 @@ class BlazeToPipelineTestCase(TestCase): )) assert_frame_equal(result, expected, check_dtype=False) + def test_custom_query_time_tz(self): + df = self.df.copy() + df['timestamp'] = ( + pd.DatetimeIndex(df['timestamp'], tz='EST') + + timedelta(hours=8, minutes=45) + ).tz_convert('utc') + df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc') + expr = bz.Data(df, name='expr', dshape=self.dshape) + loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST') + ds = from_blaze( + expr, + loader=loader, + no_deltas_rule=no_deltas_rules.ignore, + ) + p = Pipeline() + p.add(ds.value.latest, 'value') + dates = self.dates + + with tmp_asset_finder() as finder: + result = SimplePipelineEngine( + loader, + dates, + finder, + ).run_pipeline(p, dates[0], dates[-1]) + + expected = df.drop('asof_date', axis=1) + expected['timestamp'] = expected['timestamp'].dt.normalize().astype( + 'datetime64[ns]', + ) + expected.ix[3:5, 'timestamp'] += timedelta(days=1) + expected.set_index(['timestamp', 'sid'], inplace=True) + expected.index = pd.MultiIndex.from_product(( + expected.index.levels[0], + finder.retrieve_all(expected.index.levels[1]), + )) + assert_frame_equal(result, expected, check_dtype=False) + def test_id_macro_dataset(self): expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape) loader = BlazeLoader() diff --git a/tests/utils/test_preprocess.py b/tests/utils/test_preprocess.py index 3088f780..5270e605 100644 --- a/tests/utils/test_preprocess.py +++ b/tests/utils/test_preprocess.py @@ -7,10 +7,12 @@ from unittest import TestCase from nose_parameterized import parameterized from numpy import arange, dtype +import pytz from six import PY3 from zipline.utils.preprocess import call, preprocess from zipline.utils.input_validation import ( + ensure_timezone, expect_element, expect_dtypes, expect_types, @@ -317,3 +319,31 @@ class PreprocessTestCase(TestCase): "or 'float64' for argument 'a', but got 'uint32' instead." ).format(qualname=qualname(foo)) self.assertEqual(e.exception.args[0], expected_message) + + def test_ensure_timezone(self): + @preprocess(tz=ensure_timezone) + def f(tz): + return tz + + valid = { + 'utc', + 'EST', + 'US/Eastern', + } + invalid = { + # unfortunatly, these are not actually timezones (yet) + 'ayy', + 'lmao', + } + + # test coercing from string + for tz in valid: + self.assertEqual(f(tz), pytz.timezone(tz)) + + # test pass through of tzinfo objects + for tz in map(pytz.timezone, valid): + self.assertEqual(f(tz), tz) + + # test invalid timezone strings + for tz in invalid: + self.assertRaises(pytz.UnknownTimeZoneError, f, tz) diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 556eb3f8..e2e36b93 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -127,6 +127,7 @@ from __future__ import division, absolute_import from abc import ABCMeta, abstractproperty from collections import namedtuple, defaultdict from copy import copy +from datetime import time from functools import partial from itertools import count import warnings @@ -144,6 +145,7 @@ from datashape import ( ) from odo import odo import pandas as pd +from six import with_metaclass, PY2, itervalues, iteritems from toolz import ( complement, compose, @@ -154,15 +156,19 @@ from toolz import ( memoize, ) import toolz.curried.operator as op -from six import with_metaclass, PY2, itervalues, iteritems from zipline.pipeline.data.dataset import DataSet, Column +from zipline.pipeline.loaders.utils import ( + normalize_data_query_time, + normalize_timestamp_to_query_time, +) from zipline.lib.adjusted_array import AdjustedArray from zipline.lib.adjustment import Float64Overwrite from zipline.utils.enum import enum -from zipline.utils.input_validation import expect_element +from zipline.utils.input_validation import expect_element, ensure_timezone from zipline.utils.numpy_utils import repeat_last_axis +from zipline.utils.preprocess import preprocess AD_FIELD_NAME = 'asof_date' @@ -764,8 +770,25 @@ def adjustments_from_deltas_with_sids(dates, class BlazeLoader(dict): - def __init__(self, colmap=None): + """A PipelineLoader for datasets constructed with ``from_blaze``. + + Parameters + ---------- + colmap : mapping[BoundColumn -> tuple[Expr, Expr, any]], optional + The initial column mapping to use. + data_query_time : time, optional + The time to use for the data query cutoff. + data_query_tz : tzinfo or str + The timezeone to use for the data query cutoff. + """ + @preprocess(data_query_tz=ensure_timezone) + def __init__(self, + colmap=None, + data_query_time=time(0), + data_query_tz='utc'): self.update(colmap or {}) + self._data_query_time = data_query_time + self._data_query_tz = data_query_tz @classmethod @memoize(cache=WeakKeyDictionary()) @@ -802,6 +825,19 @@ class BlazeLoader(dict): [SID_FIELD_NAME] if have_sids else [] ) + data_query_time = self._data_query_time + data_query_tz = self._data_query_tz + lower_dt = normalize_data_query_time( + dates[0], + data_query_time, + data_query_tz, + ) + upper_dt = normalize_data_query_time( + dates[-1], + data_query_time, + data_query_tz, + ) + def where(e): """Create the query to run against the resources. @@ -819,8 +855,8 @@ class BlazeLoader(dict): # Hack to get the lower bound to query: # This must be strictly executed because the data for `ts` will # be removed from scope too early otherwise. - lower = odo(ts[ts <= dates[0]].max(), pd.Timestamp) - selection = ts <= dates[-1] + lower = odo(ts[ts <= lower_dt].max(), pd.Timestamp) + selection = ts <= upper_dt if have_sids: selection &= e[SID_FIELD_NAME].isin(assets) if lower is not pd.NaT: @@ -830,11 +866,32 @@ class BlazeLoader(dict): extra_kwargs = {'d': resources} if resources else {} materialized_expr = odo(where(expr), pd.DataFrame, **extra_kwargs) + materialized_expr[TS_FIELD_NAME] = materialized_expr[ + TS_FIELD_NAME + ].astype('datetime64[ns]') materialized_deltas = ( odo(where(deltas), pd.DataFrame, **extra_kwargs) if deltas is not None else pd.DataFrame(columns=query_fields) ) + materialized_deltas[TS_FIELD_NAME] = materialized_deltas[ + TS_FIELD_NAME + ].astype('datetime64[ns]') + + normalize_timestamp_to_query_time( + materialized_expr, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) + normalize_timestamp_to_query_time( + materialized_deltas, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) # Inline the deltas that changed our most recently known value. # Also, we reindex by the dates to create a dense representation of @@ -978,7 +1035,7 @@ def ffill_query_in_range(expr, # range. It must all be null anyways. computed_lower = lower - return odo( + raw = odo( expr[ (expr[ts_field] >= computed_lower) & (expr[ts_field] <= upper) @@ -986,3 +1043,5 @@ def ffill_query_in_range(expr, pd.DataFrame, **odo_kwargs ) + raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]') + return raw diff --git a/zipline/pipeline/loaders/blaze/earnings.py b/zipline/pipeline/loaders/blaze/earnings.py index 2c4f8dfb..0b2e312c 100644 --- a/zipline/pipeline/loaders/blaze/earnings.py +++ b/zipline/pipeline/loaders/blaze/earnings.py @@ -1,3 +1,5 @@ +import datetime + from datashape import istabular import pandas as pd from toolz import valmap @@ -11,6 +13,12 @@ from .core import ( from zipline.pipeline.data import EarningsCalendar from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.loaders.earnings import EarningsCalendarLoader +from zipline.pipeline.loaders.utils import ( + normalize_data_query_time, + normalize_timestamp_to_query_time, +) +from zipline.utils.input_validation import ensure_timezone +from zipline.utils.preprocess import preprocess ANNOUNCEMENT_FIELD_NAME = 'announcement_date' @@ -28,6 +36,10 @@ class BlazeEarningsCalendarLoader(PipelineLoader): Mapping from the atomic terms of ``expr`` to actual data resources. odo_kwargs : dict, optional Extra keyword arguments to pass to odo when executing the expression. + data_query_time : time, optional + The time to use for the data query cutoff. + data_query_tz : tzinfo or str + The timezeone to use for the data query cutoff. Notes ----- @@ -58,10 +70,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ANNOUNCEMENT_FIELD_NAME, }) + @preprocess(data_query_tz=ensure_timezone) def __init__(self, expr, resources=None, odo_kwargs=None, + data_query_time=datetime.time(0), + data_query_tz='utc', dataset=EarningsCalendar): dshape = expr.dshape @@ -77,12 +92,24 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ) self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {} self._dataset = dataset + self._data_query_time = data_query_time + self._data_query_tz = data_query_tz def load_adjusted_array(self, columns, dates, assets, mask): + data_query_time = self._data_query_time + data_query_tz = self._data_query_tz raw = ffill_query_in_range( self._expr, - dates[0], - dates[-1], + normalize_data_query_time( + dates[0], + data_query_time, + data_query_tz, + ), + normalize_data_query_time( + dates[-1], + data_query_time, + data_query_tz, + ), self._odo_kwargs, ) sids = raw.loc[:, SID_FIELD_NAME] @@ -90,6 +117,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader): sids[~sids.isin(assets)].index, inplace=True ) + normalize_timestamp_to_query_time( + raw, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) gb = raw.groupby(SID_FIELD_NAME) diff --git a/zipline/pipeline/loaders/utils.py b/zipline/pipeline/loaders/utils.py index 43dedfce..bef8906b 100644 --- a/zipline/pipeline/loaders/utils.py +++ b/zipline/pipeline/loaders/utils.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pandas as pd from six import iteritems @@ -93,3 +95,78 @@ def previous_date_frame(date_index, events_by_sid): frame = pd.DataFrame(out, index=date_index, columns=sids) frame.ffill(inplace=True) return frame + + +def normalize_data_query_time(dt, time, tz): + """Apply the correct time and timezone to a date. + + Parameters + ---------- + dt : pd.Timestamp + The original datetime that represents the date. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + + Returns + ------- + query_dt : pd.Timestamp + The timestamp with the correct time and date in utc. + """ + # merge the correct date with the time in the given timezone then convert + # back to utc + return pd.Timestamp( + datetime.datetime.combine(dt.date(), time), + tz=tz, + ).tz_convert('utc') + + +def normalize_timestamp_to_query_time(df, + time, + tz, + inplace=False, + ts_field='timestamp'): + """Update the timestamp field of a dataframe to normalize dates around + some data query time/timezone. + + Parameters + ---------- + df : pd.DataFrame + The dataframe to update. This needs a column named ``ts_field``. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + inplace : bool, optional + Update the dataframe in place. + ts_field : str, optional + The name of the timestamp field in ``df``. + + Returns + ------- + df : pd.DataFrame + The dataframe with the timestamp field normalized. If ``inplace`` is + true, then this will be the same object as ``df`` otherwise this will + be a copy. + """ + if not inplace: + # don't mutate the dataframe in place + df = df.copy() + + dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc') + dtidx_local_time = dtidx.tz_convert(tz) + to_roll_forward = dtidx_local_time.time > time + # for all of the times that are greater than our query time add 1 + # day and truncate to the date + df.loc[to_roll_forward, ts_field] = ( + dtidx_local_time[to_roll_forward] + datetime.timedelta(days=1) + ).normalize().tz_localize(None).tz_localize('utc') # cast back to utc + df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize() + return df diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py index 4d7256a5..fc8f7c24 100644 --- a/zipline/utils/input_validation.py +++ b/zipline/utils/input_validation.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from datetime import tzinfo from functools import partial from operator import attrgetter from numpy import dtype +from pytz import timezone from six import iteritems, string_types, PY3 from toolz import valmap, complement, compose import toolz.curried.operator as op @@ -61,6 +63,33 @@ def ensure_dtype(func, argname, arg): ) +def ensure_timezone(func, argname, arg): + """Argument preprocessor that converts the input into a tzinfo object. + + Usage + ----- + >>> from zipline.utils.preprocess import preprocess + >>> @preprocess(tz=ensure_timezone) + ... def foo(tz): + ... return tz + >>> foo('utc') + + """ + if isinstance(arg, tzinfo): + return arg + if isinstance(arg, string_types): + return timezone(arg) + + raise TypeError( + "{func}() couldn't convert argument " + "{argname}={arg!r} to a timezone.".format( + func=_qualified_name(func), + argname=argname, + arg=arg, + ), + ) + + def expect_dtypes(*_pos, **named): """ Preprocessing decorator that verifies inputs have expected numpy dtypes.