diff --git a/docs/source/whatsnew/0.8.4.txt b/docs/source/whatsnew/0.8.4.txt index 7dfd070a..d7185090 100644 --- a/docs/source/whatsnew/0.8.4.txt +++ b/docs/source/whatsnew/0.8.4.txt @@ -63,7 +63,25 @@ Enhancements subclasses inherit all of the columns from the parent. These columns will be new sentinels so you can register them a custom loader (:issue:`924`). -* Added :func:`~zipline.utils.input_validation.coerce`. +* Added :func:`~zipline.utils.input_validation.coerce` to coerce inputs from one + type into another before passing them to the function (:issue:`948`). + +* Added :func:`~zipline.utils.input_validation.optionally` to wrap other + preprocessor functions to explicitly allow ``None`` (:issue:`947`). + +* Added :func:`~zipline.utils.input_validation.ensure_timezone` to allow string + arguments to get converted into ``datetime.tzinfo`` objects. This also allows + ``tzinfo`` objects to be passed directly (:issue:`947`). + +* Added two optional arguments, ``data_query_time`` and ``data_query_tz`` to + :class:`~zipline.pipeline.loaders.blaze.core.BlazeLoader` and + :class:`~zipline.pipeline.loaders.blaze.earnings.BlazeEarningsCalendarLoader`. + These arguments allow the user to specify some cutoff time for data when + loading from the resource. For example, if I want to simulate executing my + ``before_trading_start`` function at ``8:45 US/Eastern`` then I could pass + ``datetime.time(8, 45)`` and ``'US/Eastern'`` to the loader. This means that + data that is timestamped on or after ``8:45`` will not seen on that day in the + simulation. The data will be made available on the next day (:issue:`947`). Experimental Features ~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/pipeline/test_blaze.py b/tests/pipeline/test_blaze.py index d22aa244..11d741b2 100644 --- a/tests/pipeline/test_blaze.py +++ b/tests/pipeline/test_blaze.py @@ -4,7 +4,7 @@ Tests for the blaze interface to the pipeline api. from __future__ import division from collections import OrderedDict -from datetime import timedelta +from datetime import timedelta, time from unittest import TestCase import warnings @@ -323,6 +323,43 @@ class BlazeToPipelineTestCase(TestCase): )) assert_frame_equal(result, expected, check_dtype=False) + def test_custom_query_time_tz(self): + df = self.df.copy() + df['timestamp'] = ( + pd.DatetimeIndex(df['timestamp'], tz='EST') + + timedelta(hours=8, minutes=44) + ).tz_convert('utc') + df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:45', tz='utc') + expr = bz.Data(df, name='expr', dshape=self.dshape) + loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST') + ds = from_blaze( + expr, + loader=loader, + no_deltas_rule=no_deltas_rules.ignore, + ) + p = Pipeline() + p.add(ds.value.latest, 'value') + dates = self.dates + + with tmp_asset_finder() as finder: + result = SimplePipelineEngine( + loader, + dates, + finder, + ).run_pipeline(p, dates[0], dates[-1]) + + expected = df.drop('asof_date', axis=1) + expected['timestamp'] = expected['timestamp'].dt.normalize().astype( + 'datetime64[ns]', + ) + expected.ix[3:5, 'timestamp'] += timedelta(days=1) + expected.set_index(['timestamp', 'sid'], inplace=True) + expected.index = pd.MultiIndex.from_product(( + expected.index.levels[0], + finder.retrieve_all(expected.index.levels[1]), + )) + assert_frame_equal(result, expected, check_dtype=False) + def test_id_macro_dataset(self): expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape) loader = BlazeLoader() diff --git a/tests/utils/test_preprocess.py b/tests/utils/test_preprocess.py index 3088f780..4f425b16 100644 --- a/tests/utils/test_preprocess.py +++ b/tests/utils/test_preprocess.py @@ -7,14 +7,17 @@ from unittest import TestCase from nose_parameterized import parameterized from numpy import arange, dtype +import pytz from six import PY3 from zipline.utils.preprocess import call, preprocess from zipline.utils.input_validation import ( + ensure_timezone, expect_element, expect_dtypes, expect_types, optional, + optionally, ) @@ -317,3 +320,50 @@ class PreprocessTestCase(TestCase): "or 'float64' for argument 'a', but got 'uint32' instead." ).format(qualname=qualname(foo)) self.assertEqual(e.exception.args[0], expected_message) + + def test_ensure_timezone(self): + @preprocess(tz=ensure_timezone) + def f(tz): + return tz + + valid = { + 'utc', + 'EST', + 'US/Eastern', + } + invalid = { + # unfortunatly, these are not actually timezones (yet) + 'ayy', + 'lmao', + } + + # test coercing from string + for tz in valid: + self.assertEqual(f(tz), pytz.timezone(tz)) + + # test pass through of tzinfo objects + for tz in map(pytz.timezone, valid): + self.assertEqual(f(tz), tz) + + # test invalid timezone strings + for tz in invalid: + self.assertRaises(pytz.UnknownTimeZoneError, f, tz) + + def test_optionally(self): + error = TypeError('arg must be int') + + def preprocessor(func, argname, arg): + if not isinstance(arg, int): + raise error + return arg + + @preprocess(a=optionally(preprocessor)) + def f(a): + return a + + self.assertIs(f(1), 1) + self.assertIsNone(f(None)) + + with self.assertRaises(TypeError) as e: + f('a') + self.assertIs(e.exception, error) diff --git a/zipline/pipeline/loaders/blaze/core.py b/zipline/pipeline/loaders/blaze/core.py index 556eb3f8..75e365da 100644 --- a/zipline/pipeline/loaders/blaze/core.py +++ b/zipline/pipeline/loaders/blaze/core.py @@ -144,6 +144,7 @@ from datashape import ( ) from odo import odo import pandas as pd +from six import with_metaclass, PY2, itervalues, iteritems from toolz import ( complement, compose, @@ -154,15 +155,24 @@ from toolz import ( memoize, ) import toolz.curried.operator as op -from six import with_metaclass, PY2, itervalues, iteritems from zipline.pipeline.data.dataset import DataSet, Column +from zipline.pipeline.loaders.utils import ( + check_data_query_args, + normalize_data_query_bounds, + normalize_timestamp_to_query_time, +) from zipline.lib.adjusted_array import AdjustedArray from zipline.lib.adjustment import Float64Overwrite from zipline.utils.enum import enum -from zipline.utils.input_validation import expect_element +from zipline.utils.input_validation import ( + expect_element, + ensure_timezone, + optionally, +) from zipline.utils.numpy_utils import repeat_last_axis +from zipline.utils.preprocess import preprocess AD_FIELD_NAME = 'asof_date' @@ -764,9 +774,28 @@ def adjustments_from_deltas_with_sids(dates, class BlazeLoader(dict): - def __init__(self, colmap=None): + """A PipelineLoader for datasets constructed with ``from_blaze``. + + Parameters + ---------- + colmap : mapping[BoundColumn -> tuple[Expr, Expr, any]], optional + The initial column mapping to use. + data_query_time : time, optional + The time to use for the data query cutoff. + data_query_tz : tzinfo or str + The timezeone to use for the data query cutoff. + """ + @preprocess(data_query_tz=optionally(ensure_timezone)) + def __init__(self, + colmap=None, + data_query_time=None, + data_query_tz=None): self.update(colmap or {}) + check_data_query_args(data_query_time, data_query_tz) + self._data_query_time = data_query_time + self._data_query_tz = data_query_tz + @classmethod @memoize(cache=WeakKeyDictionary()) def global_instance(cls): @@ -802,6 +831,15 @@ class BlazeLoader(dict): [SID_FIELD_NAME] if have_sids else [] ) + data_query_time = self._data_query_time + data_query_tz = self._data_query_tz + lower_dt, upper_dt = normalize_data_query_bounds( + dates[0], + dates[-1], + data_query_time, + data_query_tz, + ) + def where(e): """Create the query to run against the resources. @@ -819,8 +857,8 @@ class BlazeLoader(dict): # Hack to get the lower bound to query: # This must be strictly executed because the data for `ts` will # be removed from scope too early otherwise. - lower = odo(ts[ts <= dates[0]].max(), pd.Timestamp) - selection = ts <= dates[-1] + lower = odo(ts[ts <= lower_dt].max(), pd.Timestamp) + selection = ts <= upper_dt if have_sids: selection &= e[SID_FIELD_NAME].isin(assets) if lower is not pd.NaT: @@ -836,6 +874,20 @@ class BlazeLoader(dict): pd.DataFrame(columns=query_fields) ) + if data_query_time is not None: + for m in (materialized_expr, materialized_deltas): + m.loc[:, TS_FIELD_NAME] = m.loc[ + :, TS_FIELD_NAME + ].astype('datetime64[ns]') + + normalize_timestamp_to_query_time( + m, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) + # Inline the deltas that changed our most recently known value. # Also, we reindex by the dates to create a dense representation of # the data. @@ -978,7 +1030,7 @@ def ffill_query_in_range(expr, # range. It must all be null anyways. computed_lower = lower - return odo( + raw = odo( expr[ (expr[ts_field] >= computed_lower) & (expr[ts_field] <= upper) @@ -986,3 +1038,5 @@ def ffill_query_in_range(expr, pd.DataFrame, **odo_kwargs ) + raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]') + return raw diff --git a/zipline/pipeline/loaders/blaze/earnings.py b/zipline/pipeline/loaders/blaze/earnings.py index 2c4f8dfb..9f5137f8 100644 --- a/zipline/pipeline/loaders/blaze/earnings.py +++ b/zipline/pipeline/loaders/blaze/earnings.py @@ -11,6 +11,13 @@ from .core import ( from zipline.pipeline.data import EarningsCalendar from zipline.pipeline.loaders.base import PipelineLoader from zipline.pipeline.loaders.earnings import EarningsCalendarLoader +from zipline.pipeline.loaders.utils import ( + check_data_query_args, + normalize_data_query_bounds, + normalize_timestamp_to_query_time, +) +from zipline.utils.input_validation import ensure_timezone, optionally +from zipline.utils.preprocess import preprocess ANNOUNCEMENT_FIELD_NAME = 'announcement_date' @@ -28,6 +35,10 @@ class BlazeEarningsCalendarLoader(PipelineLoader): Mapping from the atomic terms of ``expr`` to actual data resources. odo_kwargs : dict, optional Extra keyword arguments to pass to odo when executing the expression. + data_query_time : time, optional + The time to use for the data query cutoff. + data_query_tz : tzinfo or str + The timezeone to use for the data query cutoff. Notes ----- @@ -58,10 +69,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ANNOUNCEMENT_FIELD_NAME, }) + @preprocess(data_query_tz=optionally(ensure_timezone)) def __init__(self, expr, resources=None, odo_kwargs=None, + data_query_time=None, + data_query_tz=None, dataset=EarningsCalendar): dshape = expr.dshape @@ -77,12 +91,24 @@ class BlazeEarningsCalendarLoader(PipelineLoader): ) self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {} self._dataset = dataset + check_data_query_args(data_query_time, data_query_tz) + self._data_query_time = data_query_time + self._data_query_tz = data_query_tz def load_adjusted_array(self, columns, dates, assets, mask): - raw = ffill_query_in_range( - self._expr, + data_query_time = self._data_query_time + data_query_tz = self._data_query_tz + lower_dt, upper_dt = normalize_data_query_bounds( dates[0], dates[-1], + data_query_time, + data_query_tz, + ) + + raw = ffill_query_in_range( + self._expr, + lower_dt, + upper_dt, self._odo_kwargs, ) sids = raw.loc[:, SID_FIELD_NAME] @@ -90,6 +116,14 @@ class BlazeEarningsCalendarLoader(PipelineLoader): sids[~sids.isin(assets)].index, inplace=True ) + if data_query_time is not None: + normalize_timestamp_to_query_time( + raw, + data_query_time, + data_query_tz, + inplace=True, + ts_field=TS_FIELD_NAME, + ) gb = raw.groupby(SID_FIELD_NAME) diff --git a/zipline/pipeline/loaders/utils.py b/zipline/pipeline/loaders/utils.py index 43dedfce..e5682461 100644 --- a/zipline/pipeline/loaders/utils.py +++ b/zipline/pipeline/loaders/utils.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pandas as pd from six import iteritems @@ -93,3 +95,136 @@ def previous_date_frame(date_index, events_by_sid): frame = pd.DataFrame(out, index=date_index, columns=sids) frame.ffill(inplace=True) return frame + + +def normalize_data_query_time(dt, time, tz): + """Apply the correct time and timezone to a date. + + Parameters + ---------- + dt : pd.Timestamp + The original datetime that represents the date. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + + Returns + ------- + query_dt : pd.Timestamp + The timestamp with the correct time and date in utc. + """ + # merge the correct date with the time in the given timezone then convert + # back to utc + return pd.Timestamp( + datetime.datetime.combine(dt.date(), time), + tz=tz, + ).tz_convert('utc') + + +def normalize_data_query_bounds(lower, upper, time, tz): + """Adjust the first and last dates in the requested datetime index based on + the provided query time and tz. + + lower : pd.Timestamp + The lower date requested. + upper : pd.Timestamp + The upper date requested. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + """ + # Subtract one day to grab things that happened on the first day we are + # requesting. This doesn't need to be a trading day, we are only adding + # a lower bound to limit the amount of in memory filtering that needs + # to happen. + lower -= datetime.timedelta(days=1) + if time is not None: + return normalize_data_query_time( + lower, + time, + tz, + ), normalize_data_query_time( + upper, + time, + tz, + ) + return lower, upper + + +def normalize_timestamp_to_query_time(df, + time, + tz, + inplace=False, + ts_field='timestamp'): + """Update the timestamp field of a dataframe to normalize dates around + some data query time/timezone. + + Parameters + ---------- + df : pd.DataFrame + The dataframe to update. This needs a column named ``ts_field``. + time : datetime.time + The time of day to use as the cutoff point for new data. Data points + that you learn about after this time will become available to your + algorithm on the next trading day. + tz : tzinfo + The timezone to normalize your dates to before comparing against + `time`. + inplace : bool, optional + Update the dataframe in place. + ts_field : str, optional + The name of the timestamp field in ``df``. + + Returns + ------- + df : pd.DataFrame + The dataframe with the timestamp field normalized. If ``inplace`` is + true, then this will be the same object as ``df`` otherwise this will + be a copy. + """ + if not inplace: + # don't mutate the dataframe in place + df = df.copy() + + dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc') + dtidx_local_time = dtidx.tz_convert(tz) + to_roll_forward = dtidx_local_time.time >= time + # for all of the times that are greater than our query time add 1 + # day and truncate to the date + df.loc[to_roll_forward, ts_field] = ( + dtidx_local_time[to_roll_forward] + datetime.timedelta(days=1) + ).normalize().tz_localize(None).tz_localize('utc') # cast back to utc + df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize() + return df + + +def check_data_query_args(data_query_time, data_query_tz): + """Checks the data_query_time and data_query_tz arguments for loaders + and raises a standard exception if one is None and the other is not. + + Parameters + ---------- + data_query_time : datetime.time or None + data_query_tz : tzinfo or None + + Raises + ------ + ValueError + Raised when only one of the arguments is None. + """ + if (data_query_time is None) ^ (data_query_tz is None): + raise ValueError( + "either 'data_query_time' and 'data_query_tz' must both be" + " None or neither may be None (got %r, %r)" % ( + data_query_time, + data_query_tz, + ), + ) diff --git a/zipline/utils/input_validation.py b/zipline/utils/input_validation.py index 4d7256a5..d04869fc 100644 --- a/zipline/utils/input_validation.py +++ b/zipline/utils/input_validation.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial +from datetime import tzinfo +from functools import partial, wraps from operator import attrgetter from numpy import dtype +from pytz import timezone from six import iteritems, string_types, PY3 from toolz import valmap, complement, compose import toolz.curried.operator as op @@ -22,6 +24,46 @@ import toolz.curried.operator as op from zipline.utils.preprocess import preprocess +def optionally(preprocessor): + """Modify a preprocessor to explicitly allow `None`. + + Parameters + ---------- + preprocessor : callable[callable, str, any -> any] + A preprocessor to delegate to when `arg is not None`. + + Returns + ------- + optional_preprocessor : callable[callable, str, any -> any] + A preprocessor that delegates to `preprocessor` when `arg is not None`. + + Usage + ----- + >>> def preprocessor(func, argname, arg): + ... if not isinstance(arg, int): + ... raise TypeError('arg must be int') + ... return arg + ... + >>> @preprocess(a=optionally(preprocessor)) + ... def f(a): + ... return a + ... + >>> f(1) # call with int + 1 + >>> f('a') # call with not int + Traceback (most recent call last): + ... + TypeError: arg must be int + >>> f(None) is None # call with explicit None + True + """ + @wraps(preprocessor) + def wrapper(func, argname, arg): + return arg if arg is None else preprocessor(func, argname, arg) + + return wrapper + + def ensure_upper_case(func, argname, arg): if isinstance(arg, string_types): return arg.upper() @@ -61,6 +103,33 @@ def ensure_dtype(func, argname, arg): ) +def ensure_timezone(func, argname, arg): + """Argument preprocessor that converts the input into a tzinfo object. + + Usage + ----- + >>> from zipline.utils.preprocess import preprocess + >>> @preprocess(tz=ensure_timezone) + ... def foo(tz): + ... return tz + >>> foo('utc') + + """ + if isinstance(arg, tzinfo): + return arg + if isinstance(arg, string_types): + return timezone(arg) + + raise TypeError( + "{func}() couldn't convert argument " + "{argname}={arg!r} to a timezone.".format( + func=_qualified_name(func), + argname=argname, + arg=arg, + ), + ) + + def expect_dtypes(*_pos, **named): """ Preprocessing decorator that verifies inputs have expected numpy dtypes.