mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 22:52:02 +08:00
ENH: Adds a blaze pipeline loader.
This commit is contained in:
@@ -105,7 +105,7 @@ class DataSetMeta(type):
|
||||
setattr(newtype, maybe_colname, bound_column)
|
||||
_columns.append(bound_column)
|
||||
|
||||
newtype._columns = _columns
|
||||
newtype._columns = frozenset(_columns)
|
||||
return newtype
|
||||
|
||||
@property
|
||||
|
||||
@@ -0,0 +1,588 @@
|
||||
from __future__ import division
|
||||
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from collections import namedtuple
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
|
||||
import blaze as bz
|
||||
from datashape import (
|
||||
Date,
|
||||
DateTime,
|
||||
Option,
|
||||
float64,
|
||||
isrecord,
|
||||
isscalar,
|
||||
promote,
|
||||
)
|
||||
from logbook import Logger
|
||||
from numpy.lib.stride_tricks import as_strided
|
||||
from odo import odo
|
||||
import pandas as pd
|
||||
from pytz import utc
|
||||
from toolz import flip, memoize, compose, complement, identity
|
||||
from six import with_metaclass
|
||||
|
||||
|
||||
from ..data.dataset import DataSet, Column
|
||||
from zipline.lib.adjusted_array import adjusted_array
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
|
||||
|
||||
AD_FIELD_NAME = 'asof_date'
|
||||
TS_FIELD_NAME = 'timestamp'
|
||||
SID_FIELD_NAME = 'sid'
|
||||
valid_deltas_node_types = (
|
||||
bz.expr.Field,
|
||||
bz.expr.ReLabel,
|
||||
bz.expr.Symbol,
|
||||
)
|
||||
getname = attrgetter('__name__')
|
||||
log = Logger(__name__)
|
||||
|
||||
|
||||
class ExprData(namedtuple('ExprData', 'expr deltas resources')):
|
||||
"""A pair of expressions and a data resources.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
epxr : Expr
|
||||
The first known values.
|
||||
deltas : Expr, optional
|
||||
The deltas for the data.
|
||||
resources : resource or dict of resources, optional
|
||||
The resources to compute the exprs against.
|
||||
"""
|
||||
def __new__(cls, expr, deltas=None, resources=None):
|
||||
return super(ExprData, cls).__new__(cls, expr, deltas, resources)
|
||||
|
||||
def __repr__(self):
|
||||
# If the expressions have _resources() then the repr will
|
||||
# drive computation so we str them.
|
||||
cls = type(self)
|
||||
return super(ExprData, cls).__repr__(cls(
|
||||
str(self.expr),
|
||||
str(self.deltas),
|
||||
self.resources,
|
||||
))
|
||||
|
||||
|
||||
class InvalidField(with_metaclass(ABCMeta)):
|
||||
"""A field that raises an exception that indicates that the
|
||||
field was invalid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
field : str
|
||||
The name of the field.
|
||||
type_ : dshape
|
||||
The shape of the field.
|
||||
"""
|
||||
@abstractproperty
|
||||
def error_format(self):
|
||||
raise NotImplementedError('error_format')
|
||||
|
||||
def __init__(self, field, type_):
|
||||
self._field = field
|
||||
self._type = type_
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
raise AttributeError(
|
||||
self.error_format.format(field=self._field, type_=self._type),
|
||||
)
|
||||
|
||||
|
||||
class NonNumpyField(InvalidField):
|
||||
error_format = "field '{field}' was a non numpy compatible type: '{type_}'"
|
||||
|
||||
|
||||
class NonPipelineField(InvalidField):
|
||||
error_format = (
|
||||
"field '{field}' was a non pipeline API compatible type:"
|
||||
" '{type_.__name__}'"
|
||||
)
|
||||
|
||||
|
||||
class NotPipelineCompatible(TypeError):
|
||||
"""Exception used to indicate that a dshape is not pipeline api
|
||||
compatible.
|
||||
"""
|
||||
def __str__(self):
|
||||
return "'%s' is a non pipleine API compatible type'" % self.args
|
||||
|
||||
|
||||
@memoize
|
||||
def new_dataset(expr, deltas):
|
||||
"""Creates or returns a dataset from a pair of blaze expressions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The blaze expression representing the first known values.
|
||||
deltas : Expr
|
||||
The blaze expression representing the deltas to the data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
ds : type
|
||||
A new dataset type.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This function is memoized, repeated calls will return the same type.
|
||||
"""
|
||||
columns = {}
|
||||
for name, type_ in expr.dshape.measure.fields:
|
||||
try:
|
||||
if promote(type_, float64, promote_option=False) != float64:
|
||||
raise NotPipelineCompatible
|
||||
if isinstance(type_, Option):
|
||||
type_ = type_.ty
|
||||
except TypeError:
|
||||
col = NonNumpyField(name, type_)
|
||||
except NotPipelineCompatible:
|
||||
col = NonPipelineField(name, type_)
|
||||
else:
|
||||
col = Column(type_.to_numpy_dtype().type)
|
||||
|
||||
columns[name] = col
|
||||
|
||||
return type(expr._name, (DataSet,), columns)
|
||||
|
||||
|
||||
def _check_resources(name, expr, resources):
|
||||
"""Validate that the exprssion and resources passed match up.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the argument we are checking.
|
||||
expr : Expr
|
||||
The potentially bound expr.
|
||||
resources
|
||||
The explicitly passed resources to compute expr.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
if the resources to not match for an expression
|
||||
"""
|
||||
if expr is None:
|
||||
return
|
||||
bound = expr._resources()
|
||||
if not bound and resources is None:
|
||||
raise ValueError('no resources provided to compute %s' % name)
|
||||
if bound and resources:
|
||||
raise ValueError(
|
||||
'explicit and implicit resources provided to compute %s' % name,
|
||||
)
|
||||
|
||||
|
||||
def _check_datetime_field(name, measure):
|
||||
"""Check that a field is a datetime inside some measure.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the field to check.
|
||||
measure : Record
|
||||
The record to check the field of.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
if the field is not a datetime inside ``measure``
|
||||
"""
|
||||
if not isinstance(measure[name], (Date, DateTime)):
|
||||
raise TypeError(
|
||||
"'{n}' field must be a '{dt}', not: '{dshape}'".format(
|
||||
name=name,
|
||||
dt=DateTime(),
|
||||
dshape=measure[name],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _get_deltas(expr, deltas, no_deltas_rule):
|
||||
"""Find the correct deltas for the expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The base expression.
|
||||
deltas : Expr, 'auto', or None
|
||||
The deltas argument. If this is 'auto', then the deltas table will
|
||||
be searched for by walking up the expression tree. If this can not be
|
||||
reflected, then an action will be taken based on the 'no_deltas_rule'.
|
||||
no_deltas_rule : {'log', 'raise', 'ignore'}
|
||||
How to handle the case where deltas='auto' but no deltas could be
|
||||
found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
deltas : Expr or None
|
||||
The deltas table to use.
|
||||
"""
|
||||
if no_deltas_rule not in _get_deltas.valid_no_deltas_rules:
|
||||
raise ValueError(
|
||||
'no_deltas_rule must be one of: %s' %
|
||||
_get_deltas.valid_no_deltas_rules
|
||||
)
|
||||
|
||||
if deltas != 'auto':
|
||||
return deltas
|
||||
|
||||
try:
|
||||
return expr._child[expr._name + '_deltas']
|
||||
except (AttributeError, KeyError):
|
||||
if no_deltas_rule == 'raise':
|
||||
raise ValueError(
|
||||
"no deltas table could be reflected for '%s'" % expr
|
||||
)
|
||||
elif no_deltas_rule == 'log':
|
||||
log.warn("no deltas table found for '%s'" % expr)
|
||||
return None
|
||||
|
||||
_get_deltas.valid_no_deltas_rules = 'log', 'raise', 'ignore'
|
||||
|
||||
|
||||
def pipeline_api_from_blaze(expr,
|
||||
deltas='auto',
|
||||
loader=None,
|
||||
resources=None,
|
||||
no_deltas_rule='log'):
|
||||
"""Create a pipeline api object from a blaze expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The blaze expression to use.
|
||||
deltas : Expr or 'auto', optional
|
||||
The expression to use for the point in time adjustments.
|
||||
If the string 'auto' is passed, a deltas expr will be looked up
|
||||
by stepping up the expression tree and looking for another field
|
||||
with the name of ``expr`` + '_deltas'. If None is passed, no deltas
|
||||
will be used.
|
||||
loader : BlazeLoader, optional
|
||||
The blaze loader to attach this pipeline dataset to. If none is passed,
|
||||
the global blaze loader is used.
|
||||
resources : dict or any, optional
|
||||
The data to execute the blaze expressions against. This is used as the
|
||||
scope for ``bz.compute``.
|
||||
no_deltas_rule : {'log', 'raise', 'ignore'}
|
||||
What should happen if ``deltas='auto'`` but no deltas can be found.
|
||||
'log' says to log a message but continue.
|
||||
'raise' says to raise an exception if no deltas can be found.
|
||||
'ignore' says take no action and proceed with no deltas.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pipeline_api_obj : DataSet or BoundColumn
|
||||
Either a new dataset or bound column based on the shape of the expr
|
||||
passed in. If a tabular shaped expression is passed, this will return
|
||||
a ``DataSet`` that represents the whole table. If an array-like shape
|
||||
is passed, a ``BoundColumn`` on the dataset that would be constructed
|
||||
from passing the parent is returned.
|
||||
"""
|
||||
# Check if this is a single column out of a dataset.
|
||||
single_column = None
|
||||
if isscalar(expr.dshape.measure):
|
||||
# This is a single column, record which column we are to return
|
||||
# but create the entire dataset.
|
||||
single_column = expr._name
|
||||
col = expr
|
||||
for expr in expr._subterms():
|
||||
if isrecord(expr.dshape.measure):
|
||||
break
|
||||
else:
|
||||
expr = bz.Data({single_column: col})
|
||||
|
||||
deltas = _get_deltas(expr, deltas, no_deltas_rule)
|
||||
if deltas is not None:
|
||||
invalid_nodes = tuple(filter(
|
||||
complement(flip(isinstance, valid_deltas_node_types)),
|
||||
expr._subterms(),
|
||||
))
|
||||
if invalid_nodes:
|
||||
raise TypeError(
|
||||
'expression with deltas may only contain (%s) nodes,'
|
||||
" found: %s" % (
|
||||
', '.join(map(getname, valid_deltas_node_types)),
|
||||
', '.join(map(compose(getname, type), invalid_nodes)),
|
||||
),
|
||||
)
|
||||
|
||||
measure = expr.dshape.measure
|
||||
if not isrecord(measure) or AD_FIELD_NAME not in measure.names:
|
||||
raise TypeError(
|
||||
"expr must be a collection of records with at least an '{ad}'"
|
||||
" field. Fields provided: '{fields}'\nhint: maybe you need to use "
|
||||
' `relabel` to change your field names'.format(
|
||||
ad=AD_FIELD_NAME,
|
||||
fields=measure,
|
||||
),
|
||||
)
|
||||
_check_datetime_field(AD_FIELD_NAME, measure)
|
||||
|
||||
if TS_FIELD_NAME not in measure.names:
|
||||
expr = bz.transform(expr, **{TS_FIELD_NAME: expr[AD_FIELD_NAME]})
|
||||
if deltas is not None:
|
||||
deltas = bz.transform(
|
||||
deltas,
|
||||
**{TS_FIELD_NAME: deltas[AD_FIELD_NAME]}
|
||||
)
|
||||
else:
|
||||
_check_datetime_field(TS_FIELD_NAME, measure)
|
||||
|
||||
if deltas is not None and deltas.dshape.measure != measure:
|
||||
raise TypeError(
|
||||
"base measure != deltas measure ('%s' != '%s')" % (
|
||||
measure, deltas.dshape.measure,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure that we have a data resource to execute the query against.
|
||||
_check_resources('expr', expr, resources)
|
||||
_check_resources('deltas', deltas, resources)
|
||||
|
||||
# Create or retrieve the pipeline api dataset.
|
||||
ds = new_dataset(expr, deltas)
|
||||
# Register our new dataset with the loader.
|
||||
(loader if loader is not None else global_loader)[ds] = ExprData(
|
||||
expr,
|
||||
deltas,
|
||||
resources,
|
||||
)
|
||||
if single_column is not None:
|
||||
# We were passed a single column, extract and return it.
|
||||
return getattr(ds, single_column)
|
||||
return ds
|
||||
|
||||
|
||||
getdataset = attrgetter('dataset')
|
||||
dataset_name = attrgetter('name')
|
||||
|
||||
|
||||
def inline_novel_deltas(base, deltas, dates):
|
||||
"""Inline any deltas into the base set that would have changed our most
|
||||
recently known value.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base : pd.DataFrame
|
||||
The first known values.
|
||||
deltas : pd.DataFrame
|
||||
Overwrites to the base data.
|
||||
dates : pd.DatetimeIndex
|
||||
The dates requested by the loader.
|
||||
|
||||
Returns
|
||||
-------
|
||||
new_base : pd.DataFrame
|
||||
The new base data with novel deltas inserted.
|
||||
"""
|
||||
get_indexes = dates.searchsorted
|
||||
return pd.concat(
|
||||
(base,
|
||||
deltas.loc[
|
||||
(get_indexes(deltas[TS_FIELD_NAME].values, 'right') -
|
||||
get_indexes(deltas[AD_FIELD_NAME].values, 'letf')) <= 1
|
||||
].drop(AD_FIELD_NAME, 1)),
|
||||
ignore_index=True,
|
||||
)
|
||||
|
||||
|
||||
def overwrite_from_dates(asof, dates, sparse_dates, asset_idx, value):
|
||||
"""Construct a `Float64Overwrite` with the correct
|
||||
start and end date based on the asof date of the delta,
|
||||
the dense_dates, and the sparse_dates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
asof : datetime
|
||||
The asof date of the delta.
|
||||
dates : pd.DatetimeIndex
|
||||
The dates requested by the loader.
|
||||
sparse_dates : pd.DatetimeIndex
|
||||
The dates that appeared in the dataset.
|
||||
asset_idx : int
|
||||
The index of the asset in the block.
|
||||
value : np.float64
|
||||
The value to overwrite with.
|
||||
|
||||
Returns
|
||||
-------
|
||||
overwrite : Float64Overwrite
|
||||
The overwrite that will apply the new value to the data.
|
||||
"""
|
||||
return Float64Overwrite(
|
||||
dates.searchsorted(asof),
|
||||
dates.get_loc(sparse_dates[sparse_dates.searchsorted(asof) + 1]) - 1,
|
||||
asset_idx,
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
def adjustments_from_deltas(dates,
|
||||
sparse_dates,
|
||||
column_idx,
|
||||
assets,
|
||||
deltas):
|
||||
"""Collect all the adjustments that occur in a dataset that does not
|
||||
have a sid column.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : pd.DatetimeIndex
|
||||
The dates requested by the loader.
|
||||
sparse_dates : pd.DatetimeIndex
|
||||
The dates that were in the sparse data.
|
||||
column_idx : int
|
||||
The index of the column in the dataset.
|
||||
deltas : pd.DataFrame
|
||||
The overwrites that should be applied to the dataset.
|
||||
|
||||
Returns
|
||||
-------
|
||||
adjustments : dict[idx -> Float64Overwrite]
|
||||
The adjustments dictionary to feed to the adjusted array.
|
||||
"""
|
||||
return {
|
||||
dates.get_loc(kd): tuple(
|
||||
overwrite_from_dates(
|
||||
deltas.loc[kd, AD_FIELD_NAME],
|
||||
dates,
|
||||
sparse_dates,
|
||||
n,
|
||||
v,
|
||||
) for n in range(len(assets))
|
||||
) for kd, v in deltas.icol(column_idx).iteritems()
|
||||
}
|
||||
|
||||
|
||||
def to_datetime(dt64, factory=datetime.fromtimestamp, _ns_to_s=1000 ** 3):
|
||||
"""Convert a numpy datetime64 to a datetime object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dt64 : datetime64
|
||||
The datetime64 to coerce.
|
||||
factory : callable, optional
|
||||
The function to coerce the timestamp as seconds into an object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dt : datetime
|
||||
The dt64 coerced to a datetime.
|
||||
"""
|
||||
return factory(int(dt64) / _ns_to_s, tz=utc)
|
||||
|
||||
|
||||
class BlazeLoader(dict):
|
||||
def __init__(self, colmap=None):
|
||||
self.update(colmap or {})
|
||||
|
||||
@classmethod
|
||||
@memoize
|
||||
def global_instance(cls):
|
||||
return cls()
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
try:
|
||||
dataset, = set(map(getdataset, columns))
|
||||
except ValueError:
|
||||
raise AssertionError('all columns must come from the same dataset')
|
||||
|
||||
expr, deltas, resources = self[dataset]
|
||||
have_sids = SID_FIELD_NAME in expr.fields
|
||||
assets = list(map(int, assets)) # coerce from numpy.int64
|
||||
fields = list(map(dataset_name, columns))
|
||||
query_fields = fields + [AD_FIELD_NAME, TS_FIELD_NAME] + (
|
||||
[SID_FIELD_NAME] if have_sids else []
|
||||
)
|
||||
|
||||
def where(e):
|
||||
"""Create the query to run against the resources.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
e : Expr
|
||||
The base or deltas expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
q : Expr
|
||||
The query to run.
|
||||
"""
|
||||
ts = e[TS_FIELD_NAME]
|
||||
# Hack to get the lower bound to query:
|
||||
# This must be strictly executed because the data for `ts` will
|
||||
# be removed from scope too early otherwise.
|
||||
lower = odo(ts[ts <= to_datetime(dates[0])].max(), pd.Timestamp)
|
||||
return e[
|
||||
e[SID_FIELD_NAME].isin(assets) &
|
||||
(ts >= lower) &
|
||||
(ts < to_datetime(dates[-1]))
|
||||
][query_fields]
|
||||
|
||||
materialized_expr = odo(
|
||||
bz.compute(where(expr), resources),
|
||||
pd.DataFrame,
|
||||
)
|
||||
materialized_deltas = (
|
||||
odo(bz.compute(where(deltas), resources), pd.DataFrame)
|
||||
if deltas is not None else
|
||||
pd.DataFrame(columns=query_fields)
|
||||
)
|
||||
# Capture the original (sparse) dates that came from the resource.
|
||||
sparse_dates = pd.DatetimeIndex(materialized_expr[TS_FIELD_NAME])
|
||||
# Inline the deltas that changed our most recently known value.
|
||||
# Also, we reindex by the dates to create a dense representation of
|
||||
# the data.
|
||||
base = inline_novel_deltas(
|
||||
materialized_expr,
|
||||
materialized_deltas,
|
||||
dates,
|
||||
).drop(AD_FIELD_NAME, axis=1).set_index(TS_FIELD_NAME).reindex(
|
||||
dates,
|
||||
method='ffill',
|
||||
)
|
||||
if have_sids:
|
||||
base.index.name = TS_FIELD_NAME
|
||||
# Unstack by the sid so that we get a multi-index on the columns
|
||||
# of datacolumn, sid.
|
||||
base = base.set_index(SID_FIELD_NAME, append=True).unstack()
|
||||
column_view = identity
|
||||
else:
|
||||
def column_view(arr, _shape=(len(dates), len(assets))):
|
||||
"""Return a virtual matrix where we make a view that
|
||||
duplicates a single column for all the assets.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> arr = np.array([1, 2, 3])
|
||||
>>> as_strided(arr, shape=(3, 3), strides=(arr.itemsize, 0))
|
||||
array([[1, 1, 1],
|
||||
[2, 2, 2],
|
||||
[3, 3, 3]])
|
||||
"""
|
||||
return as_strided(
|
||||
arr,
|
||||
shape=_shape,
|
||||
strides=(arr.itemsize, 0),
|
||||
)
|
||||
|
||||
for column_idx, column in enumerate(columns):
|
||||
yield adjusted_array(
|
||||
column_view(base[column.name].values.astype(column.dtype)),
|
||||
mask,
|
||||
adjustments_from_deltas(
|
||||
dates,
|
||||
sparse_dates,
|
||||
column_idx,
|
||||
assets,
|
||||
materialized_deltas,
|
||||
)
|
||||
)
|
||||
|
||||
global_loader = BlazeLoader.global_instance()
|
||||
Reference in New Issue
Block a user