TST: add test to check previous columns w/ multiple qtrs

MAINT: pass column to name dict

MAINT: make check for invalid num columns py3-compatible
This commit is contained in:
Maya Tydykov
2016-08-30 10:13:38 -04:00
parent cc07a00d16
commit 2a09160ca8
11 changed files with 1154 additions and 740 deletions
+83 -114
View File
@@ -202,11 +202,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
)
def _gen_overwrite_adjustment_cases(name,
make_input,
make_expected_output,
dtype,
missing_value):
def _gen_overwrite_adjustment_cases(dtype):
"""
Generate test cases for overwrite adjustments.
@@ -226,6 +222,8 @@ def _gen_overwrite_adjustment_cases(name,
unicode_dtype: ObjectOverwrite,
object_dtype: ObjectOverwrite,
}[dtype]
make_expected_dtype = as_dtype(dtype)
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
if dtype == object_dtype:
# When we're testing object dtypes, we expect to have strings, but
@@ -237,30 +235,30 @@ def _gen_overwrite_adjustment_cases(name,
adjustments = {}
buffer_as_of = [None] * 6
baseline = make_input([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
baseline = make_expected_dtype([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[0] = make_expected_output([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
# Note that row indices are inclusive!
adjustments[1] = [
adjustment_type(0, 0, 0, 0, make_overwrite_value(dtype, 1)),
]
buffer_as_of[1] = make_expected_output([[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
# No adjustment at index 2.
buffer_as_of[2] = buffer_as_of[1]
@@ -269,33 +267,33 @@ def _gen_overwrite_adjustment_cases(name,
adjustment_type(1, 2, 1, 1, make_overwrite_value(dtype, 3)),
adjustment_type(0, 1, 0, 0, make_overwrite_value(dtype, 4)),
]
buffer_as_of[3] = make_expected_output([[4, 2, 2],
[4, 3, 2],
[2, 3, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
[4, 3, 2],
[2, 3, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
adjustments[4] = [
adjustment_type(0, 3, 2, 2, make_overwrite_value(dtype, 5))
]
buffer_as_of[4] = make_expected_output([[4, 2, 5],
[4, 3, 5],
[2, 3, 5],
[2, 2, 5],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
[4, 3, 5],
[2, 3, 5],
[2, 2, 5],
[2, 2, 2],
[2, 2, 2]])
adjustments[5] = [
adjustment_type(0, 4, 1, 1, make_overwrite_value(dtype, 6)),
adjustment_type(2, 2, 2, 2, make_overwrite_value(dtype, 7)),
]
buffer_as_of[5] = make_expected_output([[4, 6, 5],
[4, 6, 5],
[2, 6, 7],
[2, 6, 5],
[2, 6, 2],
[2, 2, 2]])
buffer_as_of[5] = make_expected_dtype([[4, 6, 5],
[4, 6, 5],
[2, 6, 7],
[2, 6, 5],
[2, 6, 2],
[2, 2, 2]])
return _gen_expectations(
baseline,
@@ -306,11 +304,7 @@ def _gen_overwrite_adjustment_cases(name,
)
def _gen_overwrite_1d_array_adjustment_case(name,
make_input,
make_expected_output,
dtype,
missing_value):
def _gen_overwrite_1d_array_adjustment_case(dtype):
"""
Generate test cases for overwrite adjustments.
@@ -327,21 +321,24 @@ def _gen_overwrite_1d_array_adjustment_case(name,
float64_dtype: Float641DArrayOverwrite,
datetime64ns_dtype: Datetime641DArrayOverwrite,
}[dtype]
make_expected_dtype = as_dtype(dtype)
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
adjustments = {}
buffer_as_of = [None] * 6
baseline = make_input([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
baseline = make_expected_dtype([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[0] = make_expected_output([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
vals1 = [1]
# Note that row indices are inclusive!
@@ -351,12 +348,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
array([coerce_to_dtype(dtype, val) for val in vals1])
)
]
buffer_as_of[1] = make_input([[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
# No adjustment at index 2.
buffer_as_of[2] = buffer_as_of[1]
@@ -368,12 +365,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
array([coerce_to_dtype(dtype, val) for val in vals3])
)
]
buffer_as_of[3] = make_input([[4, 2, 2],
[4, 2, 2],
[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
[4, 2, 2],
[1, 2, 2],
[2, 2, 2],
[2, 2, 2],
[2, 2, 2]])
vals4 = [5] * 4
adjustments[4] = [
@@ -381,12 +378,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
0, 3, 2, 2,
array([coerce_to_dtype(dtype, val) for val in vals4]))
]
buffer_as_of[4] = make_input([[4, 2, 5],
[4, 2, 5],
[1, 2, 5],
[2, 2, 5],
[2, 2, 2],
[2, 2, 2]])
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
[4, 2, 5],
[1, 2, 5],
[2, 2, 5],
[2, 2, 2],
[2, 2, 2]])
vals5 = range(1, 6)
adjustments[5] = [
@@ -394,12 +391,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
0, 4, 1, 1,
array([coerce_to_dtype(dtype, val) for val in vals5])),
]
buffer_as_of[5] = make_input([[4, 1, 5],
[4, 2, 5],
[1, 3, 5],
[2, 4, 5],
[2, 5, 2],
[2, 2, 2]])
buffer_as_of[5] = make_expected_dtype([[4, 1, 5],
[4, 2, 5],
[1, 3, 5],
[2, 4, 5],
[2, 5, 2],
[2, 2, 2]])
return _gen_expectations(
baseline,
@@ -532,38 +529,10 @@ class AdjustedArrayTestCase(TestCase):
@parameterized.expand(
chain(
_gen_overwrite_adjustment_cases(
'float',
make_input=as_dtype(float64_dtype),
make_expected_output=as_dtype(float64_dtype),
dtype=float64_dtype,
missing_value=default_missing_value_for_dtype(float64_dtype),
),
_gen_overwrite_adjustment_cases(
'datetime',
make_input=as_dtype(datetime64ns_dtype),
make_expected_output=as_dtype(datetime64ns_dtype),
dtype=datetime64ns_dtype,
missing_value=default_missing_value_for_dtype(
datetime64ns_dtype,
),
),
_gen_overwrite_1d_array_adjustment_case(
'float',
make_input=as_dtype(float64_dtype),
make_expected_output=as_dtype(float64_dtype),
dtype=float64_dtype,
missing_value=default_missing_value_for_dtype(float64_dtype),
),
_gen_overwrite_1d_array_adjustment_case(
'datetime',
make_input=as_dtype(datetime64ns_dtype),
make_expected_output=as_dtype(datetime64ns_dtype),
dtype=datetime64ns_dtype,
missing_value=default_missing_value_for_dtype(
datetime64ns_dtype,
),
),
_gen_overwrite_adjustment_cases(float64_dtype),
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
_gen_overwrite_1d_array_adjustment_case(float64_dtype),
_gen_overwrite_1d_array_adjustment_case(datetime64ns_dtype),
# There are six cases here:
# Using np.bytes/np.unicode/object arrays as inputs.
# Passing np.bytes/np.unicode/object arrays to LabelArray,
File diff suppressed because it is too large Load Diff
+27 -23
View File
@@ -3,7 +3,7 @@ from cpython cimport Py_EQ
from pandas import isnull, Timestamp
from numpy cimport float64_t, uint8_t, int64_t
from numpy import asarray, datetime64, float64
from numpy import asarray, datetime64, float64, int64
# Purely for readability. There aren't C-level declarations for these types.
ctypedef object Int64Index_t
ctypedef object DatetimeIndex_t
@@ -451,28 +451,32 @@ cdef class Datetime641DArrayOverwrite(ArrayAdjustment):
Example
-------
>>> import numpy as np
>>> arr = np.arange(25, dtype=float).reshape(5, 5)
>>> arr
array([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.],
[ 20., 21., 22., 23., 24.]])
>>> import numpy as np; import pandas as pd
>>> dts = pd.date_range('2014', freq='D', periods=9, tz='UTC')
>>> arr = dts.values.reshape(3, 3)
>>> arr == np.datetime64(0, 'ns')
array([[False, False, False],
[False, False, False],
[False, False, False]], dtype=bool)
>>> adj = Datetime641DArrayOverwrite(
... row_start=0,
... row_end=3,
... column_start=0,
... column_end=0,
... values=np.array([1, 2, 3, 4]),
)
>>> adj.mutate(arr)
>>> arr
array([[ 1., 1., 2., 3., 4.],
[ 2., 6., 7., 8., 9.],
[ 3., 11., 12., 13., 14.],
[ 4., 16., 17., 18., 19.],
[ 20., 21., 22., 23., 24.]])
... first_row=1,
... last_row=2,
... first_col=1,
... last_col=2,
... values=np.array([
... np.datetime64(0, 'ns'),
... np.datetime64(1, 'ns')
... ])
... )
>>> adj.mutate(arr.view(np.int64))
>>> arr == np.datetime64(0, 'ns')
array([[False, False, False],
[False, True, True],
[False, False, False]], dtype=bool)
>>> arr == np.datetime64(1, 'ns')
array([[False, False, False],
[False, False, False],
[False, True, True]], dtype=bool)
"""
cdef:
readonly int64_t[:] values
@@ -598,7 +602,7 @@ cdef datetime_to_int(object datetimelike):
datetimelike.dtype.name,
)
return datetimelike.astype(int)
return datetimelike.astype(int64)
cdef class Datetime64Adjustment(_Int64Adjustment):
+6 -3
View File
@@ -1096,12 +1096,15 @@ class BlazeLoader(dict):
sparse_deltas = last_in_date_group(non_novel_deltas,
dates,
assets,
reindex=False)
reindex=False,
have_sids=have_sids)
dense_output = last_in_date_group(sparse_output,
dates,
assets,
reindex=True)
ffill_across_cols(dense_output, columns)
reindex=True,
have_sids=have_sids)
ffill_across_cols(dense_output, columns, {c.name: c.name
for c in columns})
if have_sids:
adjustments_from_deltas = adjustments_from_deltas_with_sids
column_view = identity
+6 -4
View File
@@ -25,6 +25,8 @@ class BlazeEstimatesLoader(PipelineLoader):
----------
expr : Expr
The expression representing the data to load.
columns : dict[str -> str]
A dict mapping BoundColumn names to the associated names in `expr`.
resources : dict, optional
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
@@ -33,8 +35,6 @@ class BlazeEstimatesLoader(PipelineLoader):
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezeone to use for the data query cutoff.
dataset : DataSet
The DataSet object for which this loader loads data.
Notes
-----
@@ -43,12 +43,14 @@ class BlazeEstimatesLoader(PipelineLoader):
Dim * {{
{SID_FIELD_NAME}: int64,
{TS_FIELD_NAME}: datetime,
{FISCAL_YEAR_FIELD_NAME}: float64,
{FISCAL_QUARTER_FIELD_NAME}: float64,
{EVENT_DATE_FIELD_NAME}: datetime,
}}
And other dataset-specific fields, where each row of the table is a
record including the sid to identify the company, the timestamp where we
learned about the announcement, and the date when the earnings will be
announced.
learned about the announcement, and the date of the event.
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
start the backtest with knowledge of all announcements.
+12 -6
View File
@@ -24,6 +24,10 @@ class BlazeEventsLoader(PipelineLoader):
----------
expr : Expr
The expression representing the data to load.
next_value_columns : dict[BoundColumn -> raw column name]
A dict mapping 'next' BoundColumns to their column names in `expr`.
previous_value_columns : dict[BoundColumn -> raw column name]
A dict mapping 'previous' BoundColumns to their column names in `expr`.
resources : dict, optional
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
@@ -32,8 +36,6 @@ class BlazeEventsLoader(PipelineLoader):
The time to use for the data query cutoff.
data_query_tz : tzinfo or str
The timezone to use for the data query cutoff.
dataset : DataSet
The DataSet object for which this loader loads data.
Notes
-----
@@ -42,12 +44,12 @@ class BlazeEventsLoader(PipelineLoader):
Dim * {{
{SID_FIELD_NAME}: int64,
{TS_FIELD_NAME}: datetime,
{EVENT_DATE_FIELD_NAME}: datetime,
}}
And other dataset-specific fields, where each row of the table is a
record including the sid to identify the company, the timestamp where we
learned about the announcement, and the date when the earnings will be z
announced.
learned about the announcement, and the event date.
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
start the backtest with knowledge of all announcements.
@@ -84,8 +86,12 @@ class BlazeEventsLoader(PipelineLoader):
self._data_query_tz = data_query_tz
def load_adjusted_array(self, columns, dates, assets, mask):
raw = load_raw_data(assets, dates, self._data_query_time,
self._data_query_tz, self._expr, self._odo_kwargs)
raw = load_raw_data(assets,
dates,
self._data_query_time,
self._data_query_tz,
self._expr,
self._odo_kwargs)
return EventsLoader(
events=raw,
+9 -4
View File
@@ -6,7 +6,11 @@ from zipline.pipeline.loaders.utils import (
)
def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
def load_raw_data(assets,
dates,
data_query_time,
data_query_tz,
expr,
odo_kwargs):
"""
given an expression representing data to load, perform normalization and
@@ -25,13 +29,14 @@ def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
`time`.
expr : expr
the expression representing the data to load.
odo_kwargs : dict, optional
odo_kwargs : dict
extra keyword arguments to pass to odo when executing the expression.
returns
-------
raw : pd.dataframe
the data symbolized by `expr` materialized in a dataframe.
The result of computing expr and materializing the result as a
dataframe.
"""
lower_dt, upper_dt = normalize_data_query_bounds(
dates[0],
@@ -45,7 +50,7 @@ def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
upper_dt,
odo_kwargs,
)
sids = raw.loc[:, SID_FIELD_NAME]
sids = raw[SID_FIELD_NAME]
raw.drop(
sids[~sids.isin(assets)].index,
inplace=True
+259 -122
View File
@@ -1,9 +1,8 @@
from abc import abstractmethod
from collections import defaultdict
import numpy as np
import pandas as pd
from six import viewvalues
from toolz import groupby
from zipline.lib.adjusted_array import AdjustedArray
from zipline.lib.adjustment import (Datetime641DArrayOverwrite,
Float641DArrayOverwrite)
@@ -22,14 +21,15 @@ from zipline.pipeline.loaders.utils import (
last_in_date_group
)
NORMALIZED_QUARTERS = 'normalized_quarters'
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
INVALID_NUM_QTRS_MESSAGE = "Passed invalid number of quarters %s; " \
"must pass a number of quarters >= 0"
NEXT_FISCAL_QUARTER = 'next_fiscal_quarter'
NEXT_FISCAL_YEAR = 'next_fiscal_year'
NORMALIZED_QUARTERS = 'normalized_quarters'
PREVIOUS_FISCAL_QUARTER = 'previous_fiscal_quarter'
PREVIOUS_FISCAL_YEAR = 'previous_fiscal_year'
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
SIMULTATION_DATES = 'dates'
@@ -86,10 +86,10 @@ def validate_column_specs(events, columns):
class QuarterEstimatesLoader(PipelineLoader):
def __init__(self,
estimates,
base_column_name_map):
name_map):
validate_column_specs(
estimates,
base_column_name_map
name_map
)
self.estimates = estimates[
@@ -97,12 +97,16 @@ class QuarterEstimatesLoader(PipelineLoader):
estimates[FISCAL_QUARTER_FIELD_NAME].notnull() &
estimates[FISCAL_YEAR_FIELD_NAME].notnull()
]
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
self.estimates[FISCAL_YEAR_FIELD_NAME],
self.estimates[FISCAL_QUARTER_FIELD_NAME],
)
self.base_column_name_map = base_column_name_map
self.name_map = name_map
@abstractmethod
def load_quarters(self, num_quarters, last, dates):
pass
raise NotImplementedError('load_quarters')
def get_requested_data_for_col(self, stacked_last_per_qtr, idx, dates):
"""
@@ -111,8 +115,8 @@ class QuarterEstimatesLoader(PipelineLoader):
Parameters
----------
stacked_last_per_qtr : pd.DataFrame
The latest estimate known per sid per date per quarter with the
dates, normalized quarter, and sid as the index.
The latest estimate known with the dates, normalized quarter, and
sid as the index.
idx : pd.MultiIndex
The index of the row of the requested quarter from each date for
each sid.
@@ -122,16 +126,18 @@ class QuarterEstimatesLoader(PipelineLoader):
Returns
--------
requested_qtr_data : pd.DataFrame
The DataFrame with final values for the requested quarter for all
columns; `dates` are the index and columns are a MultiIndex with
sids at the top level and the dataset columns on the bottom.
The DataFrame with the latest values for the requested quarter
for all columns; `dates` are the index and columns are a MultiIndex
with sids at the top level and the dataset columns on the bottom.
"""
requested_qtr_data = stacked_last_per_qtr.loc[idx]
# We no longer need this in the index, but we do need it as a column
# to calculate adjustments.
# We no longer need the shifted normalized quarters in the index, but
# we do need it as a column to calculate adjustments.
requested_qtr_data = requested_qtr_data.reset_index(
SHIFTED_NORMALIZED_QTRS
)
# Calculate the actual year/quarter being requested and add those in
# as columns.
(requested_qtr_data[FISCAL_YEAR_FIELD_NAME],
requested_qtr_data[FISCAL_QUARTER_FIELD_NAME]) = \
split_normalized_quarters(
@@ -154,8 +160,7 @@ class QuarterEstimatesLoader(PipelineLoader):
column_name,
column,
mask,
assets,
qtr_crossover_point):
assets):
"""
Creates an AdjustedArray from the given estimates data for the given
dates.
@@ -183,18 +188,17 @@ class QuarterEstimatesLoader(PipelineLoader):
computed.
column : BoundColumn
The column for which the AdjustedArray is being computed.
mask :
assets :
qtr_crossover_point :
Whether we should use the 'right' or 'left' side when doing
searchsorted on the dates for quarter boundaries.
mask : np.array
Mask array of dimensions len(dates) X len(assets).
assets : pd.Int64Index
An index of all the assets from the raw data.
Returns
-------
adjusted_array : AdjustedArray
The array of data and overwrites for the given column.
"""
adjustments = defaultdict(list)
adjustments = {}
requested_qtr_data = self.get_requested_data_for_col(
stacked_last_per_qtr, requested_qtr_idx, dates
)
@@ -204,10 +208,8 @@ class QuarterEstimatesLoader(PipelineLoader):
zero_qtr_data = zero_qtr_data.reset_index(NORMALIZED_QUARTERS)
if column.dtype == datetime64ns_dtype:
overwrite = Datetime641DArrayOverwrite
missing_value = np.datetime64('NaT', 'ns')
else:
overwrite = Float641DArrayOverwrite
missing_value = np.NaN
for sid_idx, sid in enumerate(assets):
zero_qtr_sid_data = zero_qtr_data[
zero_qtr_data.index.get_level_values(SID_FIELD_NAME) == sid
@@ -225,7 +227,7 @@ class QuarterEstimatesLoader(PipelineLoader):
]
# For the given sid, determine which quarters we have estimates
# for.
quarters_with_estimates_for_sid = last_per_qtr.xs(
qtrs_with_estimates_for_sid = last_per_qtr.xs(
sid, axis=1, level=SID_FIELD_NAME
).groupby(axis=1, level=1).first().columns.values
for row_indexer in list(qtr_shifts.index):
@@ -233,108 +235,162 @@ class QuarterEstimatesLoader(PipelineLoader):
# after this row. This isn't the starting index of the
# requested quarter, but simply the date we cross over into a
# new quarter.
qtr_start_idx = dates.searchsorted(
next_qtr_start_idx = dates.searchsorted(
zero_qtr_data.loc[
row_indexer
][EVENT_DATE_FIELD_NAME],
side=qtr_crossover_point
side='left'
if isinstance(self, PreviousQuartersEstimatesLoader)
else 'right'
)
# Only add adjustments if the next quarter starts somewhere in
# our date index for this sid. Our 'next' quarter can never
# start at index 0; a starting index of 0 means that the next
# quarter's event date was NaT.
if 0 < qtr_start_idx < len(dates):
# Find the quarter being requested in the quarter we're
# crossing into.
requested_quarter = requested_qtr_data[
SHIFTED_NORMALIZED_QTRS
][sid].iloc[qtr_start_idx]
# If there are estimates for the requested quarter,
# overwrite all values going up to the starting index of
# that quarter with estimates for that quarter.
if requested_quarter in quarters_with_estimates_for_sid:
adjustments[qtr_start_idx] = \
[overwrite(
0,
qtr_start_idx - 1, # overwrite thru last qtr
sid_idx,
sid_idx,
last_per_qtr[column_name,
requested_quarter,
sid][:qtr_start_idx].values)]
# There are no estimates for the quarter. Overwrite all
# values going up to the starting index of that quarter
# with the missing value for this column.
else:
adjustments[qtr_start_idx] = [
overwrite(
0,
qtr_start_idx - 1,
sid_idx,
sid_idx,
np.array(
[missing_value] *
len(last_per_qtr.index[:qtr_start_idx]))
)
]
adjustments[next_qtr_start_idx] = \
self.create_overwrite_for_quarter(
next_qtr_start_idx,
column,
column_name,
dates,
last_per_qtr,
overwrite,
qtrs_with_estimates_for_sid,
requested_qtr_data,
sid,
sid_idx,
)
return AdjustedArray(
requested_qtr_data[column_name].values.astype(column.dtype),
mask,
dict(adjustments),
column.missing_value,
)
requested_qtr_data[column_name].values.astype(column.dtype),
mask,
dict(adjustments),
column.missing_value,
)
def create_overwrite_for_quarter(self,
next_qtr_start_idx,
column,
column_name,
dates,
last_per_qtr,
overwrite,
quarters_with_estimates_for_sid,
requested_qtr_data,
sid,
sid_idx):
# Only add adjustments if the next quarter starts somewhere in
# our date index for this sid. Our 'next' quarter can never
# start at index 0; a starting index of 0 means that the next
# quarter's event date was NaT.
if 0 < next_qtr_start_idx < len(dates):
# Find the quarter being requested in the quarter we're
# crossing into.
requested_quarter = requested_qtr_data[
SHIFTED_NORMALIZED_QTRS
][sid].iloc[next_qtr_start_idx]
# If there are estimates for the requested quarter,
# overwrite all values going up to the starting index of
# that quarter with estimates for that quarter.
if requested_quarter in quarters_with_estimates_for_sid:
return self.create_overwrite_for_estimate(
column,
column_name,
last_per_qtr,
next_qtr_start_idx,
overwrite,
requested_quarter,
sid,
sid_idx
)
# There are no estimates for the quarter. Overwrite all
# values going up to the starting index of that quarter
# with the missing value for this column.
else:
return self.overwrite_with_null(
column,
last_per_qtr,
next_qtr_start_idx,
overwrite,
sid_idx
)
def overwrite_with_null(self,
column,
last_per_qtr,
next_qtr_start_idx,
overwrite,
sid_idx):
return [overwrite(
0,
next_qtr_start_idx - 1,
sid_idx,
sid_idx,
np.full(
len(
last_per_qtr.index[:next_qtr_start_idx]
),
column.missing_value,
dtype=column.dtype
))]
def load_adjusted_array(self, columns, dates, assets, mask):
# TODO: how can we enforce that datasets have the num_quarters
# attribute, given that they're created dynamically?
groups = groupby(lambda x: x.dataset.num_quarters, columns)
groups_columns = dict(groups)
if (pd.Series(groups_columns.keys()) < 0).any():
raise ValueError("Must pass a number of quarters >= 0")
# Separate out getting the columns' datasets and the datasets'
# num_quarters attributes to ensure that we're catching the right
# AttributeError.
col_to_datasets = {col: col.dataset for col in columns}
try:
groups = groupby(lambda col: col_to_datasets[col].num_quarters,
col_to_datasets)
except AttributeError:
raise AttributeError("Datasets loaded via the "
"QuarterEstimatesLoader must define a "
"`num_quarters` attribute that defines how "
"many quarters out the loader should load "
"the data relative to `dates`.")
if any(num_qtr < 0 for num_qtr in groups):
raise ValueError(
INVALID_NUM_QTRS_MESSAGE % ','.join(
str(qtr) for qtr in groups if qtr < 0
)
)
out = {}
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
self.estimates[FISCAL_YEAR_FIELD_NAME],
self.estimates[FISCAL_QUARTER_FIELD_NAME],
)
for num_quarters, columns in groups_columns.items():
# The column's dataset is itself dynamic and the mapping we
# actually want is to its dataset's parent's column name.
name_map = {c: self.base_column_name_map[
getattr(c.dataset.__base__, c.name)
] for c in columns}
for num_quarters, columns in groups.items():
# Determine the last piece of information we know for each column
# on each date in the index for each sid and quarter.
last_per_qtr = last_in_date_group(
self.estimates, True, dates, assets,
self.estimates, dates, assets, reindex=True,
extra_groupers=[NORMALIZED_QUARTERS]
)
# Forward fill values for each quarter/sid/dataset column.
ffill_across_cols(last_per_qtr, columns)
ffill_across_cols(last_per_qtr, columns, self.name_map)
# Stack quarter and sid into the index.
stacked_last_per_qtr = last_per_qtr.stack([NORMALIZED_QUARTERS,
SID_FIELD_NAME])
stacked_last_per_qtr = last_per_qtr.stack([SID_FIELD_NAME,
NORMALIZED_QUARTERS])
# Set date index name for ease of reference
stacked_last_per_qtr.index.set_names(SIMULTATION_DATES, 0, True)
stacked_last_per_qtr.index.set_names(SIMULTATION_DATES,
level=0,
inplace=True)
# We want to know the most recent/next event relative to each date.
stacked_last_per_qtr = stacked_last_per_qtr.sort(
EVENT_DATE_FIELD_NAME
)
# Determine which quarter is next/previous for each date.
shifted_qtr_data = self.load_quarters(num_quarters,
stacked_last_per_qtr)
zero_qtr_idx = shifted_qtr_data.index
requested_qtr_idx = shifted_qtr_data.set_index([
shifted_qtr_data.index.get_level_values(
SIMULTATION_DATES
),
shifted_qtr_data[SHIFTED_NORMALIZED_QTRS],
shifted_qtr_data.index.get_level_values(
SID_FIELD_NAME
)]
).index
shifted_qtr_data.index.get_level_values(
SIMULTATION_DATES
),
shifted_qtr_data.index.get_level_values(
SID_FIELD_NAME
),
shifted_qtr_data[SHIFTED_NORMALIZED_QTRS]
]).index
for c in columns:
column_name = name_map[c]
column_name = self.name_map[c.name]
adjusted_array = self.get_adjustments(zero_qtr_idx,
requested_qtr_idx,
stacked_last_per_qtr,
@@ -343,26 +399,68 @@ class QuarterEstimatesLoader(PipelineLoader):
column_name,
c,
mask,
assets,
self.qtr_crossover_point)
assets)
out[c] = adjusted_array
return out
class NextQuartersEstimatesLoader(QuarterEstimatesLoader):
qtr_crossover_point = 'right'
def create_overwrite_for_estimate(self,
column,
column_name,
last_per_qtr,
next_qtr_start_idx,
overwrite,
requested_quarter,
sid,
sid_idx):
return [overwrite(
0,
# overwrite thru last qtr
next_qtr_start_idx - 1,
sid_idx,
sid_idx,
last_per_qtr[
column_name,
requested_quarter,
sid
][0:next_qtr_start_idx].values)]
def load_quarters(self, num_quarters, stacked_last_per_qtr):
# Filter for releases that are on or after each simulation date and
# determine the next quarter by picking out the upcoming release for
# each date in the index.
stacked_last_per_qtr = stacked_last_per_qtr.sort(
EVENT_DATE_FIELD_NAME
)
"""
Filters for releases that are on or after each simulation date and
determines the next quarter by picking out the upcoming release for
each date in the index. Adda a SHIFTED_NORMALIZED_QTRS column which
contains the requested next quarter for each calendar date and sid.
Parameters
----------
num_quarters : int
Number of quarters to go out in the future.
stacked_last_per_qtr : pd.DataFrame
A DataFrame with index of calendar dates, sid, and normalized
quarters with each row being the latest estimate for the row's
index values, sorted by event date.
Returns
-------
next_releases_per_date : pd.DataFrame
A DataFrame with index of calendar dates, sid, and normalized
quarters, keeping only rows with next event information relative to
the index values and with an added column for
SHIFTED_NORMALIZED_QTRS, which contains the requested quarter for
each row.
"""
# We reset the index here because in pandas3, a groupby on the index
# will set the index to just the items in the groupby, so we will lose
# the normalized quarters.
next_releases_per_date = stacked_last_per_qtr.loc[
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(0)
].reset_index(NORMALIZED_QUARTERS).groupby(
level=[SIMULTATION_DATES, SID_FIELD_NAME]
).nth(0).set_index(NORMALIZED_QUARTERS, append=True)
next_releases_per_date[
SHIFTED_NORMALIZED_QTRS
] = next_releases_per_date.index.get_level_values(
@@ -372,18 +470,57 @@ class NextQuartersEstimatesLoader(QuarterEstimatesLoader):
class PreviousQuartersEstimatesLoader(QuarterEstimatesLoader):
qtr_crossover_point = 'left'
def create_overwrite_for_estimate(self,
column,
column_name,
last_per_qtr,
next_qtr_start_idx,
overwrite,
requested_quarter,
sid,
sid_idx):
return self.overwrite_with_null(column,
last_per_qtr,
next_qtr_start_idx,
overwrite,
sid_idx)
def load_quarters(self, num_quarters, stacked_last_per_qtr):
# Filter for releases that are on or before each simulation date and
# determine the previous quarter by picking out the upcoming release
# for each date in the index.
stacked_last_per_qtr = stacked_last_per_qtr.sort(EVENT_DATE_FIELD_NAME)
"""
Filters for releases that are on or after each simulation date and
determines the previous quarter by picking out the most recent
release relative to each date in the index. Adds a
SHIFTED_NORMALIZED_QTRS column which contains the requested previous
quarter for each calendar date and sid.
Parameters
----------
num_quarters : int
Number of quarters to go out in the past.
stacked_last_per_qtr : pd.DataFrame
A DataFrame with index of calendar dates, sid, and normalized
quarters with each row being the latest estimate for the row's
index values, sorted by event date.
Returns
-------
next_releases_per_date : pd.DataFrame
A DataFrame with index of calendar dates, sid, and normalized
quarters, keeping only rows with have a previous event relative
to the index values and with an added column for
SHIFTED_NORMALIZED_QTRS, which contains the requested quarter for
each row.
"""
# We reset the index here because in pandas3, a groupby on the index
# will set the index to just the items in the groupby, so we will lose
# the normalized quarters.
previous_releases_per_date = stacked_last_per_qtr.loc[
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
stacked_last_per_qtr.index.get_level_values(
SIMULTATION_DATES
)].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(-1)
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
].reset_index(NORMALIZED_QUARTERS).groupby(
level=[SIMULTATION_DATES, SID_FIELD_NAME]
).nth(-1).set_index(NORMALIZED_QUARTERS, append=True)
previous_releases_per_date[
SHIFTED_NORMALIZED_QTRS
] = previous_releases_per_date.index.get_level_values(
+18 -14
View File
@@ -276,7 +276,7 @@ def check_data_query_args(data_query_time, data_query_tz):
)
def last_in_date_group(df, reindex, dates, assets, have_sids=True,
def last_in_date_group(df, dates, assets, reindex=True, have_sids=True,
extra_groupers=[]):
"""
Determine the last piece of information known on each date in the date
@@ -286,14 +286,14 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
----------
df : pd.DataFrame
The DataFrame containing the data to be grouped.
reindex : bool
Whether or not the DataFrame should be reindexed against the date
index. This will add back any dates to the index that were grouped
away.
dates : pd.DatetimeIndex
The dates to use for grouping and reindexing.
assets : pd.Int64Index
The assets that should be included in the column multiindex.
reindex : bool
Whether or not the DataFrame should be reindexed against the date
index. This will add back any dates to the index that were grouped
away.
have_sids : bool
Whether or not the DataFrame has sids. If it does, they will be used
in the groupby.
@@ -307,11 +307,11 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
levels of a multiindex of columns.
"""
idx = dates[dates.searchsorted(
idx = [dates[dates.searchsorted(
df[TS_FIELD_NAME].values.astype('datetime64[D]')
)]
)]]
if have_sids:
idx = [idx, SID_FIELD_NAME]
idx += [SID_FIELD_NAME]
idx += extra_groupers
last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(
@@ -321,7 +321,7 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
# For the number of things that we're grouping by (except TS), unstack
# the df
last_in_group = last_in_group.unstack([-1, -2])
last_in_group = last_in_group.unstack(list(range(-1, -len(idx), -1)))
if reindex:
if have_sids:
@@ -339,7 +339,7 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
return last_in_group
def ffill_across_cols(df, columns):
def ffill_across_cols(df, columns, name_map):
"""
Forward fill values in a DataFrame with special logic to handle cases
that pd.DataFrame.ffill cannot and cast columns to appropriate types.
@@ -351,6 +351,9 @@ def ffill_across_cols(df, columns):
columns : list of BoundColumn
The BoundColumns that correspond to columns in the DataFrame to which
special filling and/or casting logic should be applied.
name_map: map of string -> string
Mapping from the name of each BoundColumn to the associated column
name in `df`.
"""
df.ffill(inplace=True)
@@ -369,18 +372,19 @@ def ffill_across_cols(df, columns):
# pandas to replace NaNs in an object column with None using fillna,
# so we have to roll our own instead using df.where.
for column in columns:
column_name = name_map[column.name]
# Special logic for strings since `fillna` doesn't work if the
# missing value is `None`.
if column.dtype == categorical_dtype:
df[column.name] = df[
df[column_name] = df[
column.name
].where(pd.notnull(df[column.name]),
].where(pd.notnull(df[column_name]),
column.missing_value)
else:
# We need to execute `fillna` before `astype` in case the
# column contains NaNs and needs to be cast to bool or int.
# This is so that the NaNs are replaced first, since pandas
# can't convert NaNs for those types.
df[column.name] = df[
column.name
df[column_name] = df[
column_name
].fillna(column.missing_value).astype(column.dtype)
+6
View File
@@ -49,8 +49,14 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader
from zipline.utils import security_list
from zipline.utils.calendars import get_calendar
from zipline.utils.input_validation import expect_dimensions
<<<<<<< HEAD
from zipline.utils.numpy_utils import as_column, isnat
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
=======
from zipline.utils.numpy_utils import (
as_column,
)
>>>>>>> WIP
from zipline.utils.sentinel import sentinel
import numpy as np
+3 -2
View File
@@ -34,13 +34,14 @@ from ..finance.trading import TradingEnvironment
from ..utils import factory
from ..utils.classproperty import classproperty
from ..utils.final import FinalMeta, final
from .core import tmp_asset_finder, make_simple_equity_info
from .core import (tmp_asset_finder, make_simple_equity_info)
from zipline.assets import Equity, Future
from zipline.pipeline import SimplePipelineEngine
from zipline.pipeline.loaders.testing import make_seeded_random_loader
from zipline.utils.calendars import (
get_calendar,
register_calendar)
register_calendar
)
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):