mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 21:20:06 +08:00
TST: add test to check previous columns w/ multiple qtrs
MAINT: pass column to name dict MAINT: make check for invalid num columns py3-compatible
This commit is contained in:
@@ -202,11 +202,7 @@ def _gen_multiplicative_adjustment_cases(dtype):
|
||||
)
|
||||
|
||||
|
||||
def _gen_overwrite_adjustment_cases(name,
|
||||
make_input,
|
||||
make_expected_output,
|
||||
dtype,
|
||||
missing_value):
|
||||
def _gen_overwrite_adjustment_cases(dtype):
|
||||
"""
|
||||
Generate test cases for overwrite adjustments.
|
||||
|
||||
@@ -226,6 +222,8 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
unicode_dtype: ObjectOverwrite,
|
||||
object_dtype: ObjectOverwrite,
|
||||
}[dtype]
|
||||
make_expected_dtype = as_dtype(dtype)
|
||||
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
|
||||
|
||||
if dtype == object_dtype:
|
||||
# When we're testing object dtypes, we expect to have strings, but
|
||||
@@ -237,30 +235,30 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
|
||||
adjustments = {}
|
||||
buffer_as_of = [None] * 6
|
||||
baseline = make_input([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
baseline = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
buffer_as_of[0] = make_expected_output([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
adjustment_type(0, 0, 0, 0, make_overwrite_value(dtype, 1)),
|
||||
]
|
||||
buffer_as_of[1] = make_expected_output([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# No adjustment at index 2.
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
@@ -269,33 +267,33 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
adjustment_type(1, 2, 1, 1, make_overwrite_value(dtype, 3)),
|
||||
adjustment_type(0, 1, 0, 0, make_overwrite_value(dtype, 4)),
|
||||
]
|
||||
buffer_as_of[3] = make_expected_output([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
[2, 3, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
[2, 3, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
adjustments[4] = [
|
||||
adjustment_type(0, 3, 2, 2, make_overwrite_value(dtype, 5))
|
||||
]
|
||||
buffer_as_of[4] = make_expected_output([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
[2, 3, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
[2, 3, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
adjustments[5] = [
|
||||
adjustment_type(0, 4, 1, 1, make_overwrite_value(dtype, 6)),
|
||||
adjustment_type(2, 2, 2, 2, make_overwrite_value(dtype, 7)),
|
||||
]
|
||||
buffer_as_of[5] = make_expected_output([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
[2, 6, 7],
|
||||
[2, 6, 5],
|
||||
[2, 6, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[5] = make_expected_dtype([[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
[2, 6, 7],
|
||||
[2, 6, 5],
|
||||
[2, 6, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
@@ -306,11 +304,7 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
)
|
||||
|
||||
|
||||
def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
make_input,
|
||||
make_expected_output,
|
||||
dtype,
|
||||
missing_value):
|
||||
def _gen_overwrite_1d_array_adjustment_case(dtype):
|
||||
"""
|
||||
Generate test cases for overwrite adjustments.
|
||||
|
||||
@@ -327,21 +321,24 @@ def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
float64_dtype: Float641DArrayOverwrite,
|
||||
datetime64ns_dtype: Datetime641DArrayOverwrite,
|
||||
}[dtype]
|
||||
make_expected_dtype = as_dtype(dtype)
|
||||
missing_value = default_missing_value_for_dtype(datetime64ns_dtype)
|
||||
|
||||
adjustments = {}
|
||||
buffer_as_of = [None] * 6
|
||||
baseline = make_input([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
baseline = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
buffer_as_of[0] = make_expected_output([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[0] = make_expected_dtype([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals1 = [1]
|
||||
# Note that row indices are inclusive!
|
||||
@@ -351,12 +348,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals1])
|
||||
)
|
||||
]
|
||||
buffer_as_of[1] = make_input([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[1] = make_expected_dtype([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# No adjustment at index 2.
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
@@ -368,12 +365,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals3])
|
||||
)
|
||||
]
|
||||
buffer_as_of[3] = make_input([[4, 2, 2],
|
||||
[4, 2, 2],
|
||||
[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[3] = make_expected_dtype([[4, 2, 2],
|
||||
[4, 2, 2],
|
||||
[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals4 = [5] * 4
|
||||
adjustments[4] = [
|
||||
@@ -381,12 +378,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
0, 3, 2, 2,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals4]))
|
||||
]
|
||||
buffer_as_of[4] = make_input([[4, 2, 5],
|
||||
[4, 2, 5],
|
||||
[1, 2, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[4] = make_expected_dtype([[4, 2, 5],
|
||||
[4, 2, 5],
|
||||
[1, 2, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals5 = range(1, 6)
|
||||
adjustments[5] = [
|
||||
@@ -394,12 +391,12 @@ def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
0, 4, 1, 1,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals5])),
|
||||
]
|
||||
buffer_as_of[5] = make_input([[4, 1, 5],
|
||||
[4, 2, 5],
|
||||
[1, 3, 5],
|
||||
[2, 4, 5],
|
||||
[2, 5, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[5] = make_expected_dtype([[4, 1, 5],
|
||||
[4, 2, 5],
|
||||
[1, 3, 5],
|
||||
[2, 4, 5],
|
||||
[2, 5, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
@@ -532,38 +529,10 @@ class AdjustedArrayTestCase(TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
chain(
|
||||
_gen_overwrite_adjustment_cases(
|
||||
'float',
|
||||
make_input=as_dtype(float64_dtype),
|
||||
make_expected_output=as_dtype(float64_dtype),
|
||||
dtype=float64_dtype,
|
||||
missing_value=default_missing_value_for_dtype(float64_dtype),
|
||||
),
|
||||
_gen_overwrite_adjustment_cases(
|
||||
'datetime',
|
||||
make_input=as_dtype(datetime64ns_dtype),
|
||||
make_expected_output=as_dtype(datetime64ns_dtype),
|
||||
dtype=datetime64ns_dtype,
|
||||
missing_value=default_missing_value_for_dtype(
|
||||
datetime64ns_dtype,
|
||||
),
|
||||
),
|
||||
_gen_overwrite_1d_array_adjustment_case(
|
||||
'float',
|
||||
make_input=as_dtype(float64_dtype),
|
||||
make_expected_output=as_dtype(float64_dtype),
|
||||
dtype=float64_dtype,
|
||||
missing_value=default_missing_value_for_dtype(float64_dtype),
|
||||
),
|
||||
_gen_overwrite_1d_array_adjustment_case(
|
||||
'datetime',
|
||||
make_input=as_dtype(datetime64ns_dtype),
|
||||
make_expected_output=as_dtype(datetime64ns_dtype),
|
||||
dtype=datetime64ns_dtype,
|
||||
missing_value=default_missing_value_for_dtype(
|
||||
datetime64ns_dtype,
|
||||
),
|
||||
),
|
||||
_gen_overwrite_adjustment_cases(float64_dtype),
|
||||
_gen_overwrite_adjustment_cases(datetime64ns_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(float64_dtype),
|
||||
_gen_overwrite_1d_array_adjustment_case(datetime64ns_dtype),
|
||||
# There are six cases here:
|
||||
# Using np.bytes/np.unicode/object arrays as inputs.
|
||||
# Passing np.bytes/np.unicode/object arrays to LabelArray,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
+27
-23
@@ -3,7 +3,7 @@ from cpython cimport Py_EQ
|
||||
|
||||
from pandas import isnull, Timestamp
|
||||
from numpy cimport float64_t, uint8_t, int64_t
|
||||
from numpy import asarray, datetime64, float64
|
||||
from numpy import asarray, datetime64, float64, int64
|
||||
# Purely for readability. There aren't C-level declarations for these types.
|
||||
ctypedef object Int64Index_t
|
||||
ctypedef object DatetimeIndex_t
|
||||
@@ -451,28 +451,32 @@ cdef class Datetime641DArrayOverwrite(ArrayAdjustment):
|
||||
Example
|
||||
-------
|
||||
|
||||
>>> import numpy as np
|
||||
>>> arr = np.arange(25, dtype=float).reshape(5, 5)
|
||||
>>> arr
|
||||
array([[ 0., 1., 2., 3., 4.],
|
||||
[ 5., 6., 7., 8., 9.],
|
||||
[ 10., 11., 12., 13., 14.],
|
||||
[ 15., 16., 17., 18., 19.],
|
||||
[ 20., 21., 22., 23., 24.]])
|
||||
>>> import numpy as np; import pandas as pd
|
||||
>>> dts = pd.date_range('2014', freq='D', periods=9, tz='UTC')
|
||||
>>> arr = dts.values.reshape(3, 3)
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, False, False],
|
||||
[False, False, False]], dtype=bool)
|
||||
>>> adj = Datetime641DArrayOverwrite(
|
||||
... row_start=0,
|
||||
... row_end=3,
|
||||
... column_start=0,
|
||||
... column_end=0,
|
||||
... values=np.array([1, 2, 3, 4]),
|
||||
)
|
||||
>>> adj.mutate(arr)
|
||||
>>> arr
|
||||
array([[ 1., 1., 2., 3., 4.],
|
||||
[ 2., 6., 7., 8., 9.],
|
||||
[ 3., 11., 12., 13., 14.],
|
||||
[ 4., 16., 17., 18., 19.],
|
||||
[ 20., 21., 22., 23., 24.]])
|
||||
... first_row=1,
|
||||
... last_row=2,
|
||||
... first_col=1,
|
||||
... last_col=2,
|
||||
... values=np.array([
|
||||
... np.datetime64(0, 'ns'),
|
||||
... np.datetime64(1, 'ns')
|
||||
... ])
|
||||
... )
|
||||
>>> adj.mutate(arr.view(np.int64))
|
||||
>>> arr == np.datetime64(0, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, True, True],
|
||||
[False, False, False]], dtype=bool)
|
||||
>>> arr == np.datetime64(1, 'ns')
|
||||
array([[False, False, False],
|
||||
[False, False, False],
|
||||
[False, True, True]], dtype=bool)
|
||||
"""
|
||||
cdef:
|
||||
readonly int64_t[:] values
|
||||
@@ -598,7 +602,7 @@ cdef datetime_to_int(object datetimelike):
|
||||
datetimelike.dtype.name,
|
||||
)
|
||||
|
||||
return datetimelike.astype(int)
|
||||
return datetimelike.astype(int64)
|
||||
|
||||
|
||||
cdef class Datetime64Adjustment(_Int64Adjustment):
|
||||
|
||||
@@ -1096,12 +1096,15 @@ class BlazeLoader(dict):
|
||||
sparse_deltas = last_in_date_group(non_novel_deltas,
|
||||
dates,
|
||||
assets,
|
||||
reindex=False)
|
||||
reindex=False,
|
||||
have_sids=have_sids)
|
||||
dense_output = last_in_date_group(sparse_output,
|
||||
dates,
|
||||
assets,
|
||||
reindex=True)
|
||||
ffill_across_cols(dense_output, columns)
|
||||
reindex=True,
|
||||
have_sids=have_sids)
|
||||
ffill_across_cols(dense_output, columns, {c.name: c.name
|
||||
for c in columns})
|
||||
if have_sids:
|
||||
adjustments_from_deltas = adjustments_from_deltas_with_sids
|
||||
column_view = identity
|
||||
|
||||
@@ -25,6 +25,8 @@ class BlazeEstimatesLoader(PipelineLoader):
|
||||
----------
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
columns : dict[str -> str]
|
||||
A dict mapping BoundColumn names to the associated names in `expr`.
|
||||
resources : dict, optional
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
@@ -33,8 +35,6 @@ class BlazeEstimatesLoader(PipelineLoader):
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezeone to use for the data query cutoff.
|
||||
dataset : DataSet
|
||||
The DataSet object for which this loader loads data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
@@ -43,12 +43,14 @@ class BlazeEstimatesLoader(PipelineLoader):
|
||||
Dim * {{
|
||||
{SID_FIELD_NAME}: int64,
|
||||
{TS_FIELD_NAME}: datetime,
|
||||
{FISCAL_YEAR_FIELD_NAME}: float64,
|
||||
{FISCAL_QUARTER_FIELD_NAME}: float64,
|
||||
{EVENT_DATE_FIELD_NAME}: datetime,
|
||||
}}
|
||||
|
||||
And other dataset-specific fields, where each row of the table is a
|
||||
record including the sid to identify the company, the timestamp where we
|
||||
learned about the announcement, and the date when the earnings will be
|
||||
announced.
|
||||
learned about the announcement, and the date of the event.
|
||||
|
||||
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
|
||||
start the backtest with knowledge of all announcements.
|
||||
|
||||
@@ -24,6 +24,10 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
----------
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
next_value_columns : dict[BoundColumn -> raw column name]
|
||||
A dict mapping 'next' BoundColumns to their column names in `expr`.
|
||||
previous_value_columns : dict[BoundColumn -> raw column name]
|
||||
A dict mapping 'previous' BoundColumns to their column names in `expr`.
|
||||
resources : dict, optional
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
@@ -32,8 +36,6 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
The time to use for the data query cutoff.
|
||||
data_query_tz : tzinfo or str
|
||||
The timezone to use for the data query cutoff.
|
||||
dataset : DataSet
|
||||
The DataSet object for which this loader loads data.
|
||||
|
||||
Notes
|
||||
-----
|
||||
@@ -42,12 +44,12 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
Dim * {{
|
||||
{SID_FIELD_NAME}: int64,
|
||||
{TS_FIELD_NAME}: datetime,
|
||||
{EVENT_DATE_FIELD_NAME}: datetime,
|
||||
}}
|
||||
|
||||
And other dataset-specific fields, where each row of the table is a
|
||||
record including the sid to identify the company, the timestamp where we
|
||||
learned about the announcement, and the date when the earnings will be z
|
||||
announced.
|
||||
learned about the announcement, and the event date.
|
||||
|
||||
If the '{TS_FIELD_NAME}' field is not included it is assumed that we
|
||||
start the backtest with knowledge of all announcements.
|
||||
@@ -84,8 +86,12 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
self._data_query_tz = data_query_tz
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
raw = load_raw_data(assets, dates, self._data_query_time,
|
||||
self._data_query_tz, self._expr, self._odo_kwargs)
|
||||
raw = load_raw_data(assets,
|
||||
dates,
|
||||
self._data_query_time,
|
||||
self._data_query_tz,
|
||||
self._expr,
|
||||
self._odo_kwargs)
|
||||
|
||||
return EventsLoader(
|
||||
events=raw,
|
||||
|
||||
@@ -6,7 +6,11 @@ from zipline.pipeline.loaders.utils import (
|
||||
)
|
||||
|
||||
|
||||
def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
|
||||
def load_raw_data(assets,
|
||||
dates,
|
||||
data_query_time,
|
||||
data_query_tz,
|
||||
expr,
|
||||
odo_kwargs):
|
||||
"""
|
||||
given an expression representing data to load, perform normalization and
|
||||
@@ -25,13 +29,14 @@ def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
|
||||
`time`.
|
||||
expr : expr
|
||||
the expression representing the data to load.
|
||||
odo_kwargs : dict, optional
|
||||
odo_kwargs : dict
|
||||
extra keyword arguments to pass to odo when executing the expression.
|
||||
|
||||
returns
|
||||
-------
|
||||
raw : pd.dataframe
|
||||
the data symbolized by `expr` materialized in a dataframe.
|
||||
The result of computing expr and materializing the result as a
|
||||
dataframe.
|
||||
"""
|
||||
lower_dt, upper_dt = normalize_data_query_bounds(
|
||||
dates[0],
|
||||
@@ -45,7 +50,7 @@ def load_raw_data(assets, dates, data_query_time, data_query_tz, expr,
|
||||
upper_dt,
|
||||
odo_kwargs,
|
||||
)
|
||||
sids = raw.loc[:, SID_FIELD_NAME]
|
||||
sids = raw[SID_FIELD_NAME]
|
||||
raw.drop(
|
||||
sids[~sids.isin(assets)].index,
|
||||
inplace=True
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from six import viewvalues
|
||||
from toolz import groupby
|
||||
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import (Datetime641DArrayOverwrite,
|
||||
Float641DArrayOverwrite)
|
||||
@@ -22,14 +21,15 @@ from zipline.pipeline.loaders.utils import (
|
||||
last_in_date_group
|
||||
)
|
||||
|
||||
NORMALIZED_QUARTERS = 'normalized_quarters'
|
||||
|
||||
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
|
||||
|
||||
INVALID_NUM_QTRS_MESSAGE = "Passed invalid number of quarters %s; " \
|
||||
"must pass a number of quarters >= 0"
|
||||
NEXT_FISCAL_QUARTER = 'next_fiscal_quarter'
|
||||
NEXT_FISCAL_YEAR = 'next_fiscal_year'
|
||||
NORMALIZED_QUARTERS = 'normalized_quarters'
|
||||
PREVIOUS_FISCAL_QUARTER = 'previous_fiscal_quarter'
|
||||
PREVIOUS_FISCAL_YEAR = 'previous_fiscal_year'
|
||||
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
|
||||
SIMULTATION_DATES = 'dates'
|
||||
|
||||
|
||||
@@ -86,10 +86,10 @@ def validate_column_specs(events, columns):
|
||||
class QuarterEstimatesLoader(PipelineLoader):
|
||||
def __init__(self,
|
||||
estimates,
|
||||
base_column_name_map):
|
||||
name_map):
|
||||
validate_column_specs(
|
||||
estimates,
|
||||
base_column_name_map
|
||||
name_map
|
||||
)
|
||||
|
||||
self.estimates = estimates[
|
||||
@@ -97,12 +97,16 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
estimates[FISCAL_QUARTER_FIELD_NAME].notnull() &
|
||||
estimates[FISCAL_YEAR_FIELD_NAME].notnull()
|
||||
]
|
||||
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
|
||||
self.estimates[FISCAL_YEAR_FIELD_NAME],
|
||||
self.estimates[FISCAL_QUARTER_FIELD_NAME],
|
||||
)
|
||||
|
||||
self.base_column_name_map = base_column_name_map
|
||||
self.name_map = name_map
|
||||
|
||||
@abstractmethod
|
||||
def load_quarters(self, num_quarters, last, dates):
|
||||
pass
|
||||
raise NotImplementedError('load_quarters')
|
||||
|
||||
def get_requested_data_for_col(self, stacked_last_per_qtr, idx, dates):
|
||||
"""
|
||||
@@ -111,8 +115,8 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
Parameters
|
||||
----------
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
The latest estimate known per sid per date per quarter with the
|
||||
dates, normalized quarter, and sid as the index.
|
||||
The latest estimate known with the dates, normalized quarter, and
|
||||
sid as the index.
|
||||
idx : pd.MultiIndex
|
||||
The index of the row of the requested quarter from each date for
|
||||
each sid.
|
||||
@@ -122,16 +126,18 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
Returns
|
||||
--------
|
||||
requested_qtr_data : pd.DataFrame
|
||||
The DataFrame with final values for the requested quarter for all
|
||||
columns; `dates` are the index and columns are a MultiIndex with
|
||||
sids at the top level and the dataset columns on the bottom.
|
||||
The DataFrame with the latest values for the requested quarter
|
||||
for all columns; `dates` are the index and columns are a MultiIndex
|
||||
with sids at the top level and the dataset columns on the bottom.
|
||||
"""
|
||||
requested_qtr_data = stacked_last_per_qtr.loc[idx]
|
||||
# We no longer need this in the index, but we do need it as a column
|
||||
# to calculate adjustments.
|
||||
# We no longer need the shifted normalized quarters in the index, but
|
||||
# we do need it as a column to calculate adjustments.
|
||||
requested_qtr_data = requested_qtr_data.reset_index(
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
)
|
||||
# Calculate the actual year/quarter being requested and add those in
|
||||
# as columns.
|
||||
(requested_qtr_data[FISCAL_YEAR_FIELD_NAME],
|
||||
requested_qtr_data[FISCAL_QUARTER_FIELD_NAME]) = \
|
||||
split_normalized_quarters(
|
||||
@@ -154,8 +160,7 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
column_name,
|
||||
column,
|
||||
mask,
|
||||
assets,
|
||||
qtr_crossover_point):
|
||||
assets):
|
||||
"""
|
||||
Creates an AdjustedArray from the given estimates data for the given
|
||||
dates.
|
||||
@@ -183,18 +188,17 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
computed.
|
||||
column : BoundColumn
|
||||
The column for which the AdjustedArray is being computed.
|
||||
mask :
|
||||
assets :
|
||||
qtr_crossover_point :
|
||||
Whether we should use the 'right' or 'left' side when doing
|
||||
searchsorted on the dates for quarter boundaries.
|
||||
mask : np.array
|
||||
Mask array of dimensions len(dates) X len(assets).
|
||||
assets : pd.Int64Index
|
||||
An index of all the assets from the raw data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
adjusted_array : AdjustedArray
|
||||
The array of data and overwrites for the given column.
|
||||
"""
|
||||
adjustments = defaultdict(list)
|
||||
adjustments = {}
|
||||
requested_qtr_data = self.get_requested_data_for_col(
|
||||
stacked_last_per_qtr, requested_qtr_idx, dates
|
||||
)
|
||||
@@ -204,10 +208,8 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
zero_qtr_data = zero_qtr_data.reset_index(NORMALIZED_QUARTERS)
|
||||
if column.dtype == datetime64ns_dtype:
|
||||
overwrite = Datetime641DArrayOverwrite
|
||||
missing_value = np.datetime64('NaT', 'ns')
|
||||
else:
|
||||
overwrite = Float641DArrayOverwrite
|
||||
missing_value = np.NaN
|
||||
for sid_idx, sid in enumerate(assets):
|
||||
zero_qtr_sid_data = zero_qtr_data[
|
||||
zero_qtr_data.index.get_level_values(SID_FIELD_NAME) == sid
|
||||
@@ -225,7 +227,7 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
]
|
||||
# For the given sid, determine which quarters we have estimates
|
||||
# for.
|
||||
quarters_with_estimates_for_sid = last_per_qtr.xs(
|
||||
qtrs_with_estimates_for_sid = last_per_qtr.xs(
|
||||
sid, axis=1, level=SID_FIELD_NAME
|
||||
).groupby(axis=1, level=1).first().columns.values
|
||||
for row_indexer in list(qtr_shifts.index):
|
||||
@@ -233,108 +235,162 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
# after this row. This isn't the starting index of the
|
||||
# requested quarter, but simply the date we cross over into a
|
||||
# new quarter.
|
||||
qtr_start_idx = dates.searchsorted(
|
||||
next_qtr_start_idx = dates.searchsorted(
|
||||
zero_qtr_data.loc[
|
||||
row_indexer
|
||||
][EVENT_DATE_FIELD_NAME],
|
||||
side=qtr_crossover_point
|
||||
side='left'
|
||||
if isinstance(self, PreviousQuartersEstimatesLoader)
|
||||
else 'right'
|
||||
)
|
||||
|
||||
# Only add adjustments if the next quarter starts somewhere in
|
||||
# our date index for this sid. Our 'next' quarter can never
|
||||
# start at index 0; a starting index of 0 means that the next
|
||||
# quarter's event date was NaT.
|
||||
if 0 < qtr_start_idx < len(dates):
|
||||
# Find the quarter being requested in the quarter we're
|
||||
# crossing into.
|
||||
requested_quarter = requested_qtr_data[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
][sid].iloc[qtr_start_idx]
|
||||
|
||||
# If there are estimates for the requested quarter,
|
||||
# overwrite all values going up to the starting index of
|
||||
# that quarter with estimates for that quarter.
|
||||
if requested_quarter in quarters_with_estimates_for_sid:
|
||||
adjustments[qtr_start_idx] = \
|
||||
[overwrite(
|
||||
0,
|
||||
qtr_start_idx - 1, # overwrite thru last qtr
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
last_per_qtr[column_name,
|
||||
requested_quarter,
|
||||
sid][:qtr_start_idx].values)]
|
||||
# There are no estimates for the quarter. Overwrite all
|
||||
# values going up to the starting index of that quarter
|
||||
# with the missing value for this column.
|
||||
else:
|
||||
adjustments[qtr_start_idx] = [
|
||||
overwrite(
|
||||
0,
|
||||
qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
np.array(
|
||||
[missing_value] *
|
||||
len(last_per_qtr.index[:qtr_start_idx]))
|
||||
)
|
||||
]
|
||||
adjustments[next_qtr_start_idx] = \
|
||||
self.create_overwrite_for_quarter(
|
||||
next_qtr_start_idx,
|
||||
column,
|
||||
column_name,
|
||||
dates,
|
||||
last_per_qtr,
|
||||
overwrite,
|
||||
qtrs_with_estimates_for_sid,
|
||||
requested_qtr_data,
|
||||
sid,
|
||||
sid_idx,
|
||||
)
|
||||
|
||||
return AdjustedArray(
|
||||
requested_qtr_data[column_name].values.astype(column.dtype),
|
||||
mask,
|
||||
dict(adjustments),
|
||||
column.missing_value,
|
||||
)
|
||||
requested_qtr_data[column_name].values.astype(column.dtype),
|
||||
mask,
|
||||
dict(adjustments),
|
||||
column.missing_value,
|
||||
)
|
||||
|
||||
def create_overwrite_for_quarter(self,
|
||||
next_qtr_start_idx,
|
||||
column,
|
||||
column_name,
|
||||
dates,
|
||||
last_per_qtr,
|
||||
overwrite,
|
||||
quarters_with_estimates_for_sid,
|
||||
requested_qtr_data,
|
||||
sid,
|
||||
sid_idx):
|
||||
# Only add adjustments if the next quarter starts somewhere in
|
||||
# our date index for this sid. Our 'next' quarter can never
|
||||
# start at index 0; a starting index of 0 means that the next
|
||||
# quarter's event date was NaT.
|
||||
if 0 < next_qtr_start_idx < len(dates):
|
||||
# Find the quarter being requested in the quarter we're
|
||||
# crossing into.
|
||||
requested_quarter = requested_qtr_data[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
][sid].iloc[next_qtr_start_idx]
|
||||
|
||||
# If there are estimates for the requested quarter,
|
||||
# overwrite all values going up to the starting index of
|
||||
# that quarter with estimates for that quarter.
|
||||
if requested_quarter in quarters_with_estimates_for_sid:
|
||||
return self.create_overwrite_for_estimate(
|
||||
column,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx
|
||||
)
|
||||
# There are no estimates for the quarter. Overwrite all
|
||||
# values going up to the starting index of that quarter
|
||||
# with the missing value for this column.
|
||||
else:
|
||||
return self.overwrite_with_null(
|
||||
column,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
sid_idx
|
||||
)
|
||||
|
||||
def overwrite_with_null(self,
|
||||
column,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
sid_idx):
|
||||
return [overwrite(
|
||||
0,
|
||||
next_qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
np.full(
|
||||
len(
|
||||
last_per_qtr.index[:next_qtr_start_idx]
|
||||
),
|
||||
column.missing_value,
|
||||
dtype=column.dtype
|
||||
))]
|
||||
|
||||
def load_adjusted_array(self, columns, dates, assets, mask):
|
||||
# TODO: how can we enforce that datasets have the num_quarters
|
||||
# attribute, given that they're created dynamically?
|
||||
groups = groupby(lambda x: x.dataset.num_quarters, columns)
|
||||
groups_columns = dict(groups)
|
||||
if (pd.Series(groups_columns.keys()) < 0).any():
|
||||
raise ValueError("Must pass a number of quarters >= 0")
|
||||
# Separate out getting the columns' datasets and the datasets'
|
||||
# num_quarters attributes to ensure that we're catching the right
|
||||
# AttributeError.
|
||||
col_to_datasets = {col: col.dataset for col in columns}
|
||||
try:
|
||||
groups = groupby(lambda col: col_to_datasets[col].num_quarters,
|
||||
col_to_datasets)
|
||||
except AttributeError:
|
||||
raise AttributeError("Datasets loaded via the "
|
||||
"QuarterEstimatesLoader must define a "
|
||||
"`num_quarters` attribute that defines how "
|
||||
"many quarters out the loader should load "
|
||||
"the data relative to `dates`.")
|
||||
if any(num_qtr < 0 for num_qtr in groups):
|
||||
raise ValueError(
|
||||
INVALID_NUM_QTRS_MESSAGE % ','.join(
|
||||
str(qtr) for qtr in groups if qtr < 0
|
||||
)
|
||||
|
||||
)
|
||||
out = {}
|
||||
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
|
||||
self.estimates[FISCAL_YEAR_FIELD_NAME],
|
||||
self.estimates[FISCAL_QUARTER_FIELD_NAME],
|
||||
)
|
||||
for num_quarters, columns in groups_columns.items():
|
||||
# The column's dataset is itself dynamic and the mapping we
|
||||
# actually want is to its dataset's parent's column name.
|
||||
name_map = {c: self.base_column_name_map[
|
||||
getattr(c.dataset.__base__, c.name)
|
||||
] for c in columns}
|
||||
|
||||
for num_quarters, columns in groups.items():
|
||||
# Determine the last piece of information we know for each column
|
||||
# on each date in the index for each sid and quarter.
|
||||
last_per_qtr = last_in_date_group(
|
||||
self.estimates, True, dates, assets,
|
||||
self.estimates, dates, assets, reindex=True,
|
||||
extra_groupers=[NORMALIZED_QUARTERS]
|
||||
)
|
||||
|
||||
# Forward fill values for each quarter/sid/dataset column.
|
||||
ffill_across_cols(last_per_qtr, columns)
|
||||
ffill_across_cols(last_per_qtr, columns, self.name_map)
|
||||
# Stack quarter and sid into the index.
|
||||
stacked_last_per_qtr = last_per_qtr.stack([NORMALIZED_QUARTERS,
|
||||
SID_FIELD_NAME])
|
||||
stacked_last_per_qtr = last_per_qtr.stack([SID_FIELD_NAME,
|
||||
NORMALIZED_QUARTERS])
|
||||
# Set date index name for ease of reference
|
||||
stacked_last_per_qtr.index.set_names(SIMULTATION_DATES, 0, True)
|
||||
stacked_last_per_qtr.index.set_names(SIMULTATION_DATES,
|
||||
level=0,
|
||||
inplace=True)
|
||||
# We want to know the most recent/next event relative to each date.
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(
|
||||
EVENT_DATE_FIELD_NAME
|
||||
)
|
||||
# Determine which quarter is next/previous for each date.
|
||||
shifted_qtr_data = self.load_quarters(num_quarters,
|
||||
stacked_last_per_qtr)
|
||||
zero_qtr_idx = shifted_qtr_data.index
|
||||
requested_qtr_idx = shifted_qtr_data.set_index([
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SIMULTATION_DATES
|
||||
),
|
||||
shifted_qtr_data[SHIFTED_NORMALIZED_QTRS],
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SID_FIELD_NAME
|
||||
)]
|
||||
).index
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SIMULTATION_DATES
|
||||
),
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SID_FIELD_NAME
|
||||
),
|
||||
shifted_qtr_data[SHIFTED_NORMALIZED_QTRS]
|
||||
]).index
|
||||
|
||||
for c in columns:
|
||||
column_name = name_map[c]
|
||||
column_name = self.name_map[c.name]
|
||||
adjusted_array = self.get_adjustments(zero_qtr_idx,
|
||||
requested_qtr_idx,
|
||||
stacked_last_per_qtr,
|
||||
@@ -343,26 +399,68 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
column_name,
|
||||
c,
|
||||
mask,
|
||||
assets,
|
||||
self.qtr_crossover_point)
|
||||
assets)
|
||||
out[c] = adjusted_array
|
||||
return out
|
||||
|
||||
|
||||
class NextQuartersEstimatesLoader(QuarterEstimatesLoader):
|
||||
qtr_crossover_point = 'right'
|
||||
def create_overwrite_for_estimate(self,
|
||||
column,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx):
|
||||
return [overwrite(
|
||||
0,
|
||||
# overwrite thru last qtr
|
||||
next_qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
last_per_qtr[
|
||||
column_name,
|
||||
requested_quarter,
|
||||
sid
|
||||
][0:next_qtr_start_idx].values)]
|
||||
|
||||
def load_quarters(self, num_quarters, stacked_last_per_qtr):
|
||||
# Filter for releases that are on or after each simulation date and
|
||||
# determine the next quarter by picking out the upcoming release for
|
||||
# each date in the index.
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(
|
||||
EVENT_DATE_FIELD_NAME
|
||||
)
|
||||
"""
|
||||
Filters for releases that are on or after each simulation date and
|
||||
determines the next quarter by picking out the upcoming release for
|
||||
each date in the index. Adda a SHIFTED_NORMALIZED_QTRS column which
|
||||
contains the requested next quarter for each calendar date and sid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_quarters : int
|
||||
Number of quarters to go out in the future.
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters with each row being the latest estimate for the row's
|
||||
index values, sorted by event date.
|
||||
|
||||
Returns
|
||||
-------
|
||||
next_releases_per_date : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters, keeping only rows with next event information relative to
|
||||
the index values and with an added column for
|
||||
SHIFTED_NORMALIZED_QTRS, which contains the requested quarter for
|
||||
each row.
|
||||
"""
|
||||
|
||||
# We reset the index here because in pandas3, a groupby on the index
|
||||
# will set the index to just the items in the groupby, so we will lose
|
||||
# the normalized quarters.
|
||||
next_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(0)
|
||||
].reset_index(NORMALIZED_QUARTERS).groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME]
|
||||
).nth(0).set_index(NORMALIZED_QUARTERS, append=True)
|
||||
next_releases_per_date[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
] = next_releases_per_date.index.get_level_values(
|
||||
@@ -372,18 +470,57 @@ class NextQuartersEstimatesLoader(QuarterEstimatesLoader):
|
||||
|
||||
|
||||
class PreviousQuartersEstimatesLoader(QuarterEstimatesLoader):
|
||||
qtr_crossover_point = 'left'
|
||||
def create_overwrite_for_estimate(self,
|
||||
column,
|
||||
column_name,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
requested_quarter,
|
||||
sid,
|
||||
sid_idx):
|
||||
return self.overwrite_with_null(column,
|
||||
last_per_qtr,
|
||||
next_qtr_start_idx,
|
||||
overwrite,
|
||||
sid_idx)
|
||||
|
||||
def load_quarters(self, num_quarters, stacked_last_per_qtr):
|
||||
# Filter for releases that are on or before each simulation date and
|
||||
# determine the previous quarter by picking out the upcoming release
|
||||
# for each date in the index.
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(EVENT_DATE_FIELD_NAME)
|
||||
"""
|
||||
Filters for releases that are on or after each simulation date and
|
||||
determines the previous quarter by picking out the most recent
|
||||
release relative to each date in the index. Adds a
|
||||
SHIFTED_NORMALIZED_QTRS column which contains the requested previous
|
||||
quarter for each calendar date and sid.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
num_quarters : int
|
||||
Number of quarters to go out in the past.
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters with each row being the latest estimate for the row's
|
||||
index values, sorted by event date.
|
||||
|
||||
Returns
|
||||
-------
|
||||
next_releases_per_date : pd.DataFrame
|
||||
A DataFrame with index of calendar dates, sid, and normalized
|
||||
quarters, keeping only rows with have a previous event relative
|
||||
to the index values and with an added column for
|
||||
SHIFTED_NORMALIZED_QTRS, which contains the requested quarter for
|
||||
each row.
|
||||
"""
|
||||
|
||||
# We reset the index here because in pandas3, a groupby on the index
|
||||
# will set the index to just the items in the groupby, so we will lose
|
||||
# the normalized quarters.
|
||||
previous_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
stacked_last_per_qtr.index.get_level_values(
|
||||
SIMULTATION_DATES
|
||||
)].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(-1)
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
].reset_index(NORMALIZED_QUARTERS).groupby(
|
||||
level=[SIMULTATION_DATES, SID_FIELD_NAME]
|
||||
).nth(-1).set_index(NORMALIZED_QUARTERS, append=True)
|
||||
previous_releases_per_date[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
] = previous_releases_per_date.index.get_level_values(
|
||||
|
||||
@@ -276,7 +276,7 @@ def check_data_query_args(data_query_time, data_query_tz):
|
||||
)
|
||||
|
||||
|
||||
def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
def last_in_date_group(df, dates, assets, reindex=True, have_sids=True,
|
||||
extra_groupers=[]):
|
||||
"""
|
||||
Determine the last piece of information known on each date in the date
|
||||
@@ -286,14 +286,14 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The DataFrame containing the data to be grouped.
|
||||
reindex : bool
|
||||
Whether or not the DataFrame should be reindexed against the date
|
||||
index. This will add back any dates to the index that were grouped
|
||||
away.
|
||||
dates : pd.DatetimeIndex
|
||||
The dates to use for grouping and reindexing.
|
||||
assets : pd.Int64Index
|
||||
The assets that should be included in the column multiindex.
|
||||
reindex : bool
|
||||
Whether or not the DataFrame should be reindexed against the date
|
||||
index. This will add back any dates to the index that were grouped
|
||||
away.
|
||||
have_sids : bool
|
||||
Whether or not the DataFrame has sids. If it does, they will be used
|
||||
in the groupby.
|
||||
@@ -307,11 +307,11 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
levels of a multiindex of columns.
|
||||
|
||||
"""
|
||||
idx = dates[dates.searchsorted(
|
||||
idx = [dates[dates.searchsorted(
|
||||
df[TS_FIELD_NAME].values.astype('datetime64[D]')
|
||||
)]
|
||||
)]]
|
||||
if have_sids:
|
||||
idx = [idx, SID_FIELD_NAME]
|
||||
idx += [SID_FIELD_NAME]
|
||||
idx += extra_groupers
|
||||
|
||||
last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(
|
||||
@@ -321,7 +321,7 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
|
||||
# For the number of things that we're grouping by (except TS), unstack
|
||||
# the df
|
||||
last_in_group = last_in_group.unstack([-1, -2])
|
||||
last_in_group = last_in_group.unstack(list(range(-1, -len(idx), -1)))
|
||||
|
||||
if reindex:
|
||||
if have_sids:
|
||||
@@ -339,7 +339,7 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
return last_in_group
|
||||
|
||||
|
||||
def ffill_across_cols(df, columns):
|
||||
def ffill_across_cols(df, columns, name_map):
|
||||
"""
|
||||
Forward fill values in a DataFrame with special logic to handle cases
|
||||
that pd.DataFrame.ffill cannot and cast columns to appropriate types.
|
||||
@@ -351,6 +351,9 @@ def ffill_across_cols(df, columns):
|
||||
columns : list of BoundColumn
|
||||
The BoundColumns that correspond to columns in the DataFrame to which
|
||||
special filling and/or casting logic should be applied.
|
||||
name_map: map of string -> string
|
||||
Mapping from the name of each BoundColumn to the associated column
|
||||
name in `df`.
|
||||
"""
|
||||
df.ffill(inplace=True)
|
||||
|
||||
@@ -369,18 +372,19 @@ def ffill_across_cols(df, columns):
|
||||
# pandas to replace NaNs in an object column with None using fillna,
|
||||
# so we have to roll our own instead using df.where.
|
||||
for column in columns:
|
||||
column_name = name_map[column.name]
|
||||
# Special logic for strings since `fillna` doesn't work if the
|
||||
# missing value is `None`.
|
||||
if column.dtype == categorical_dtype:
|
||||
df[column.name] = df[
|
||||
df[column_name] = df[
|
||||
column.name
|
||||
].where(pd.notnull(df[column.name]),
|
||||
].where(pd.notnull(df[column_name]),
|
||||
column.missing_value)
|
||||
else:
|
||||
# We need to execute `fillna` before `astype` in case the
|
||||
# column contains NaNs and needs to be cast to bool or int.
|
||||
# This is so that the NaNs are replaced first, since pandas
|
||||
# can't convert NaNs for those types.
|
||||
df[column.name] = df[
|
||||
column.name
|
||||
df[column_name] = df[
|
||||
column_name
|
||||
].fillna(column.missing_value).astype(column.dtype)
|
||||
|
||||
@@ -49,8 +49,14 @@ from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils import security_list
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.input_validation import expect_dimensions
|
||||
<<<<<<< HEAD
|
||||
from zipline.utils.numpy_utils import as_column, isnat
|
||||
from zipline.utils.pandas_utils import timedelta_to_integral_seconds
|
||||
=======
|
||||
from zipline.utils.numpy_utils import (
|
||||
as_column,
|
||||
)
|
||||
>>>>>>> WIP
|
||||
from zipline.utils.sentinel import sentinel
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -34,13 +34,14 @@ from ..finance.trading import TradingEnvironment
|
||||
from ..utils import factory
|
||||
from ..utils.classproperty import classproperty
|
||||
from ..utils.final import FinalMeta, final
|
||||
from .core import tmp_asset_finder, make_simple_equity_info
|
||||
from .core import (tmp_asset_finder, make_simple_equity_info)
|
||||
from zipline.assets import Equity, Future
|
||||
from zipline.pipeline import SimplePipelineEngine
|
||||
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils.calendars import (
|
||||
get_calendar,
|
||||
register_calendar)
|
||||
register_calendar
|
||||
)
|
||||
|
||||
|
||||
class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
|
||||
|
||||
Reference in New Issue
Block a user