mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 22:34:50 +08:00
Merge pull request #1396 from quantopian/add_estimates_quarter_loader_to_pipeline
Add estimates quarter loader to pipeline
This commit is contained in:
@@ -20,8 +20,10 @@ from toolz import curry
|
||||
from zipline.errors import WindowLengthNotPositive, WindowLengthTooLong
|
||||
from zipline.lib.adjustment import (
|
||||
Datetime64Overwrite,
|
||||
Datetime641DArrayOverwrite,
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
Float641DArrayOverwrite,
|
||||
ObjectOverwrite,
|
||||
)
|
||||
from zipline.lib.adjusted_array import AdjustedArray, NOMASK
|
||||
@@ -200,11 +202,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
)
|
||||
|
||||
|
||||
def _gen_overwrite_adjustment_cases(name,
|
||||
make_input,
|
||||
make_expected_output,
|
||||
dtype,
|
||||
missing_value):
|
||||
def _gen_overwrite_adjustment_cases(dtype):
|
||||
"""
|
||||
Generate test cases for overwrite adjustments.
|
||||
|
||||
@@ -224,6 +222,8 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
unicode_dtype: ObjectOverwrite,
|
||||
object_dtype: ObjectOverwrite,
|
||||
}[dtype]
|
||||
make_expected_dtype = as_dtype(dtype)
|
||||
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
|
||||
|
||||
if dtype == object_dtype:
|
||||
# When we're testing object dtypes, we expect to have strings, but
|
||||
@@ -235,30 +235,30 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
|
||||
adjustments = {}
|
||||
buffer_as_of = [None] * 6
|
||||
baseline = make_input([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
baseline = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
buffer_as_of[0] = make_expected_output([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, make_overwrite_value(dtype, 1)),
|
||||
]
|
||||
buffer_as_of[1] = make_expected_output([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# No adjustment at index 2.
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
@@ -267,33 +267,136 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
adjustment_type(1, 2, 1, 1, make_overwrite_value(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, make_overwrite_value(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = make_expected_output([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
[2, 3, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
[2, 3, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, make_overwrite_value(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = make_expected_output([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
[2, 3, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
[2, 3, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, make_overwrite_value(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, make_overwrite_value(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = make_expected_output([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
[2, 6, 7],
|
||||
[2, 6, 5],
|
||||
[2, 6, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[5] = make_expected_dtype([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
[2, 6, 7],
|
||||
[2, 6, 5],
|
||||
[2, 6, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
missing_value,
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows=6,
|
||||
)
|
||||
|
||||
|
||||
def _gen_overwrite_1d_array_adjustment_case(dtype):
|
||||
"""
|
||||
Generate test cases for overwrite adjustments.
|
||||
|
||||
The algorithm used here is the same as the one used above for
|
||||
multiplicative adjustments. The only difference is the semantics of how
|
||||
the adjustments are expected to modify the arrays.
|
||||
|
||||
This is parameterized on `make_input` and `make_expected_output` functions,
|
||||
which take 1-D lists of values and transform them into desired input/output
|
||||
arrays. We do this so that we can easily test both vanilla numpy ndarrays
|
||||
and our own LabelArray class for strings.
|
||||
"""
|
||||
adjustment_type = {
|
||||
float64_dtype: Float641DArrayOverwrite,
|
||||
datetime64ns_dtype: Datetime641DArrayOverwrite,
|
||||
}[dtype]
|
||||
make_expected_dtype = as_dtype(dtype)
|
||||
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
|
||||
|
||||
adjustments = {}
|
||||
buffer_as_of = [None] * 6
|
||||
baseline = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals1 = [1]
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(
|
||||
0, 0, 0, 0,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals1])
|
||||
)
|
||||
]
|
||||
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# No adjustment at index 2.
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
vals3 = [4, 4, 1]
|
||||
adjustments[3] = [
|
||||
adjustment_type(
|
||||
0, 2, 0, 0,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals3])
|
||||
)
|
||||
]
|
||||
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
|
||||
[4, 2, 2],
|
||||
[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals4 = [5] * 4
|
||||
adjustments[4] = [
|
||||
adjustment_type(
|
||||
0, 3, 2, 2,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals4]))
|
||||
]
|
||||
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
|
||||
[4, 2, 5],
|
||||
[1, 2, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals5 = range(1, 6)
|
||||
adjustments[5] = [
|
||||
adjustment_type(
|
||||
0, 4, 1, 1,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals5])),
|
||||
]
|
||||
buffer_as_of[5] = make_expected_dtype([[4, 1, 5],
|
||||
[4, 2, 5],
|
||||
[1, 3, 5],
|
||||
[2, 4, 5],
|
||||
[2, 5, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
@@ -426,22 +529,10 @@ class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_overwrite_adjustment_cases(
|
||||
'float',
|
||||
make_input=as_dtype(float64_dtype),
|
||||
make_expected_output=as_dtype(float64_dtype),
|
||||
dtype=float64_dtype,
|
||||
missing_value=default_missing_value_for_dtype(float64_dtype),
|
||||
),
|
||||
_gen_overwrite_adjustment_cases(
|
||||
'datetime',
|
||||
make_input=as_dtype(datetime64ns_dtype),
|
||||
make_expected_output=as_dtype(datetime64ns_dtype),
|
||||
dtype=datetime64ns_dtype,
|
||||
missing_value=default_missing_value_for_dtype(
|
||||
datetime64ns_dtype,
|
||||
),
|
||||
),
|
||||
_gen_overwrite_adjustment_cases(float64_dtype),
|
||||
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(float64_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(datetime64ns_dtype),
|
||||
# There are six cases here:
|
||||
# Using np.bytes/np.unicode/object arrays as inputs.
|
||||
# Passing np.bytes/np.unicode/object arrays to LabelArray,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+146
-2
@@ -3,7 +3,7 @@ from cpython cimport Py_EQ
|
||||
|
||||
from pandas import isnull, Timestamp
|
||||
from numpy cimport float64_t, uint8_t, int64_t
|
||||
from numpy import datetime64, float64
|
||||
from numpy import asarray, datetime64, float64, int64
|
||||
# Purely for readability. There aren't C-level declarations for these types.
|
||||
ctypedef object Int64Index_t
|
||||
ctypedef object DatetimeIndex_t
|
||||
@@ -364,6 +364,150 @@ cdef class Float64Overwrite(Float64Adjustment):
|
||||
data[row, col] = value
|
||||
|
||||
|
||||
cdef class ArrayAdjustment(Adjustment):
|
||||
"""
|
||||
Base class for ArrayAdjustments.
|
||||
|
||||
Subclasses should inherit and provide a `values` attribute and a `mutate`
|
||||
method.
|
||||
"""
|
||||
def __repr__(self):
|
||||
return (
|
||||
"%s(first_row=%d, last_row=%d,"
|
||||
" first_col=%d, last_col=%d, values=%s)" % (
|
||||
type(self).__name__,
|
||||
self.first_row,
|
||||
self.last_row,
|
||||
self.first_col,
|
||||
self.last_col,
|
||||
asarray(self.values),
|
||||
)
|
||||
)
|
||||
|
||||
cdef class Float641DArrayOverwrite(ArrayAdjustment):
|
||||
"""
|
||||
An adjustment that overwrites subarrays with a value for each subarray.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(25, dtype=float).reshape(5, 5)
|
||||
>>> arr
|
||||
array([[ 0., 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8., 9.],
|
||||
[ 10., 11., 12., 13., 14.],
|
||||
[ 15., 16., 17., 18., 19.],
|
||||
[ 20., 21., 22., 23., 24.]])
|
||||
>>> adj = Float641DArrayOverwrite(
|
||||
... row_start=0,
|
||||
... row_end=3,
|
||||
... column_start=0,
|
||||
... column_end=0,
|
||||
... values=np.array([1, 2, 3, 4]),
|
||||
)
|
||||
>>> adj.mutate(arr)
|
||||
>>> arr
|
||||
array([[ 1., 1., 2., 3., 4.],
|
||||
[ 2., 6., 7., 8., 9.],
|
||||
[ 3., 11., 12., 13., 14.],
|
||||
[ 4., 16., 17., 18., 19.],
|
||||
[ 20., 21., 22., 23., 24.]])
|
||||
"""
|
||||
cdef:
|
||||
readonly float64_t[:] values
|
||||
|
||||
def __init__(self,
|
||||
int64_t first_row,
|
||||
int64_t last_row,
|
||||
int64_t first_col,
|
||||
int64_t last_col,
|
||||
float64_t[:] values):
|
||||
super(Float641DArrayOverwrite, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
if last_row + 1 - first_row != len(values):
|
||||
raise ValueError(
|
||||
"Mismatch: got %d values for rows starting at index %d and "
|
||||
"ending at index %d." % (len(values), first_row, last_row)
|
||||
)
|
||||
self.values = values
|
||||
|
||||
cpdef mutate(self, float64_t[:, :] data):
|
||||
cdef Py_ssize_t fill_range, row, col
|
||||
cdef float64_t[:] values = self.values
|
||||
for col in range(self.first_col, self.last_col + 1):
|
||||
for i, row in enumerate(range(self.first_row, self.last_row + 1)):
|
||||
data[row, col] = values[i]
|
||||
|
||||
|
||||
cdef class Datetime641DArrayOverwrite(ArrayAdjustment):
|
||||
"""
|
||||
An adjustment that overwrites subarrays with a value for each subarray.
|
||||
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np; import pandas as pd
|
||||
>>> dts = pd.date_range('2014', freq='D', periods=9, tz='UTC')
|
||||
>>> arr = dts.values.reshape(3, 3)
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
>>> adj = Datetime641DArrayOverwrite(
|
||||
... first_row=1,
|
||||
... last_row=2,
|
||||
... first_col=1,
|
||||
... last_col=2,
|
||||
... values=np.array([
|
||||
... np.datetime64(0, 'ns'),
|
||||
... np.datetime64(1, 'ns')
|
||||
... ])
|
||||
... )
|
||||
>>> adj.mutate(arr.view(np.int64))
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, True, True],
|
||||
[False, False, False]], dtype=bool)
|
||||
>>> arr == np.datetime64(1, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, False, False],
|
||||
[False, True, True]], dtype=bool)
|
||||
"""
|
||||
cdef:
|
||||
readonly int64_t[:] values
|
||||
|
||||
def __init__(self,
|
||||
int64_t first_row,
|
||||
int64_t last_row,
|
||||
int64_t first_col,
|
||||
int64_t last_col,
|
||||
object values):
|
||||
super(Datetime641DArrayOverwrite, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
if last_row + 1 - first_row != len(values):
|
||||
raise ValueError("Mismatch: got %d values for rows starting at"
|
||||
" index %d and ending at index %d." % (
|
||||
len(values), first_row, last_row)
|
||||
)
|
||||
self.values = asarray([datetime_to_int(value) for value in values])
|
||||
|
||||
cpdef mutate(self, int64_t[:, :] data):
|
||||
cdef Py_ssize_t row, col
|
||||
cdef int64_t[:] values = self.values
|
||||
for col in range(self.first_col, self.last_col + 1):
|
||||
for i, row in enumerate(range(self.first_row, self.last_row + 1)):
|
||||
data[row, col] = values[i]
|
||||
|
||||
|
||||
cdef class Float64Add(Float64Adjustment):
|
||||
"""
|
||||
An adjustment that adds a float.
|
||||
@@ -458,7 +602,7 @@ cdef datetime_to_int(object datetimelike):
|
||||
datetimelike.dtype.name,
|
||||
)
|
||||
|
||||
return datetimelike.astype(int)
|
||||
return datetimelike.astype(int64)
|
||||
|
||||
|
||||
cdef class Datetime64Adjustment(_Int64Adjustment):
|
||||
|
||||
@@ -6,6 +6,8 @@ ANNOUNCEMENT_FIELD_NAME = 'announcement_date'
|
||||
CASH_FIELD_NAME = 'cash'
|
||||
DAYS_SINCE_PREV = 'days_since_prev'
|
||||
DAYS_TO_NEXT = 'days_to_next'
|
||||
FISCAL_QUARTER_FIELD_NAME = 'fiscal_quarter'
|
||||
FISCAL_YEAR_FIELD_NAME = 'fiscal_year'
|
||||
NEXT_ANNOUNCEMENT = 'next_announcement'
|
||||
PREVIOUS_AMOUNT = 'previous_amount'
|
||||
PREVIOUS_ANNOUNCEMENT = 'previous_announcement'
|
||||
|
||||
@@ -175,8 +175,10 @@ from zipline.pipeline.common import (
|
||||
from zipline.pipeline.data.dataset import DataSet, Column
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
last_in_date_group,
|
||||
normalize_data_query_bounds,
|
||||
normalize_timestamp_to_query_time,
|
||||
ffill_across_cols
|
||||
)
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.lib.adjusted_array import AdjustedArray, can_represent_dtype
|
||||
@@ -186,7 +188,7 @@ from zipline.utils.input_validation import (
|
||||
ensure_timezone,
|
||||
optionally,
|
||||
)
|
||||
from zipline.utils.numpy_utils import bool_dtype, categorical_dtype
|
||||
from zipline.utils.numpy_utils import bool_dtype
|
||||
from zipline.utils.pool import SequentialPool
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
@@ -869,9 +871,9 @@ def adjustments_from_deltas_with_sids(dense_dates,
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dates : pd.DatetimeIndex
|
||||
The dates requested by the loader.
|
||||
dense_dates : pd.DatetimeIndex
|
||||
The dates requested by the loader.
|
||||
sparse_dates : pd.DatetimeIndex
|
||||
The dates that were in the raw data.
|
||||
column_idx : int
|
||||
The index of the column in the dataset.
|
||||
@@ -1026,22 +1028,9 @@ class BlazeLoader(dict):
|
||||
|
||||
return odo(e[predicate][colnames], pd.DataFrame, **odo_kwargs)
|
||||
|
||||
if checkpoints is not None:
|
||||
ts = checkpoints[TS_FIELD_NAME]
|
||||
checkpoints_ts = odo(ts[ts <= lower_dt].max(), pd.Timestamp)
|
||||
if pd.isnull(checkpoints_ts):
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
else:
|
||||
materialized_checkpoints = odo(
|
||||
checkpoints[ts == checkpoints_ts][colnames],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
)
|
||||
lower = checkpoints_ts
|
||||
else:
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
lower, materialized_checkpoints = get_materialized_checkpoints(
|
||||
checkpoints, colnames, lower_dt, odo_kwargs
|
||||
)
|
||||
|
||||
materialized_expr = self.pool.apply_async(collect_expr, (expr, lower))
|
||||
materialized_deltas = (
|
||||
@@ -1091,71 +1080,18 @@ class BlazeLoader(dict):
|
||||
)
|
||||
sparse_output.drop(AD_FIELD_NAME, axis=1, inplace=True)
|
||||
|
||||
def last_in_date_group(df, reindex, have_sids=have_sids):
|
||||
idx = dates[dates.searchsorted(
|
||||
df[TS_FIELD_NAME].values.astype('datetime64[D]')
|
||||
)]
|
||||
if have_sids:
|
||||
idx = [idx, SID_FIELD_NAME]
|
||||
|
||||
last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(
|
||||
idx,
|
||||
sort=False,
|
||||
).last()
|
||||
|
||||
if have_sids:
|
||||
last_in_group = last_in_group.unstack()
|
||||
|
||||
if reindex:
|
||||
if have_sids:
|
||||
cols = last_in_group.columns
|
||||
last_in_group = last_in_group.reindex(
|
||||
index=dates,
|
||||
columns=pd.MultiIndex.from_product(
|
||||
(cols.levels[0], assets),
|
||||
names=cols.names,
|
||||
),
|
||||
)
|
||||
else:
|
||||
last_in_group = last_in_group.reindex(dates)
|
||||
|
||||
return last_in_group
|
||||
|
||||
sparse_deltas = last_in_date_group(non_novel_deltas, reindex=False)
|
||||
dense_output = last_in_date_group(sparse_output, reindex=True)
|
||||
dense_output.ffill(inplace=True)
|
||||
|
||||
# Fill in missing values specified by each column. This is made
|
||||
# significantly more complex by the fact that we need to work around
|
||||
# two pandas issues:
|
||||
|
||||
# 1) When we have sids, if there are no records for a given sid for any
|
||||
# dates, pandas will generate a column full of NaNs for that sid.
|
||||
# This means that some of the columns in `dense_output` are now
|
||||
# float instead of the intended dtype, so we have to coerce back to
|
||||
# our expected type and convert NaNs into the desired missing value.
|
||||
|
||||
# 2) DataFrame.ffill assumes that receiving None as a fill-value means
|
||||
# that no value was passed. Consequently, there's no way to tell
|
||||
# pandas to replace NaNs in an object column with None using fillna,
|
||||
# so we have to roll our own instead using df.where.
|
||||
for column in columns:
|
||||
# Special logic for strings since `fillna` doesn't work if the
|
||||
# missing value is `None`.
|
||||
if column.dtype == categorical_dtype:
|
||||
dense_output[column.name] = dense_output[
|
||||
column.name
|
||||
].where(pd.notnull(dense_output[column.name]),
|
||||
column.missing_value)
|
||||
else:
|
||||
# We need to execute `fillna` before `astype` in case the
|
||||
# column contains NaNs and needs to be cast to bool or int.
|
||||
# This is so that the NaNs are replaced first, since pandas
|
||||
# can't convert NaNs for those types.
|
||||
dense_output[column.name] = dense_output[
|
||||
column.name
|
||||
].fillna(column.missing_value).astype(column.dtype)
|
||||
|
||||
sparse_deltas = last_in_date_group(non_novel_deltas,
|
||||
dates,
|
||||
assets,
|
||||
reindex=False,
|
||||
have_sids=have_sids)
|
||||
dense_output = last_in_date_group(sparse_output,
|
||||
dates,
|
||||
assets,
|
||||
reindex=True,
|
||||
have_sids=have_sids)
|
||||
ffill_across_cols(dense_output, columns, {c.name: c.name
|
||||
for c in columns})
|
||||
if have_sids:
|
||||
adjustments_from_deltas = adjustments_from_deltas_with_sids
|
||||
column_view = identity
|
||||
@@ -1188,6 +1124,7 @@ class BlazeLoader(dict):
|
||||
for column_idx, column in enumerate(columns)
|
||||
}
|
||||
|
||||
|
||||
global_loader = BlazeLoader.global_instance()
|
||||
|
||||
|
||||
@@ -1219,12 +1156,48 @@ def bind_expression_to_resources(expr, resources):
|
||||
})
|
||||
|
||||
|
||||
def get_materialized_checkpoints(checkpoints, colnames, lower_dt, odo_kwargs):
|
||||
"""
|
||||
Computes a lower bound and a DataFrame checkpoints.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
checkpoints : Expr
|
||||
Bound blaze expression for a checkpoints table from which to get a
|
||||
computed lower bound.
|
||||
colnames : iterable of str
|
||||
The names of the columns for which checkpoints should be computed.
|
||||
lower_dt : pd.Timestamp
|
||||
The lower date being queried for that serves as an upper bound for
|
||||
checkpoints.
|
||||
odo_kwargs : dict, optional
|
||||
The extra keyword arguments to pass to ``odo``.
|
||||
"""
|
||||
if checkpoints is not None:
|
||||
ts = checkpoints[TS_FIELD_NAME]
|
||||
checkpoints_ts = odo(ts[ts <= lower_dt].max(), pd.Timestamp)
|
||||
if pd.isnull(checkpoints_ts):
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
else:
|
||||
materialized_checkpoints = odo(
|
||||
checkpoints[ts == checkpoints_ts][colnames],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
)
|
||||
lower = checkpoints_ts
|
||||
else:
|
||||
materialized_checkpoints = pd.DataFrame(columns=colnames)
|
||||
lower = None
|
||||
return lower, materialized_checkpoints
|
||||
|
||||
|
||||
def ffill_query_in_range(expr,
|
||||
lower,
|
||||
upper,
|
||||
checkpoints=None,
|
||||
odo_kwargs=None,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
sid_field=SID_FIELD_NAME):
|
||||
ts_field=TS_FIELD_NAME):
|
||||
"""Query a blaze expression in a given time range properly forward filling
|
||||
from values that fall before the lower date.
|
||||
|
||||
@@ -1236,12 +1209,13 @@ def ffill_query_in_range(expr,
|
||||
The lower date to query for.
|
||||
upper : datetime
|
||||
The upper date to query for.
|
||||
checkpoints : Expr, optional
|
||||
Bound blaze expression for a checkpoints table from which to get a
|
||||
computed lower bound.
|
||||
odo_kwargs : dict, optional
|
||||
The extra keyword arguments to pass to ``odo``.
|
||||
ts_field : str, optional
|
||||
The name of the timestamp field in the given blaze expression.
|
||||
sid_field : str, optional
|
||||
The name of the sid field in the given blaze expression.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -1250,27 +1224,24 @@ def ffill_query_in_range(expr,
|
||||
start before the requested start date if a value is needed to ffill.
|
||||
"""
|
||||
odo_kwargs = odo_kwargs or {}
|
||||
filtered = expr[expr[ts_field] <= lower]
|
||||
computed_lower = odo(
|
||||
bz.by(
|
||||
filtered[sid_field],
|
||||
timestamp=filtered[ts_field].max(),
|
||||
).timestamp.min(),
|
||||
pd.Timestamp,
|
||||
**odo_kwargs
|
||||
computed_lower, materialized_checkpoints = get_materialized_checkpoints(
|
||||
checkpoints, expr.fields, lower, odo_kwargs
|
||||
)
|
||||
if pd.isnull(computed_lower):
|
||||
# If there is no lower date, just query for data in the date
|
||||
# range. It must all be null anyways.
|
||||
computed_lower = lower
|
||||
|
||||
raw = odo(
|
||||
expr[
|
||||
(expr[ts_field] >= computed_lower) &
|
||||
(expr[ts_field] <= upper)
|
||||
],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
raw = pd.concat(
|
||||
[materialized_checkpoints,
|
||||
odo(
|
||||
expr[
|
||||
(expr[ts_field] >= computed_lower) &
|
||||
(expr[ts_field] <= upper)
|
||||
],
|
||||
pd.DataFrame,
|
||||
**odo_kwargs
|
||||
)]
|
||||
)
|
||||
raw.loc[:, ts_field] = raw.loc[:, ts_field].astype('datetime64[ns]')
|
||||
return raw
|
||||
|
||||
@@ -0,0 +1,137 @@
|
||||
from datashape import istabular
|
||||
|
||||
from .core import (
|
||||
bind_expression_to_resources,
|
||||
)
|
||||
from zipline.pipeline.common import (
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.blaze.utils import load_raw_data
|
||||
from zipline.pipeline.loaders.earnings_estimates import (
|
||||
NextEarningsEstimatesLoader,
|
||||
PreviousEarningsEstimatesLoader,
|
||||
required_estimates_fields,
|
||||
metadata_columns,
|
||||
)
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_timezone, optionally
|
||||
from zipline.utils.preprocess import preprocess
|
||||
|
||||
|
||||
class BlazeEstimatesLoader(PipelineLoader):
|
||||
"""An abstract pipeline loader for the estimates datasets that loads
|
||||
data from a blaze expression.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
columns : dict[str -> str]
|
||||
A dict mapping BoundColumn names to the associated names in `expr`.
|
||||
resources : dict, optional
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
checkpoints : Expr, optional
|
||||
The expression representing checkpointed data to be used for faster
|
||||
forward-filling of data from `expr`.
|
||||
|
||||
Notes
|
||||
-----
|
||||
The expression should have a tabular dshape of::
|
||||
|
||||
Dim * {{
|
||||
{SID_FIELD_NAME}: int64,
|
||||
{TS_FIELD_NAME}: datetime,
|
||||
{FISCAL_YEAR_FIELD_NAME}: float64,
|
||||
{FISCAL_QUARTER_FIELD_NAME}: float64,
|
||||
{EVENT_DATE_FIELD_NAME}: datetime,
|
||||
}}
|
||||
|
||||
And other dataset-specific fields, where each row of the table is a
|
||||
record including the sid to identify the company, the timestamp where we
|
||||
learned about the announcement, and the date of the event.
|
||||
|
||||
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
|
||||
start the backtest with knowledge of all announcements.
|
||||
"""
|
||||
__doc__ = __doc__.format(
|
||||
SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME=FISCAL_YEAR_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME=FISCAL_QUARTER_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME=EVENT_DATE_FIELD_NAME,
|
||||
)
|
||||
|
||||
@preprocess(data_query_tz=optionally(ensure_timezone))
|
||||
def __init__(self,
|
||||
expr,
|
||||
columns,
|
||||
resources=None,
|
||||
odo_kwargs=None,
|
||||
data_query_time=None,
|
||||
data_query_tz=None,
|
||||
checkpoints=None):
|
||||
|
||||
dshape = expr.dshape
|
||||
if not istabular(dshape):
|
||||
raise ValueError(
|
||||
'expression dshape must be tabular, got: %s' % dshape,
|
||||
)
|
||||
|
||||
required_cols = list(
|
||||
required_estimates_fields(columns)
|
||||
)
|
||||
self._expr = bind_expression_to_resources(
|
||||
expr[required_cols],
|
||||
resources,
|
||||
)
|
||||
self._columns = columns
|
||||
self._odo_kwargs = odo_kwargs if odo_kwargs is not None else {}
|
||||
check_data_query_args(data_query_time, data_query_tz)
|
||||
self._data_query_time = data_query_time
|
||||
self._data_query_tz = data_query_tz
|
||||
self._checkpoints = checkpoints
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
# Only load requested columns.
|
||||
requested_column_names = [self._columns[column.name]
|
||||
for column in columns]
|
||||
raw = load_raw_data(
|
||||
assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr[sorted(metadata_columns.union(requested_column_names))],
|
||||
self._odo_kwargs,
|
||||
checkpoints=self._checkpoints,
|
||||
)
|
||||
|
||||
return self.loader(
|
||||
raw,
|
||||
{column.name: self._columns[column.name] for column in columns}
|
||||
).load_adjusted_array(
|
||||
columns,
|
||||
dates,
|
||||
assets,
|
||||
mask,
|
||||
)
|
||||
|
||||
|
||||
class BlazeNextEstimatesLoader(BlazeEstimatesLoader):
|
||||
loader = NextEarningsEstimatesLoader
|
||||
|
||||
|
||||
class BlazePreviousEstimatesLoader(BlazeEstimatesLoader):
|
||||
loader = PreviousEarningsEstimatesLoader
|
||||
@@ -2,21 +2,17 @@ from datashape import istabular
|
||||
|
||||
from .core import (
|
||||
bind_expression_to_resources,
|
||||
ffill_query_in_range,
|
||||
)
|
||||
from zipline.pipeline.common import SID_FIELD_NAME, TS_FIELD_NAME, \
|
||||
EVENT_DATE_FIELD_NAME
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.blaze.utils import load_raw_data
|
||||
from zipline.pipeline.loaders.events import (
|
||||
EventsLoader,
|
||||
required_event_fields,
|
||||
)
|
||||
from zipline.pipeline.common import (
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
check_data_query_args,
|
||||
normalize_data_query_bounds,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_timezone, optionally
|
||||
from zipline.utils.preprocess import preprocess
|
||||
@@ -30,6 +26,10 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
----------
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
next_value_columns : dict[BoundColumn -> raw column name]
|
||||
A dict mapping 'next' BoundColumns to their column names in `expr`.
|
||||
previous_value_columns : dict[BoundColumn -> raw column name]
|
||||
A dict mapping 'previous' BoundColumns to their column names in `expr`.
|
||||
resources : dict, optional
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
@@ -37,9 +37,7 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
data_query_time : time, optional
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
dataset : DataSet
|
||||
The DataSet object for which this loader loads data.
|
||||
The timezone to use for the data query cutoff.
|
||||
|
||||
Notes
|
||||
-----
|
||||
@@ -48,17 +46,21 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
Dim * {{
|
||||
{SID_FIELD_NAME}: int64,
|
||||
{TS_FIELD_NAME}: datetime,
|
||||
{EVENT_DATE_FIELD_NAME}: datetime,
|
||||
}}
|
||||
|
||||
And other dataset-specific fields, where each row of the table is a
|
||||
record including the sid to identify the company, the timestamp where we
|
||||
learned about the announcement, and the date when the earnings will be z
|
||||
announced.
|
||||
learned about the announcement, and the event date.
|
||||
|
||||
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
|
||||
start the backtest with knowledge of all announcements.
|
||||
"""
|
||||
|
||||
__doc__ = __doc__.format(SID_FIELD_NAME=SID_FIELD_NAME,
|
||||
TS_FIELD_NAME=TS_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME=EVENT_DATE_FIELD_NAME)
|
||||
|
||||
@preprocess(data_query_tz=optionally(ensure_timezone))
|
||||
def __init__(self,
|
||||
expr,
|
||||
@@ -90,34 +92,12 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
data_query_time = self._data_query_time
|
||||
data_query_tz = self._data_query_tz
|
||||
lower_dt, upper_dt = normalize_data_query_bounds(
|
||||
dates[0],
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
|
||||
raw = ffill_query_in_range(
|
||||
self._expr,
|
||||
lower_dt,
|
||||
upper_dt,
|
||||
self._odo_kwargs,
|
||||
)
|
||||
sids = raw.loc[:, SID_FIELD_NAME]
|
||||
raw.drop(
|
||||
sids[~sids.isin(assets)].index,
|
||||
inplace=True
|
||||
)
|
||||
if data_query_time is not None:
|
||||
normalize_timestamp_to_query_time(
|
||||
raw,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
raw = load_raw_data(assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr,
|
||||
self._odo_kwargs)
|
||||
|
||||
return EventsLoader(
|
||||
events=raw,
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
from zipline.pipeline.common import SID_FIELD_NAME, TS_FIELD_NAME
|
||||
from zipline.pipeline.loaders.blaze.core import ffill_query_in_range
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
normalize_data_query_bounds,
|
||||
normalize_timestamp_to_query_time,
|
||||
)
|
||||
|
||||
|
||||
def load_raw_data(assets,
|
||||
dates,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
expr,
|
||||
odo_kwargs,
|
||||
checkpoints=None):
|
||||
"""
|
||||
Given an expression representing data to load, perform normalization and
|
||||
forward-filling and return the data, materialized. Only accepts data with a
|
||||
`sid` field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
assets : pd.int64index
|
||||
the assets to load data for.
|
||||
dates : pd.datetimeindex
|
||||
the simulation dates to load data for.
|
||||
data_query_time : datetime.time
|
||||
the time used as cutoff for new information.
|
||||
data_query_tz : tzinfo
|
||||
the timezone to normalize your dates to before comparing against
|
||||
`time`.
|
||||
expr : expr
|
||||
the expression representing the data to load.
|
||||
odo_kwargs : dict
|
||||
extra keyword arguments to pass to odo when executing the expression.
|
||||
checkpoints : expr, optional
|
||||
the expression representing the checkpointed data for `expr`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
raw : pd.dataframe
|
||||
The result of computing expr and materializing the result as a
|
||||
dataframe.
|
||||
"""
|
||||
lower_dt, upper_dt = normalize_data_query_bounds(
|
||||
dates[0],
|
||||
dates[-1],
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
)
|
||||
raw = ffill_query_in_range(
|
||||
expr,
|
||||
lower_dt,
|
||||
upper_dt,
|
||||
checkpoints=checkpoints,
|
||||
odo_kwargs=odo_kwargs,
|
||||
)
|
||||
sids = raw[SID_FIELD_NAME]
|
||||
raw.drop(
|
||||
sids[~sids.isin(assets)].index,
|
||||
inplace=True
|
||||
)
|
||||
if data_query_time is not None:
|
||||
normalize_timestamp_to_query_time(
|
||||
raw,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
inplace=True,
|
||||
ts_field=TS_FIELD_NAME,
|
||||
)
|
||||
return raw
|
||||
@@ -0,0 +1,647 @@
|
||||
from abc import abstractmethod, abstractproperty
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from six import viewvalues
|
||||
from toolz import groupby
|
||||
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import (
|
||||
Datetime641DArrayOverwrite,
|
||||
Datetime64Overwrite,
|
||||
Float641DArrayOverwrite,
|
||||
Float64Overwrite,
|
||||
)
|
||||
|
||||
from zipline.pipeline.common import (
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype, float64_dtype
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
ffill_across_cols,
|
||||
last_in_date_group
|
||||
)
|
||||
|
||||
|
||||
INVALID_NUM_QTRS_MESSAGE = "Passed invalid number of quarters %s; " \
|
||||
"must pass a number of quarters >= 0"
|
||||
NEXT_FISCAL_QUARTER = 'next_fiscal_quarter'
|
||||
NEXT_FISCAL_YEAR = 'next_fiscal_year'
|
||||
NORMALIZED_QUARTERS = 'normalized_quarters'
|
||||
PREVIOUS_FISCAL_QUARTER = 'previous_fiscal_quarter'
|
||||
PREVIOUS_FISCAL_YEAR = 'previous_fiscal_year'
|
||||
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
|
||||
SIMULATION_DATES = 'dates'
|
||||
|
||||
|
||||
def normalize_quarters(years, quarters):
|
||||
return years * 4 + quarters - 1
|
||||
|
||||
|
||||
def split_normalized_quarters(normalized_quarters):
|
||||
years = normalized_quarters // 4
|
||||
quarters = normalized_quarters % 4
|
||||
return years, quarters + 1
|
||||
|
||||
|
||||
# These metadata columns are used to align event indexers.
|
||||
metadata_columns = frozenset({
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
FISCAL_YEAR_FIELD_NAME,
|
||||
})
|
||||
|
||||
|
||||
def required_estimates_fields(columns):
|
||||
"""
|
||||
Compute the set of resource columns required to serve
|
||||
`columns`.
|
||||
"""
|
||||
# We also expect any of the field names that our loadable columns
|
||||
# are mapped to.
|
||||
return metadata_columns.union(viewvalues(columns))
|
||||
|
||||
|
||||
def validate_column_specs(events, columns):
|
||||
"""
|
||||
Verify that the columns of ``events`` can be used by a
|
||||
EarningsEstimatesLoader to serve the BoundColumns described by
|
||||
`columns`.
|
||||
"""
|
||||
required = required_estimates_fields(columns)
|
||||
received = set(events.columns)
|
||||
missing = required - received
|
||||
if missing:
|
||||
raise ValueError(
|
||||
"EarningsEstimatesLoader missing required columns {missing}.\n"
|
||||
"Got Columns: {received}\n"
|
||||
"Expected Columns: {required}".format(
|
||||
missing=sorted(missing),
|
||||
received=sorted(received),
|
||||
required=sorted(required),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class EarningsEstimatesLoader(PipelineLoader):
|
||||
"""
|
||||
An abstract pipeline loader for estimates data that can load data a
|
||||
variable number of quarters forwards/backwards from calendar dates
|
||||
depending on the `num_announcements` attribute of the columns' dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
estimates : pd.DataFrame
|
||||
The raw estimates data.
|
||||
``estimates`` must contain at least 5 columns:
|
||||
sid : int64
|
||||
The asset id associated with each estimate.
|
||||
|
||||
event_date : datetime64[ns]
|
||||
The date on which the event that the estimate is for will/has
|
||||
occurred..
|
||||
|
||||
timestamp : datetime64[ns]
|
||||
The date on which we learned about the estimate.
|
||||
|
||||
fiscal_quarter : int64
|
||||
The quarter during which the event has/will occur.
|
||||
|
||||
fiscal_year : int64
|
||||
The year during which the event has/will occur.
|
||||
|
||||
name_map : dict[str -> str]
|
||||
A map of names of BoundColumns that this loader will load to the
|
||||
names of the corresponding columns in `events`.
|
||||
"""
|
||||
def __init__(self,
|
||||
estimates,
|
||||
name_map):
|
||||
validate_column_specs(
|
||||
estimates,
|
||||
name_map
|
||||
)
|
||||
|
||||
self.estimates = estimates[
|
||||
estimates[EVENT_DATE_FIELD_NAME].notnull() &
|
||||
estimates[FISCAL_QUARTER_FIELD_NAME].notnull() &
|
||||
estimates[FISCAL_YEAR_FIELD_NAME].notnull()
|
||||
]
|
||||
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
|
||||
self.estimates[FISCAL_YEAR_FIELD_NAME],
|
||||
self.estimates[FISCAL_QUARTER_FIELD_NAME],
|
||||
)
|
||||
|
||||
self.array_overwrites_dict = {
|
||||
datetime64ns_dtype: Datetime641DArrayOverwrite,
|
||||
float64_dtype: Float641DArrayOverwrite,
|
||||
}
|
||||
self.scalar_overwrites_dict = {
|
||||
datetime64ns_dtype: Datetime64Overwrite,
|
||||
float64_dtype: Float64Overwrite,
|
||||
}
|
||||
|
||||
self.name_map = name_map
|
||||
|
||||
@abstractmethod
|
||||
def get_zeroth_quarter_idx(self, num_announcements, last, dates):
|
||||
raise NotImplementedError('get_zeroth_quarter_idx')
|
||||
|
||||
@abstractmethod
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
raise NotImplementedError('get_shifted_qtrs')
|
||||
|
||||
@abstractmethod
|
||||
def create_overwrite_for_estimate(self,
|
||||
column,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx):
|
||||
raise NotImplementedError('create_overwrite_for_estimate')
|
||||
|
||||
@abstractproperty
|
||||
def searchsorted_side(self):
|
||||
return NotImplementedError('searchsorted_side')
|
||||
|
||||
def get_requested_quarter_data(self,
|
||||
zero_qtr_data,
|
||||
zeroth_quarter_idx,
|
||||
stacked_last_per_qtr,
|
||||
num_announcements,
|
||||
dates):
|
||||
"""
|
||||
Selects the requested data for each date.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
zero_qtr_data : pd.DataFrame
|
||||
The 'time zero' data for each calendar date per sid.
|
||||
zeroth_quarter_idx : pd.Index
|
||||
An index of calendar dates, sid, and normalized quarters, for only
|
||||
the rows that have a next or previous earnings estimate.
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
The latest estimate known with the dates, normalized quarter, and
|
||||
sid as the index.
|
||||
dates : pd.DatetimeIndex
|
||||
The calendar dates for which estimates data is requested.
|
||||
|
||||
Returns
|
||||
--------
|
||||
requested_qtr_data : pd.DataFrame
|
||||
The DataFrame with the latest values for the requested quarter
|
||||
for all columns; `dates` are the index and columns are a MultiIndex
|
||||
with sids at the top level and the dataset columns on the bottom.
|
||||
"""
|
||||
zero_qtr_data_idx = zero_qtr_data.index
|
||||
requested_qtr_idx = pd.MultiIndex.from_arrays(
|
||||
[
|
||||
zero_qtr_data_idx.get_level_values(0),
|
||||
zero_qtr_data_idx.get_level_values(1),
|
||||
self.get_shifted_qtrs(
|
||||
zeroth_quarter_idx.get_level_values(
|
||||
NORMALIZED_QUARTERS,
|
||||
),
|
||||
num_announcements,
|
||||
),
|
||||
],
|
||||
names=[
|
||||
zero_qtr_data_idx.names[0],
|
||||
zero_qtr_data_idx.names[1],
|
||||
SHIFTED_NORMALIZED_QTRS,
|
||||
],
|
||||
)
|
||||
requested_qtr_data = stacked_last_per_qtr.loc[requested_qtr_idx]
|
||||
requested_qtr_data = requested_qtr_data.reset_index(
|
||||
SHIFTED_NORMALIZED_QTRS,
|
||||
)
|
||||
# Calculate the actual year/quarter being requested and add those in
|
||||
# as columns.
|
||||
(requested_qtr_data[FISCAL_YEAR_FIELD_NAME],
|
||||
requested_qtr_data[FISCAL_QUARTER_FIELD_NAME]) = \
|
||||
split_normalized_quarters(
|
||||
requested_qtr_data[SHIFTED_NORMALIZED_QTRS]
|
||||
)
|
||||
# Once we're left with just dates as the index, we can reindex by all
|
||||
# dates so that we have a value for each calendar date.
|
||||
return requested_qtr_data.unstack(SID_FIELD_NAME).reindex(dates)
|
||||
|
||||
def get_adjustments(self,
|
||||
zero_qtr_data,
|
||||
requested_qtr_data,
|
||||
last_per_qtr,
|
||||
dates,
|
||||
assets,
|
||||
columns):
|
||||
"""
|
||||
Creates an AdjustedArray from the given estimates data for the given
|
||||
dates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
zero_qtr_data : pd.DataFrame
|
||||
The 'time zero' data for each calendar date per sid.
|
||||
requested_qtr_data : pd.DataFrame
|
||||
The requested quarter data for each calendar date per sid.
|
||||
last_per_qtr : pd.DataFrame
|
||||
A DataFrame with a column MultiIndex of [self.estimates.columns,
|
||||
normalized_quarters, sid] that allows easily getting the timeline
|
||||
of estimates for a particular sid for a particular quarter.
|
||||
dates : pd.DatetimeIndex
|
||||
The calendar dates for which estimates data is requested.
|
||||
assets : pd.Int64Index
|
||||
An index of all the assets from the raw data.
|
||||
columns : list of BoundColumn
|
||||
The columns for which adjustments need to be calculated.
|
||||
|
||||
Returns
|
||||
-------
|
||||
adjusted_array : AdjustedArray
|
||||
The array of data and overwrites for the given column.
|
||||
"""
|
||||
col_to_overwrites = defaultdict(dict)
|
||||
zero_qtr_data.sort_index(inplace=True)
|
||||
# Here we want to get the LAST record from each group of records
|
||||
# corresponding to a single quarter. This is to ensure that we select
|
||||
# the most up-to-date event date in case the event date changes.
|
||||
quarter_shifts = zero_qtr_data.groupby(
|
||||
level=[SID_FIELD_NAME, NORMALIZED_QUARTERS]
|
||||
).nth(-1)
|
||||
|
||||
sid_to_idx = dict(zip(assets, range(len(assets))))
|
||||
|
||||
def collect_adjustments(group):
|
||||
next_qtr_start_indices = dates.searchsorted(
|
||||
group[EVENT_DATE_FIELD_NAME].values,
|
||||
side=self.searchsorted_side,
|
||||
)
|
||||
sid = int(group.name)
|
||||
qtrs_with_estimates = group.index.get_level_values(
|
||||
NORMALIZED_QUARTERS
|
||||
).values
|
||||
for idx in next_qtr_start_indices:
|
||||
if 0 < idx < len(dates):
|
||||
# Only add adjustments if the next quarter starts somewhere
|
||||
# in our date index for this sid. Our 'next' quarter can
|
||||
# never start at index 0; a starting index of 0 means that
|
||||
# the next quarter's event date was NaT.
|
||||
self.create_overwrite_for_quarter(
|
||||
col_to_overwrites,
|
||||
idx,
|
||||
last_per_qtr,
|
||||
qtrs_with_estimates,
|
||||
requested_qtr_data,
|
||||
sid,
|
||||
sid_to_idx[sid],
|
||||
columns,
|
||||
)
|
||||
|
||||
quarter_shifts.groupby(level=SID_FIELD_NAME).apply(collect_adjustments)
|
||||
return col_to_overwrites
|
||||
|
||||
def create_overwrite_for_quarter(self,
|
||||
col_to_overwrites,
|
||||
next_qtr_start_idx,
|
||||
last_per_qtr,
|
||||
quarters_with_estimates_for_sid,
|
||||
requested_qtr_data,
|
||||
sid,
|
||||
sid_idx,
|
||||
columns):
|
||||
"""
|
||||
Add entries to the dictionary of columns to adjustments for the given
|
||||
sid and the given quarter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
col_to_overwrites : dict [column_name -> list of ArrayAdjustment]
|
||||
A dictionary mapping column names to all overwrites for those
|
||||
columns.
|
||||
next_qtr_start_idx : int
|
||||
The index of the first day of the next quarter in the calendar
|
||||
dates.
|
||||
last_per_qtr : pd.DataFrame
|
||||
A DataFrame with a column MultiIndex of [self.estimates.columns,
|
||||
normalized_quarters, sid] that allows easily getting the timeline
|
||||
of estimates for a particular sid for a particular quarter; this
|
||||
is particularly useful for getting adjustments for 'next'
|
||||
estimates.
|
||||
quarters_with_estimates_for_sid : np.array
|
||||
An array of all quarters for which there are estimates for the
|
||||
given sid.
|
||||
sid : int
|
||||
The sid for which to create overwrites.
|
||||
sid_idx : int
|
||||
The index of the sid in `assets`.
|
||||
columns : list of BoundColumn
|
||||
The columns for which to create overwrites.
|
||||
"""
|
||||
|
||||
# Find the quarter being requested in the quarter we're
|
||||
# crossing into.
|
||||
requested_quarter = requested_qtr_data[
|
||||
SHIFTED_NORMALIZED_QTRS, sid,
|
||||
].iloc[next_qtr_start_idx]
|
||||
for col in columns:
|
||||
column_name = self.name_map[col.name]
|
||||
# If there are estimates for the requested quarter,
|
||||
# overwrite all values going up to the starting index of
|
||||
# that quarter with estimates for that quarter.
|
||||
if requested_quarter in quarters_with_estimates_for_sid:
|
||||
col_to_overwrites[column_name][next_qtr_start_idx] = [
|
||||
self.create_overwrite_for_estimate(
|
||||
col,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx
|
||||
),
|
||||
]
|
||||
# There are no estimates for the quarter. Overwrite all
|
||||
# values going up to the starting index of that quarter
|
||||
# with the missing value for this column.
|
||||
else:
|
||||
col_to_overwrites[column_name][next_qtr_start_idx] = [
|
||||
self.overwrite_with_null(
|
||||
col,
|
||||
last_per_qtr.index,
|
||||
next_qtr_start_idx,
|
||||
sid_idx
|
||||
),
|
||||
]
|
||||
|
||||
def overwrite_with_null(self,
|
||||
column,
|
||||
dates,
|
||||
next_qtr_start_idx,
|
||||
sid_idx):
|
||||
return self.scalar_overwrites_dict[column.dtype](
|
||||
0,
|
||||
next_qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
column.missing_value
|
||||
)
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
# Separate out getting the columns' datasets and the datasets'
|
||||
# num_announcements attributes to ensure that we're catching the right
|
||||
# AttributeError.
|
||||
col_to_datasets = {col: col.dataset for col in columns}
|
||||
try:
|
||||
groups = groupby(lambda col:
|
||||
col_to_datasets[col].num_announcements,
|
||||
col_to_datasets)
|
||||
except AttributeError:
|
||||
raise AttributeError("Datasets loaded via the "
|
||||
"EarningsEstimatesLoader must define a "
|
||||
"`num_announcements` attribute that defines "
|
||||
"how many quarters out the loader should load"
|
||||
" the data relative to `dates`.")
|
||||
if any(num_qtr < 0 for num_qtr in groups):
|
||||
raise ValueError(
|
||||
INVALID_NUM_QTRS_MESSAGE % ','.join(
|
||||
str(qtr) for qtr in groups if qtr < 0
|
||||
)
|
||||
|
||||
)
|
||||
out = {}
|
||||
# To optimize performance, only work below on assets that are
|
||||
# actually in the raw data.
|
||||
assets_with_data = set(assets) & set(self.estimates[SID_FIELD_NAME])
|
||||
last_per_qtr, stacked_last_per_qtr = self.get_last_data_per_qtr(
|
||||
assets_with_data,
|
||||
columns,
|
||||
dates
|
||||
)
|
||||
# Determine which quarter is immediately next/previous for each
|
||||
# date.
|
||||
zeroth_quarter_idx = self.get_zeroth_quarter_idx(stacked_last_per_qtr)
|
||||
zero_qtr_data = stacked_last_per_qtr.loc[zeroth_quarter_idx]
|
||||
|
||||
for num_announcements, columns in groups.items():
|
||||
requested_qtr_data = self.get_requested_quarter_data(
|
||||
zero_qtr_data,
|
||||
zeroth_quarter_idx,
|
||||
stacked_last_per_qtr,
|
||||
num_announcements,
|
||||
dates,
|
||||
)
|
||||
|
||||
# Calculate all adjustments for the given quarter and accumulate
|
||||
# them for each column.
|
||||
col_to_adjustments = self.get_adjustments(zero_qtr_data,
|
||||
requested_qtr_data,
|
||||
last_per_qtr,
|
||||
dates,
|
||||
assets_with_data,
|
||||
columns)
|
||||
|
||||
# Lookup the asset indexer once, this is so we can reindex
|
||||
# the assets returned into the assets requested for each column.
|
||||
# This depends on the fact that our column multiindex has the same
|
||||
# sids for each field. This allows us to do the lookup once on
|
||||
# level 1 instead of doing the lookup each time per value in
|
||||
# level 0.
|
||||
asset_indexer = assets.get_indexer_for(
|
||||
requested_qtr_data.columns.levels[1],
|
||||
)
|
||||
for col in columns:
|
||||
column_name = self.name_map[col.name]
|
||||
# allocate the empty output with the correct missing value
|
||||
output_array = np.full(
|
||||
(len(dates), len(assets)),
|
||||
col.missing_value,
|
||||
dtype=col.dtype,
|
||||
)
|
||||
# overwrite the missing value with values from the computed
|
||||
# data
|
||||
output_array[
|
||||
:,
|
||||
asset_indexer,
|
||||
] = requested_qtr_data[column_name].values
|
||||
|
||||
out[col] = AdjustedArray(
|
||||
output_array,
|
||||
mask,
|
||||
dict(col_to_adjustments[column_name]),
|
||||
col.missing_value,
|
||||
)
|
||||
return out
|
||||
|
||||
def get_last_data_per_qtr(self, assets_with_data, columns, dates):
|
||||
"""
|
||||
Determine the last piece of information we know for each column on each
|
||||
date in the index for each sid and quarter.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
assets_with_data : pd.Index
|
||||
Index of all assets that appear in the raw data given to the
|
||||
loader.
|
||||
columns : iterable of BoundColumn
|
||||
The columns that need to be loaded from the raw data.
|
||||
dates : pd.DatetimeIndex
|
||||
The calendar of dates for which data should be loaded.
|
||||
|
||||
Returns
|
||||
-------
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
A DataFrame indexed by [dates, sid, normalized_quarters] that has
|
||||
the latest information for each row of the index, sorted by event
|
||||
date.
|
||||
last_per_qtr : pd.DataFrame
|
||||
A DataFrame with columns that are a MultiIndex of [
|
||||
self.estimates.columns, normalized_quarters, sid].
|
||||
"""
|
||||
# Get a DataFrame indexed by date with a MultiIndex of columns of [
|
||||
# self.estimates.columns, normalized_quarters, sid], where each cell
|
||||
# contains the latest data for that day.
|
||||
last_per_qtr = last_in_date_group(
|
||||
self.estimates,
|
||||
dates,
|
||||
assets_with_data,
|
||||
reindex=True,
|
||||
extra_groupers=[NORMALIZED_QUARTERS],
|
||||
)
|
||||
# Forward fill values for each quarter/sid/dataset column.
|
||||
ffill_across_cols(last_per_qtr, columns, self.name_map)
|
||||
# Stack quarter and sid into the index.
|
||||
stacked_last_per_qtr = last_per_qtr.stack(
|
||||
[SID_FIELD_NAME, NORMALIZED_QUARTERS],
|
||||
)
|
||||
# Set date index name for ease of reference
|
||||
stacked_last_per_qtr.index.set_names(
|
||||
SIMULATION_DATES,
|
||||
level=0,
|
||||
inplace=True,
|
||||
)
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
)
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] = pd.to_datetime(
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME]
|
||||
)
|
||||
return last_per_qtr, stacked_last_per_qtr
|
||||
|
||||
|
||||
class NextEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
searchsorted_side = 'right'
|
||||
|
||||
def create_overwrite_for_estimate(self,
|
||||
column,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx):
|
||||
return self.array_overwrites_dict[column.dtype](
|
||||
0,
|
||||
# overwrite thru last qtr
|
||||
next_qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
last_per_qtr[
|
||||
column_name,
|
||||
requested_quarter,
|
||||
sid,
|
||||
].values[:next_qtr_start_idx],
|
||||
)
|
||||
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
return zero_qtrs + (num_announcements - 1)
|
||||
|
||||
def get_zeroth_quarter_idx(self, stacked_last_per_qtr):
|
||||
"""
|
||||
Filters for releases that are on or after each simulation date and
|
||||
determines the next quarter by picking out the upcoming release for
|
||||
each date in the index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters with each row being the latest estimate for the row's
|
||||
index values, sorted by event date.
|
||||
|
||||
Returns
|
||||
-------
|
||||
next_releases_per_date_index : pd.MultiIndex
|
||||
An index of calendar dates, sid, and normalized quarters, for only
|
||||
the rows that have a next event.
|
||||
"""
|
||||
next_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
).nth(0)
|
||||
return next_releases_per_date.index
|
||||
|
||||
|
||||
class PreviousEarningsEstimatesLoader(EarningsEstimatesLoader):
|
||||
searchsorted_side = 'left'
|
||||
|
||||
def create_overwrite_for_estimate(self,
|
||||
column,
|
||||
column_name,
|
||||
dates,
|
||||
next_qtr_start_idx,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx):
|
||||
return self.overwrite_with_null(
|
||||
column,
|
||||
dates,
|
||||
next_qtr_start_idx,
|
||||
sid_idx,
|
||||
)
|
||||
|
||||
def get_shifted_qtrs(self, zero_qtrs, num_announcements):
|
||||
return zero_qtrs - (num_announcements - 1)
|
||||
|
||||
def get_zeroth_quarter_idx(self, stacked_last_per_qtr):
|
||||
"""
|
||||
Filters for releases that are on or after each simulation date and
|
||||
determines the previous quarter by picking out the most recent
|
||||
release relative to each date in the index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters with each row being the latest estimate for the row's
|
||||
index values, sorted by event date.
|
||||
|
||||
Returns
|
||||
-------
|
||||
previous_releases_per_date_index : pd.MultiIndex
|
||||
An index of calendar dates, sid, and normalized quarters, for only
|
||||
the rows that have a previous event.
|
||||
"""
|
||||
previous_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULATION_DATES)
|
||||
].groupby(
|
||||
level=[SIMULATION_DATES, SID_FIELD_NAME],
|
||||
as_index=False,
|
||||
# Here we take advantage of the fact that `stacked_last_per_qtr` is
|
||||
# sorted by event date.
|
||||
).nth(-1)
|
||||
return previous_releases_per_date.index
|
||||
@@ -5,12 +5,12 @@ from six import viewvalues
|
||||
from toolz import groupby, merge
|
||||
|
||||
from .base import PipelineLoader
|
||||
from .frame import DataFrameLoader
|
||||
from zipline.pipeline.common import (
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
next_event_indexer,
|
||||
previous_event_indexer,
|
||||
@@ -41,16 +41,8 @@ def validate_column_specs(events, next_value_columns, previous_value_columns):
|
||||
serve the BoundColumns described by ``next_value_columns`` and
|
||||
``previous_value_columns``.
|
||||
"""
|
||||
required = {
|
||||
TS_FIELD_NAME,
|
||||
SID_FIELD_NAME,
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
}.union(
|
||||
# We also expect any of the field names that our loadable columns
|
||||
# are mapped to.
|
||||
viewvalues(next_value_columns),
|
||||
viewvalues(previous_value_columns),
|
||||
)
|
||||
required = required_event_fields(next_value_columns,
|
||||
previous_value_columns)
|
||||
received = set(events.columns)
|
||||
missing = required - received
|
||||
if missing:
|
||||
|
||||
@@ -2,6 +2,8 @@ import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from zipline.pipeline.common import TS_FIELD_NAME, SID_FIELD_NAME
|
||||
from zipline.utils.numpy_utils import categorical_dtype
|
||||
from zipline.utils.pandas_utils import mask_between_time
|
||||
|
||||
|
||||
@@ -272,3 +274,124 @@ def check_data_query_args(data_query_time, data_query_tz):
|
||||
data_query_tz,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def last_in_date_group(df,
|
||||
dates,
|
||||
assets,
|
||||
reindex=True,
|
||||
have_sids=True,
|
||||
extra_groupers=[]):
|
||||
"""
|
||||
Determine the last piece of information known on each date in the date
|
||||
index for each group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The DataFrame containing the data to be grouped.
|
||||
dates : pd.DatetimeIndex
|
||||
The dates to use for grouping and reindexing.
|
||||
assets : pd.Int64Index
|
||||
The assets that should be included in the column multiindex.
|
||||
reindex : bool
|
||||
Whether or not the DataFrame should be reindexed against the date
|
||||
index. This will add back any dates to the index that were grouped
|
||||
away.
|
||||
have_sids : bool
|
||||
Whether or not the DataFrame has sids. If it does, they will be used
|
||||
in the groupby.
|
||||
extra_groupers : list of str
|
||||
Any extra field names that should be included in the groupby.
|
||||
|
||||
Returns
|
||||
-------
|
||||
last_in_group : pd.DataFrame
|
||||
A DataFrame with dates as the index and fields used in the groupby as
|
||||
levels of a multiindex of columns.
|
||||
|
||||
"""
|
||||
idx = [dates[dates.searchsorted(
|
||||
df[TS_FIELD_NAME].values.astype('datetime64[D]')
|
||||
)]]
|
||||
if have_sids:
|
||||
idx += [SID_FIELD_NAME]
|
||||
idx += extra_groupers
|
||||
|
||||
last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(
|
||||
idx,
|
||||
sort=False,
|
||||
).last()
|
||||
|
||||
# For the number of things that we're grouping by (except TS), unstack
|
||||
# the df. Done this way because of an unresolved pandas bug whereby
|
||||
# passing a list of levels with mixed dtypes to unstack causes the
|
||||
# resulting DataFrame to have all object-type columns.
|
||||
for _ in range(len(idx) - 1):
|
||||
last_in_group = last_in_group.unstack(-1)
|
||||
|
||||
if reindex:
|
||||
if have_sids:
|
||||
cols = last_in_group.columns
|
||||
last_in_group = last_in_group.reindex(
|
||||
index=dates,
|
||||
columns=pd.MultiIndex.from_product(
|
||||
tuple(cols.levels[0:len(extra_groupers) + 1]) + (assets,),
|
||||
names=cols.names,
|
||||
),
|
||||
)
|
||||
else:
|
||||
last_in_group = last_in_group.reindex(dates)
|
||||
|
||||
return last_in_group
|
||||
|
||||
|
||||
def ffill_across_cols(df, columns, name_map):
|
||||
"""
|
||||
Forward fill values in a DataFrame with special logic to handle cases
|
||||
that pd.DataFrame.ffill cannot and cast columns to appropriate types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The DataFrame to do forward-filling on.
|
||||
columns : list of BoundColumn
|
||||
The BoundColumns that correspond to columns in the DataFrame to which
|
||||
special filling and/or casting logic should be applied.
|
||||
name_map: map of string -> string
|
||||
Mapping from the name of each BoundColumn to the associated column
|
||||
name in `df`.
|
||||
"""
|
||||
df.ffill(inplace=True)
|
||||
|
||||
# Fill in missing values specified by each column. This is made
|
||||
# significantly more complex by the fact that we need to work around
|
||||
# two pandas issues:
|
||||
|
||||
# 1) When we have sids, if there are no records for a given sid for any
|
||||
# dates, pandas will generate a column full of NaNs for that sid.
|
||||
# This means that some of the columns in `dense_output` are now
|
||||
# float instead of the intended dtype, so we have to coerce back to
|
||||
# our expected type and convert NaNs into the desired missing value.
|
||||
|
||||
# 2) DataFrame.ffill assumes that receiving None as a fill-value means
|
||||
# that no value was passed. Consequently, there's no way to tell
|
||||
# pandas to replace NaNs in an object column with None using fillna,
|
||||
# so we have to roll our own instead using df.where.
|
||||
for column in columns:
|
||||
column_name = name_map[column.name]
|
||||
# Special logic for strings since `fillna` doesn't work if the
|
||||
# missing value is `None`.
|
||||
if column.dtype == categorical_dtype:
|
||||
df[column_name] = df[
|
||||
column.name
|
||||
].where(pd.notnull(df[column_name]),
|
||||
column.missing_value)
|
||||
else:
|
||||
# We need to execute `fillna` before `astype` in case the
|
||||
# column contains NaNs and needs to be cast to bool or int.
|
||||
# This is so that the NaNs are replaced first, since pandas
|
||||
# can't convert NaNs for those types.
|
||||
df[column_name] = df[
|
||||
column_name
|
||||
].fillna(column.missing_value).astype(column.dtype)
|
||||
|
||||
Reference in New Issue
Block a user