mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 05:32:55 +08:00
ENH: allows users to specify the cutoff time for data query in blaze
loaders This allows people to set their cutoff time to the time they will actually execute 'before_trading_start'. Currently this is just passed to the constructor of the loader; however, I would like to make this managed by the algorithm simulation runner. This would help keep all of the loaders in sync and lock 'before_trading_start's execution to the time the data is queried for.
This commit is contained in:
@@ -4,7 +4,7 @@ Tests for the blaze interface to the pipeline api.
|
||||
from __future__ import division
|
||||
|
||||
from collections import OrderedDict
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta, time
|
||||
from unittest import TestCase
|
||||
import warnings
|
||||
|
||||
@@ -323,6 +323,43 @@ class BlazeToPipelineTestCase(TestCase):
|
||||
))
|
||||
assert_frame_equal(result, expected, check_dtype=False)
|
||||
|
||||
def test_custom_query_time_tz(self):
|
||||
df = self.df.copy()
|
||||
df['timestamp'] = (
|
||||
pd.DatetimeIndex(df['timestamp'], tz='EST') +
|
||||
timedelta(hours=8, minutes=45)
|
||||
).tz_convert('utc')
|
||||
df.ix[3:5, 'timestamp'] = pd.Timestamp('2014-01-01 13:46', tz='utc')
|
||||
expr = bz.Data(df, name='expr', dshape=self.dshape)
|
||||
loader = BlazeLoader(data_query_time=time(8, 45), data_query_tz='EST')
|
||||
ds = from_blaze(
|
||||
expr,
|
||||
loader=loader,
|
||||
no_deltas_rule=no_deltas_rules.ignore,
|
||||
)
|
||||
p = Pipeline()
|
||||
p.add(ds.value.latest, 'value')
|
||||
dates = self.dates
|
||||
|
||||
with tmp_asset_finder() as finder:
|
||||
result = SimplePipelineEngine(
|
||||
loader,
|
||||
dates,
|
||||
finder,
|
||||
).run_pipeline(p, dates[0], dates[-1])
|
||||
|
||||
expected = df.drop('asof_date', axis=1)
|
||||
expected['timestamp'] = expected['timestamp'].dt.normalize().astype(
|
||||
'datetime64[ns]',
|
||||
)
|
||||
expected.ix[3:5, 'timestamp'] += timedelta(days=1)
|
||||
expected.set_index(['timestamp', 'sid'], inplace=True)
|
||||
expected.index = pd.MultiIndex.from_product((
|
||||
expected.index.levels[0],
|
||||
finder.retrieve_all(expected.index.levels[1]),
|
||||
))
|
||||
assert_frame_equal(result, expected, check_dtype=False)
|
||||
|
||||
def test_id_macro_dataset(self):
|
||||
expr = bz.Data(self.macro_df, name='expr', dshape=self.macro_dshape)
|
||||
loader = BlazeLoader()
|
||||
|
||||
@@ -7,10 +7,12 @@ from unittest import TestCase
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from numpy import arange, dtype
|
||||
import pytz
|
||||
from six import PY3
|
||||
|
||||
from zipline.utils.preprocess import call, preprocess
|
||||
from zipline.utils.input_validation import (
|
||||
ensure_timezone,
|
||||
expect_element,
|
||||
expect_dtypes,
|
||||
expect_types,
|
||||
@@ -317,3 +319,31 @@ class PreprocessTestCase(TestCase):
|
||||
"or 'float64' for argument 'a', but got 'uint32' instead."
|
||||
).format(qualname=qualname(foo))
|
||||
self.assertEqual(e.exception.args[0], expected_message)
|
||||
|
||||
def test_ensure_timezone(self):
|
||||
@preprocess(tz=ensure_timezone)
|
||||
def f(tz):
|
||||
return tz
|
||||
|
||||
valid = {
|
||||
'utc',
|
||||
'EST',
|
||||
'US/Eastern',
|
||||
}
|
||||
invalid = {
|
||||
# unfortunatly, these are not actually timezones (yet)
|
||||
'ayy',
|
||||
'lmao',
|
||||
}
|
||||
|
||||
# test coercing from string
|
||||
for tz in valid:
|
||||
self.assertEqual(f(tz), pytz.timezone(tz))
|
||||
|
||||
# test pass through of tzinfo objects
|
||||
for tz in map(pytz.timezone, valid):
|
||||
self.assertEqual(f(tz), tz)
|
||||
|
||||
# test invalid timezone strings
|
||||
for tz in invalid:
|
||||
self.assertRaises(pytz.UnknownTimeZoneError, f, tz)
|
||||
|
||||
@@ -127,6 +127,7 @@ from __future__ import division, absolute_import
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from collections import namedtuple, defaultdict
|
||||
from copy import copy
|
||||
from datetime import time
|
||||
from functools import partial
|
||||
from itertools import count
|
||||
import warnings
|
||||
@@ -144,6 +145,7 @@ from datashape import (
|
||||
)
|
||||
from odo import odo
|
||||
import pandas as pd
|
||||
from six import with_metaclass, PY2, itervalues, iteritems
|
||||
from toolz import (
|
||||
complement,
|
||||
compose,
|
||||
@@ -154,15 +156,19 @@ from toolz import (
|
||||
memoize,
|
||||
)
|
||||
import toolz.curried.operator as op
|
||||
from six import with_metaclass, PY2, itervalues, iteritems
|
||||
|
||||
|
||||
from zipline.pipeline.data.dataset import DataSet, Column
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
normalize_data_query_time,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
from zipline.utils.enum import enum
|
||||
from zipline.utils.input_validation import expect_element
|
||||
from zipline.utils.input_validation import expect_element, ensure_timezone
|
||||
from zipline.utils.numpy_utils import repeat_last_axis
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
AD_FIELD_NAME = 'asof_date'
|
||||
@@ -764,8 +770,25 @@ def adjustments_from_deltas_with_sids(dates,
|
||||
|
||||
|
||||
class BlazeLoader(dict):
|
||||
def __init__(self, colmap=None):
|
||||
"""A PipelineLoader for datasets constructed with ``from_blaze``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
colmap : mapping[BoundColumn -> tuple[Expr, Expr, any]], optional
|
||||
The initial column mapping to use.
|
||||
data_query_time : time, optional
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
"""
|
||||
@preprocess(data_query_tz=ensure_timezone)
|
||||
def __init__(self,
|
||||
colmap=None,
|
||||
data_query_time=time(0),
|
||||
data_query_tz='utc'):
|
||||
self.update(colmap or {})
|
||||
self._data_query_time = data_query_time
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
@classmethod
|
||||
@memoize(cache=WeakKeyDictionary())
|
||||
@@ -802,6 +825,19 @@ class BlazeLoader(dict):
|
||||
[SID_FIELD_NAME] if have_sids else []
|
||||
)
|
||||
|
||||
data_query_time = self._data_query_time
|
||||
data_query_tz = self._data_query_tz
|
||||
lower_dt = normalize_data_query_time(
|
||||
dates[0],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
upper_dt = normalize_data_query_time(
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
|
||||
def where(e):
|
||||
"""Create the query to run against the resources.
|
||||
|
||||
@@ -819,8 +855,8 @@ class BlazeLoader(dict):
|
||||
# Hack to get the lower bound to query:
|
||||
# This must be strictly executed because the data for `ts` will
|
||||
# be removed from scope too early otherwise.
|
||||
lower = odo(ts[ts <= dates[0]].max(), pd.Timestamp)
|
||||
selection = ts <= dates[-1]
|
||||
lower = odo(ts[ts <= lower_dt].max(), pd.Timestamp)
|
||||
selection = ts <= upper_dt
|
||||
if have_sids:
|
||||
selection &= e[SID_FIELD_NAME].isin(assets)
|
||||
if lower is not pd.NaT:
|
||||
@@ -830,11 +866,32 @@ class BlazeLoader(dict):
|
||||
|
||||
extra_kwargs = {'d': resources} if resources else {}
|
||||
materialized_expr = odo(where(expr), pd.DataFrame, **extra_kwargs)
|
||||
materialized_expr[TS_FIELD_NAME] = materialized_expr[
|
||||
TS_FIELD_NAME
|
||||
].astype('datetime64[ns]')
|
||||
materialized_deltas = (
|
||||
odo(where(deltas), pd.DataFrame, **extra_kwargs)
|
||||
if deltas is not None else
|
||||
pd.DataFrame(columns=query_fields)
|
||||
)
|
||||
materialized_deltas[TS_FIELD_NAME] = materialized_deltas[
|
||||
TS_FIELD_NAME
|
||||
].astype('datetime64[ns]')
|
||||
|
||||
normalize_timestamp_to_query_time(
|
||||
materialized_expr,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
normalize_timestamp_to_query_time(
|
||||
materialized_deltas,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
|
||||
# Inline the deltas that changed our most recently known value.
|
||||
# Also, we reindex by the dates to create a dense representation of
|
||||
@@ -978,7 +1035,7 @@ def ffill_query_in_range(expr,
|
||||
# range. It must all be null anyways.
|
||||
computed_lower = lower
|
||||
|
||||
return odo(
|
||||
raw = odo(
|
||||
expr[
|
||||
(expr[ts_field] >= computed_lower) &
|
||||
(expr[ts_field] <= upper)
|
||||
@@ -986,3 +1043,5 @@ def ffill_query_in_range(expr,
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
)
|
||||
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
|
||||
return raw
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
from datashape import istabular
|
||||
import pandas as pd
|
||||
from toolz import valmap
|
||||
@@ -11,6 +13,12 @@ from .core import (
|
||||
from zipline.pipeline.data import EarningsCalendar
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.earnings import EarningsCalendarLoader
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
normalize_data_query_time,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_timezone
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
|
||||
@@ -28,6 +36,10 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
|
||||
Notes
|
||||
-----
|
||||
@@ -58,10 +70,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
ANNOUNCEMENT_FIELD_NAME,
|
||||
})
|
||||
|
||||
@preprocess(data_query_tz=ensure_timezone)
|
||||
def __init__(self,
|
||||
expr,
|
||||
resources=None,
|
||||
odo_kwargs=None,
|
||||
data_query_time=datetime.time(0),
|
||||
data_query_tz='utc',
|
||||
dataset=EarningsCalendar):
|
||||
dshape = expr.dshape
|
||||
|
||||
@@ -77,12 +92,24 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
)
|
||||
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
|
||||
self._dataset = dataset
|
||||
self._data_query_time = data_query_time
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
data_query_time = self._data_query_time
|
||||
data_query_tz = self._data_query_tz
|
||||
raw = ffill_query_in_range(
|
||||
self._expr,
|
||||
dates[0],
|
||||
dates[-1],
|
||||
normalize_data_query_time(
|
||||
dates[0],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
),
|
||||
normalize_data_query_time(
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
),
|
||||
self._odo_kwargs,
|
||||
)
|
||||
sids = raw.loc[:, SID_FIELD_NAME]
|
||||
@@ -90,6 +117,13 @@ class BlazeEarningsCalendarLoader(PipelineLoader):
|
||||
sids[~sids.isin(assets)].index,
|
||||
inplace=True
|
||||
)
|
||||
normalize_timestamp_to_query_time(
|
||||
raw,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
|
||||
gb = raw.groupby(SID_FIELD_NAME)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from six import iteritems
|
||||
@@ -93,3 +95,78 @@ def previous_date_frame(date_index, events_by_sid):
|
||||
frame = pd.DataFrame(out, index=date_index, columns=sids)
|
||||
frame.ffill(inplace=True)
|
||||
return frame
|
||||
|
||||
|
||||
def normalize_data_query_time(dt, time, tz):
|
||||
"""Apply the correct time and timezone to a date.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt : pd.Timestamp
|
||||
The original datetime that represents the date.
|
||||
time : datetime.time
|
||||
The time of day to use as the cutoff point for new data. Data points
|
||||
that you learn about after this time will become available to your
|
||||
algorithm on the next trading day.
|
||||
tz : tzinfo
|
||||
The timezone to normalize your dates to before comparing against
|
||||
`time`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
query_dt : pd.Timestamp
|
||||
The timestamp with the correct time and date in utc.
|
||||
"""
|
||||
# merge the correct date with the time in the given timezone then convert
|
||||
# back to utc
|
||||
return pd.Timestamp(
|
||||
datetime.datetime.combine(dt.date(), time),
|
||||
tz=tz,
|
||||
).tz_convert('utc')
|
||||
|
||||
|
||||
def normalize_timestamp_to_query_time(df,
|
||||
time,
|
||||
tz,
|
||||
inplace=False,
|
||||
ts_field='timestamp'):
|
||||
"""Update the timestamp field of a dataframe to normalize dates around
|
||||
some data query time/timezone.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The dataframe to update. This needs a column named ``ts_field``.
|
||||
time : datetime.time
|
||||
The time of day to use as the cutoff point for new data. Data points
|
||||
that you learn about after this time will become available to your
|
||||
algorithm on the next trading day.
|
||||
tz : tzinfo
|
||||
The timezone to normalize your dates to before comparing against
|
||||
`time`.
|
||||
inplace : bool, optional
|
||||
Update the dataframe in place.
|
||||
ts_field : str, optional
|
||||
The name of the timestamp field in ``df``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
df : pd.DataFrame
|
||||
The dataframe with the timestamp field normalized. If ``inplace`` is
|
||||
true, then this will be the same object as ``df`` otherwise this will
|
||||
be a copy.
|
||||
"""
|
||||
if not inplace:
|
||||
# don't mutate the dataframe in place
|
||||
df = df.copy()
|
||||
|
||||
dtidx = pd.DatetimeIndex(df.loc[:, ts_field], tz='utc')
|
||||
dtidx_local_time = dtidx.tz_convert(tz)
|
||||
to_roll_forward = dtidx_local_time.time > time
|
||||
# for all of the times that are greater than our query time add 1
|
||||
# day and truncate to the date
|
||||
df.loc[to_roll_forward, ts_field] = (
|
||||
dtidx_local_time[to_roll_forward] + datetime.timedelta(days=1)
|
||||
).normalize().tz_localize(None).tz_localize('utc') # cast back to utc
|
||||
df.loc[~to_roll_forward, ts_field] = dtidx[~to_roll_forward].normalize()
|
||||
return df
|
||||
|
||||
@@ -11,10 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from datetime import tzinfo
|
||||
from functools import partial
|
||||
from operator import attrgetter
|
||||
|
||||
from numpy import dtype
|
||||
from pytz import timezone
|
||||
from six import iteritems, string_types, PY3
|
||||
from toolz import valmap, complement, compose
|
||||
import toolz.curried.operator as op
|
||||
@@ -61,6 +63,33 @@ def ensure_dtype(func, argname, arg):
|
||||
)
|
||||
|
||||
|
||||
def ensure_timezone(func, argname, arg):
|
||||
"""Argument preprocessor that converts the input into a tzinfo object.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> from zipline.utils.preprocess import preprocess
|
||||
>>> @preprocess(tz=ensure_timezone)
|
||||
... def foo(tz):
|
||||
... return tz
|
||||
>>> foo('utc')
|
||||
<UTC>
|
||||
"""
|
||||
if isinstance(arg, tzinfo):
|
||||
return arg
|
||||
if isinstance(arg, string_types):
|
||||
return timezone(arg)
|
||||
|
||||
raise TypeError(
|
||||
"{func}() couldn't convert argument "
|
||||
"{argname}={arg!r} to a timezone.".format(
|
||||
func=_qualified_name(func),
|
||||
argname=argname,
|
||||
arg=arg,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def expect_dtypes(*_pos, **named):
|
||||
"""
|
||||
Preprocessing decorator that verifies inputs have expected numpy dtypes.
|
||||
|
||||
Reference in New Issue
Block a user