mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 12:36:44 +08:00
TST: add test for datetime array and update test
TST: fix quarter normalization test TST: change test name BUG: remove arg BUG: look at dict keys TST: add test for windowing MAINT: raise ValueError instead of asserting TST: add assertion to check windowing TST: parametrize test over number of quarters forward/back. BUG: fix adjustment calculation logic for quarter crossovers. TST: add test for previous quarter windows BUG: fix bugs in calculating previous windows BUG: fix missing value for datetime TST: add test case for missing quarter
This commit is contained in:
@@ -20,6 +20,7 @@ from toolz import curry
|
||||
from zipline.errors import WindowLengthNotPositive, WindowLengthTooLong
|
||||
from zipline.lib.adjustment import (
|
||||
Datetime64Overwrite,
|
||||
Datetime641DArrayOverwrite,
|
||||
Float64Multiply,
|
||||
Float64Overwrite,
|
||||
Float641DArrayOverwrite,
|
||||
@@ -305,7 +306,11 @@ def _gen_overwrite_adjustment_cases(name,
|
||||
)
|
||||
|
||||
|
||||
def _gen_overwrite_1d_array_adjustment_case():
|
||||
def _gen_overwrite_1d_array_adjustment_case(name,
|
||||
make_input,
|
||||
make_expected_output,
|
||||
dtype,
|
||||
missing_value):
|
||||
"""
|
||||
Generate test cases for overwrite adjustments.
|
||||
|
||||
@@ -314,90 +319,91 @@ def _gen_overwrite_1d_array_adjustment_case():
|
||||
the adjustments are expected to modify the arrays.
|
||||
|
||||
This is parameterized on `make_input` and `make_expected_output` functions,
|
||||
which take 2-D lists of values and transform them into desired input/output
|
||||
which take 1-D lists of values and transform them into desired input/output
|
||||
arrays. We do this so that we can easily test both vanilla numpy ndarrays
|
||||
and our own LabelArray class for strings.
|
||||
"""
|
||||
|
||||
adjustment_type = {
|
||||
float64_dtype: Float641DArrayOverwrite,
|
||||
datetime64ns_dtype: Datetime641DArrayOverwrite,
|
||||
}[dtype]
|
||||
adjustments = {}
|
||||
buffer_as_of = [None] * 6
|
||||
baseline = as_dtype(float64_dtype, [[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
baseline = make_input([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
buffer_as_of[0] = as_dtype(float64_dtype, [[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[0] = make_expected_output([[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals1 = [1]
|
||||
# Note that row indices are inclusive!
|
||||
adjustments[1] = [
|
||||
Float641DArrayOverwrite(array([0]),
|
||||
array([0]),
|
||||
array([0]),
|
||||
array([0]),
|
||||
as_dtype(float64_dtype, array([1])))
|
||||
adjustment_type(
|
||||
0, 0, 0, 0,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals1])
|
||||
)
|
||||
]
|
||||
buffer_as_of[1] = as_dtype(float64_dtype, [[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[1] = make_input([[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
# No adjustment at index 2.
|
||||
buffer_as_of[2] = buffer_as_of[1]
|
||||
|
||||
vals3 = [4, 4, 1]
|
||||
adjustments[3] = [
|
||||
Float641DArrayOverwrite(array([0, 2, 1]),
|
||||
array([1, 2, 2]),
|
||||
array([0, 0, 1]),
|
||||
array([0, 0, 1]),
|
||||
as_dtype(float64_dtype, array([4, 1, 3])))
|
||||
adjustment_type(
|
||||
0, 2, 0, 0,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals3])
|
||||
)
|
||||
]
|
||||
buffer_as_of[3] = as_dtype(float64_dtype, [[4, 2, 2],
|
||||
[4, 3, 2],
|
||||
[1, 3, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[3] = make_input([[4, 2, 2],
|
||||
[4, 2, 2],
|
||||
[1, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals4 = [5] * 4
|
||||
adjustments[4] = [
|
||||
Float641DArrayOverwrite(array([0]),
|
||||
array([3]),
|
||||
array([2]),
|
||||
array([2]),
|
||||
as_dtype(float64_dtype, array([5])))
|
||||
adjustment_type(
|
||||
0, 3, 2, 2,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals4]))
|
||||
]
|
||||
buffer_as_of[4] = as_dtype(float64_dtype, [[4, 2, 5],
|
||||
[4, 3, 5],
|
||||
[1, 3, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[4] = make_input([[4, 2, 5],
|
||||
[4, 2, 5],
|
||||
[1, 2, 5],
|
||||
[2, 2, 5],
|
||||
[2, 2, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
vals5 = range(1, 6)
|
||||
adjustments[5] = [
|
||||
Float641DArrayOverwrite(array([0, 2]),
|
||||
array([4, 2]),
|
||||
array([1, 2]),
|
||||
array([1, 2]),
|
||||
as_dtype(float64_dtype, array([6, 7]))),
|
||||
adjustment_type(
|
||||
0, 4, 1, 1,
|
||||
array([coerce_to_dtype(dtype, val) for val in vals5])),
|
||||
]
|
||||
buffer_as_of[5] = as_dtype(float64_dtype, [[4, 6, 5],
|
||||
[4, 6, 5],
|
||||
[1, 6, 7],
|
||||
[2, 6, 5],
|
||||
[2, 6, 2],
|
||||
[2, 2, 2]])
|
||||
buffer_as_of[5] = make_input([[4, 1, 5],
|
||||
[4, 2, 5],
|
||||
[1, 3, 5],
|
||||
[2, 4, 5],
|
||||
[2, 5, 2],
|
||||
[2, 2, 2]])
|
||||
|
||||
return _gen_expectations(
|
||||
baseline,
|
||||
default_missing_value_for_dtype(float64_dtype),
|
||||
missing_value,
|
||||
adjustments,
|
||||
buffer_as_of,
|
||||
nrows=6,
|
||||
@@ -542,7 +548,22 @@ class AdjustedArrayTestCase(TestCase):
|
||||
datetime64ns_dtype,
|
||||
),
|
||||
),
|
||||
_gen_overwrite_1d_array_adjustment_case(),
|
||||
_gen_overwrite_1d_array_adjustment_case(
|
||||
'float',
|
||||
make_input=as_dtype(float64_dtype),
|
||||
make_expected_output=as_dtype(float64_dtype),
|
||||
dtype=float64_dtype,
|
||||
missing_value=default_missing_value_for_dtype(float64_dtype),
|
||||
),
|
||||
_gen_overwrite_1d_array_adjustment_case(
|
||||
'datetime',
|
||||
make_input=as_dtype(datetime64ns_dtype),
|
||||
make_expected_output=as_dtype(datetime64ns_dtype),
|
||||
dtype=datetime64ns_dtype,
|
||||
missing_value=default_missing_value_for_dtype(
|
||||
datetime64ns_dtype,
|
||||
),
|
||||
),
|
||||
# There are six cases here:
|
||||
# Using np.bytes/np.unicode/object arrays as inputs.
|
||||
# Passing np.bytes/np.unicode/object arrays to LabelArray,
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import blaze as bz
|
||||
import itertools
|
||||
from nose_parameterized import parameterized
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from zipline.pipeline import SimplePipelineEngine, Pipeline
|
||||
from zipline.pipeline import SimplePipelineEngine, Pipeline, CustomFactor
|
||||
from zipline.pipeline.common import (
|
||||
EVENT_DATE_FIELD_NAME,
|
||||
FISCAL_QUARTER_FIELD_NAME,
|
||||
@@ -18,8 +19,8 @@ from zipline.pipeline.loaders.blaze.estimates import (
|
||||
)
|
||||
from zipline.pipeline.loaders.quarter_estimates import (
|
||||
NextQuartersEstimatesLoader,
|
||||
PreviousQuartersEstimatesLoader
|
||||
)
|
||||
PreviousQuartersEstimatesLoader,
|
||||
split_normalized_quarters, normalize_quarters)
|
||||
from zipline.testing import ZiplineTestCase
|
||||
from zipline.testing.fixtures import WithAssetFinder, WithTradingSessions
|
||||
from zipline.testing.predicates import assert_equal
|
||||
@@ -31,7 +32,6 @@ class Estimates(DataSet):
|
||||
fiscal_quarter = Column(dtype=float64_dtype)
|
||||
fiscal_year = Column(dtype=float64_dtype)
|
||||
estimate = Column(dtype=float64_dtype)
|
||||
value = Column(dtype=float64_dtype)
|
||||
|
||||
|
||||
def QuartersEstimates(num_qtr):
|
||||
@@ -40,6 +40,28 @@ def QuartersEstimates(num_qtr):
|
||||
name = Estimates
|
||||
return QtrEstimates
|
||||
|
||||
|
||||
# 0Q1: 2015-01-05.Q1.e1.2015-01-06, 2015-01-10.Q1.e1.2015-01-11,
|
||||
# 0Q2: 2015-01-15.Q2.e1.2015-01-16, 2015-01-20.Q2.e1.2015-01-21,
|
||||
# 0Q4: 2015-02-05.Q4.e1.2015-02-06, 2015-02-10.Q4.e1.2015-02-11,
|
||||
# Skip Q3 to make sure we handle skipped quarter data correctly.
|
||||
estimates_timeline = pd.DataFrame({
|
||||
TS_FIELD_NAME: [pd.Timestamp('2015-01-05'), pd.Timestamp('2015-01-07'),
|
||||
pd.Timestamp('2015-01-05'), pd.Timestamp('2015-01-17'),
|
||||
pd.Timestamp('2015-01-05'), pd.Timestamp('2015-01-17'),
|
||||
pd.Timestamp('2015-01-22'), pd.Timestamp('2015-02-02')],
|
||||
EVENT_DATE_FIELD_NAME:
|
||||
[pd.Timestamp('2015-01-10'), pd.Timestamp('2015-01-10'),
|
||||
pd.Timestamp('2015-01-20'), pd.Timestamp('2015-01-20'),
|
||||
pd.Timestamp('2015-02-10'), pd.Timestamp('2015-02-10'),
|
||||
pd.Timestamp('2015-02-10'), pd.Timestamp('2015-02-10')],
|
||||
'estimate': [1.]*2 + [2.] * 2 + [4.] * 4,
|
||||
FISCAL_QUARTER_FIELD_NAME: [1]*2 + [2] * 2 + [4] * 4,
|
||||
FISCAL_YEAR_FIELD_NAME: [2015]*8,
|
||||
SID_FIELD_NAME: [0]*8
|
||||
})
|
||||
|
||||
|
||||
# Final release dates never change. The quarters have very tight date ranges
|
||||
# in order to reduce the number of dates we need to iterate through when
|
||||
# testing.
|
||||
@@ -48,7 +70,6 @@ releases = pd.DataFrame({
|
||||
EVENT_DATE_FIELD_NAME: [pd.Timestamp('2015-01-15'),
|
||||
pd.Timestamp('2015-01-31')],
|
||||
'estimate': [0.5, 0.8],
|
||||
'value': [0.6, 0.9],
|
||||
FISCAL_QUARTER_FIELD_NAME: [1.0, 2.0],
|
||||
FISCAL_YEAR_FIELD_NAME: [2015.0, 2015.0]
|
||||
})
|
||||
@@ -70,7 +91,6 @@ q2_release_dates = [pd.Timestamp('2015-01-30'), # One day early
|
||||
estimates = pd.DataFrame({
|
||||
EVENT_DATE_FIELD_NAME: q1_release_dates + q2_release_dates,
|
||||
'estimate': [.1, .2, .3, .4],
|
||||
'value': [np.NaN, np.NaN, np.NaN, np.NaN],
|
||||
FISCAL_QUARTER_FIELD_NAME: [1.0, 1.0, 2.0, 2.0],
|
||||
FISCAL_YEAR_FIELD_NAME: [2015.0, 2015.0, 2015.0, 2015.0]
|
||||
})
|
||||
@@ -110,14 +130,12 @@ class EstimateTestCase(WithAssetFinder,
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
cls.events = gen_estimates()
|
||||
cls.sids = cls.events['sid'].unique()
|
||||
cls.columns = {
|
||||
Estimates.estimate: 'estimate',
|
||||
Estimates.event_date: EVENT_DATE_FIELD_NAME,
|
||||
Estimates.fiscal_quarter: FISCAL_QUARTER_FIELD_NAME,
|
||||
Estimates.fiscal_year: FISCAL_YEAR_FIELD_NAME,
|
||||
Estimates.value: 'value',
|
||||
}
|
||||
cls.loader = cls.make_loader(
|
||||
events=cls.events,
|
||||
@@ -147,7 +165,138 @@ class EstimateTestCase(WithAssetFinder,
|
||||
)
|
||||
|
||||
|
||||
window_test_cases = [
|
||||
(window_len, start_idx, num_quarters_out) for
|
||||
(window_len, start_idx), num_quarters_out in
|
||||
itertools.product(
|
||||
[[5, pd.Timestamp('2015-01-09').tz_localize('utc')],
|
||||
[6, pd.Timestamp('2015-01-12').tz_localize('utc')],
|
||||
[11, pd.Timestamp('2015-01-20').tz_localize('utc')],
|
||||
[19, pd.Timestamp('2015-01-30').tz_localize('utc')],
|
||||
[26, pd.Timestamp('2015-02-10').tz_localize('utc')]],
|
||||
[1, 2, 3, 4])
|
||||
]
|
||||
|
||||
|
||||
class NextEstimateWindowsTestCase(EstimateTestCase):
|
||||
events = estimates_timeline
|
||||
START_DATE = pd.Timestamp('2014-12-31')
|
||||
END_DATE = pd.Timestamp('2015-02-15')
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return NextQuartersEstimatesLoader(events, columns)
|
||||
|
||||
@parameterized.expand(window_test_cases)
|
||||
def test_next_estimate_windows_at_quarter_boundaries(self,
|
||||
window_len,
|
||||
start_idx,
|
||||
num_quarters_out):
|
||||
"""
|
||||
Tests that we overwrite values with the correct quarter's estimate at
|
||||
the correct dates.
|
||||
"""
|
||||
dataset = QuartersEstimates(num_quarters_out)
|
||||
|
||||
class SomeFactor(CustomFactor):
|
||||
inputs = [dataset.estimate]
|
||||
window_length = window_len
|
||||
|
||||
def compute(self, today, assets, out, *inputs):
|
||||
unique_inputs = np.unique(inputs).tolist()
|
||||
requested_quarter = None
|
||||
if (pd.Timestamp('2015-02-10').tz_localize('utc') >= today >=
|
||||
pd.Timestamp('2015-01-05').tz_localize('utc')):
|
||||
next_quarter = estimates_timeline[
|
||||
estimates_timeline[EVENT_DATE_FIELD_NAME] >= today
|
||||
].min()[FISCAL_QUARTER_FIELD_NAME]
|
||||
requested_quarter = next_quarter + num_quarters_out - 1
|
||||
|
||||
# If we know something about the requested quarter, assert
|
||||
# that all our estimates in the window are about that quarter.
|
||||
if requested_quarter and requested_quarter <= 4 and \
|
||||
requested_quarter != 3:
|
||||
assert np.equal(unique_inputs, requested_quarter).all()
|
||||
else:
|
||||
# We don't have any information yet about the next quarter
|
||||
# or about the requested quarter; in that case, all our
|
||||
# estimates in the window should be NaN across time.
|
||||
assert np.isnan(unique_inputs).all()
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
engine.run_pipeline(
|
||||
Pipeline({'est': SomeFactor()}),
|
||||
start_date=start_idx,
|
||||
end_date=self.trading_days[-1],
|
||||
)
|
||||
|
||||
|
||||
class PreviousEstimateWindowsTestCase(EstimateTestCase):
|
||||
events = estimates_timeline
|
||||
START_DATE = pd.Timestamp('2014-12-31')
|
||||
END_DATE = pd.Timestamp('2015-02-15')
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return PreviousQuartersEstimatesLoader(events, columns)
|
||||
|
||||
@parameterized.expand(window_test_cases)
|
||||
def test_previous_estimate_windows_at_quarter_boundaries(self,
|
||||
window_len,
|
||||
start_idx,
|
||||
num_quarters_out):
|
||||
"""
|
||||
Tests that we overwrite values with the correct quarter's estimate at
|
||||
the correct dates.
|
||||
"""
|
||||
dataset = QuartersEstimates(num_quarters_out)
|
||||
|
||||
class SomeFactor(CustomFactor):
|
||||
inputs = [dataset.estimate]
|
||||
window_length = window_len
|
||||
|
||||
def compute(self, today, assets, out, *inputs):
|
||||
unique_inputs = np.unique(inputs).tolist()
|
||||
requested_quarter = None
|
||||
if today >= pd.Timestamp('2015-01-12').tz_localize('utc'):
|
||||
previous_quarter = estimates_timeline[
|
||||
estimates_timeline[EVENT_DATE_FIELD_NAME] <= today
|
||||
].max()[FISCAL_QUARTER_FIELD_NAME]
|
||||
requested_quarter = (
|
||||
previous_quarter - (num_quarters_out - 1)
|
||||
)
|
||||
|
||||
# If we know something about the requested quarter, assert
|
||||
# that all our estimates in the window are about that quarter.
|
||||
if requested_quarter and requested_quarter >= 0 and \
|
||||
requested_quarter != 3:
|
||||
assert np.equal(unique_inputs, requested_quarter).all()
|
||||
else:
|
||||
# We don't have any information yet about the previous
|
||||
# quarter
|
||||
# or about the requested quarter; in that case, all our
|
||||
# estimates in the window should be NaN across time.
|
||||
assert np.isnan(unique_inputs).all()
|
||||
|
||||
engine = SimplePipelineEngine(
|
||||
lambda x: self.loader,
|
||||
self.trading_days,
|
||||
self.asset_finder,
|
||||
)
|
||||
engine.run_pipeline(
|
||||
Pipeline({'est': SomeFactor()}),
|
||||
start_date=start_idx,
|
||||
end_date=self.trading_days[-1],
|
||||
)
|
||||
|
||||
|
||||
class NextEstimateTestCase(EstimateTestCase):
|
||||
events = gen_estimates()
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return NextQuartersEstimatesLoader(events, columns)
|
||||
@@ -229,6 +378,8 @@ class BlazeNextEstimateLoaderTestCase(NextEstimateTestCase):
|
||||
|
||||
|
||||
class PreviousEstimateTestCase(EstimateTestCase):
|
||||
events = gen_estimates()
|
||||
|
||||
@classmethod
|
||||
def make_loader(cls, events, columns):
|
||||
return PreviousQuartersEstimatesLoader(events, columns)
|
||||
@@ -314,26 +465,13 @@ class QuarterShiftTestCase(ZiplineTestCase):
|
||||
This tests, in isolation, quarter calculation logic for shifting quarters
|
||||
backwards/forwards from a starting point.
|
||||
"""
|
||||
def test_calc_forward_shift(self):
|
||||
def test_quarter_normalization(self):
|
||||
input_yrs = pd.Series([0] * 4)
|
||||
input_qtrs = pd.Series(range(1, 5))
|
||||
expected = pd.DataFrame(([yr, qtr] for yr in range(0, 4) for qtr
|
||||
in range(1, 5)))
|
||||
for i in range(0, 8):
|
||||
years, quarters = shift_quarters(i, input_yrs, input_qtrs)
|
||||
# Can't use assert_series_equal here with check_names=False
|
||||
# because that still fails due to name differences.
|
||||
assert years.equals(expected[i:i+4].reset_index(drop=True)[0])
|
||||
assert quarters.equals(expected[i:i+4].reset_index(drop=True)[1])
|
||||
|
||||
def test_calc_backward_shift(self):
|
||||
input_yrs = pd.Series([0] * 4)
|
||||
input_qtrs = pd.Series(range(4, 0, -1))
|
||||
expected = pd.DataFrame(([yr, qtr] for yr in range(0, -4, -1) for qtr
|
||||
in range(4, 0, -1)))
|
||||
for i in range(0, 8, 1):
|
||||
years, quarters = shift_quarters(-i, input_yrs, input_qtrs)
|
||||
# Can't use assert_series_equal here with check_names=False
|
||||
# because that still fails due to name differences.
|
||||
assert years.equals(expected[i:i+4].reset_index(drop=True)[0])
|
||||
assert quarters.equals(expected[i:i+4].reset_index(drop=True)[1])
|
||||
result_years, result_quarters = split_normalized_quarters(
|
||||
normalize_quarters(input_yrs, input_qtrs)
|
||||
)
|
||||
# Can't use assert_series_equal here with check_names=False
|
||||
# because that still fails due to name differences.
|
||||
assert input_yrs.equals(result_years)
|
||||
assert input_qtrs.equals(result_quarters)
|
||||
|
||||
+10
-14
@@ -371,18 +371,6 @@ cdef class ArrayAdjustment(Adjustment):
|
||||
Subclasses should inherit and provide a `values` attribute and a `mutate`
|
||||
method.
|
||||
"""
|
||||
def __init__(self,
|
||||
int64_t first_row,
|
||||
int64_t last_row,
|
||||
int64_t first_col,
|
||||
int64_t last_col):
|
||||
super(ArrayAdjustment, self).__init__(
|
||||
first_row=first_row,
|
||||
last_row=last_row,
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"%s(first_row=%d, last_row=%d,"
|
||||
@@ -441,7 +429,11 @@ cdef class Float641DArrayOverwrite(ArrayAdjustment):
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
assert (last_row + 1 - first_row) == len(values)
|
||||
if last_row + 1 - first_row != len(values):
|
||||
raise ValueError(
|
||||
"Mismatch: got %d values for rows starting at index %d and "
|
||||
"ending at index %d." % (len(values), first_row, last_row)
|
||||
)
|
||||
self.values = values
|
||||
|
||||
cpdef mutate(self, float64_t[:, :] data):
|
||||
@@ -497,7 +489,11 @@ cdef class Datetime641DArrayOverwrite(ArrayAdjustment):
|
||||
first_col=first_col,
|
||||
last_col=last_col,
|
||||
)
|
||||
assert (last_row + 1 - first_row) == len(values)
|
||||
if last_row + 1 - first_row != len(values):
|
||||
raise ValueError("Mismatch: got %d values for rows starting at"
|
||||
" index %d and ending at index %d." % (
|
||||
len(values), first_row, last_row)
|
||||
)
|
||||
self.values = asarray([datetime_to_int(value) for value in values])
|
||||
|
||||
cpdef mutate(self, int64_t[:, :] data):
|
||||
|
||||
@@ -178,7 +178,8 @@ from zipline.pipeline.loaders.utils import (
|
||||
last_in_date_group,
|
||||
normalize_data_query_bounds,
|
||||
normalize_timestamp_to_query_time,
|
||||
ffill_across_cols)
|
||||
ffill_across_cols
|
||||
)
|
||||
from zipline.pipeline.sentinels import NotSpecified
|
||||
from zipline.lib.adjusted_array import AdjustedArray, can_represent_dtype
|
||||
from zipline.lib.adjustment import Float64Overwrite
|
||||
|
||||
@@ -167,7 +167,6 @@ class EventsLoader(PipelineLoader):
|
||||
return {}
|
||||
|
||||
return self._load_events(
|
||||
rows=self.events,
|
||||
name_map=self.next_value_columns,
|
||||
indexer=self.next_event_indexer(dates, sids),
|
||||
columns=columns,
|
||||
@@ -181,7 +180,6 @@ class EventsLoader(PipelineLoader):
|
||||
return {}
|
||||
|
||||
return self._load_events(
|
||||
rows=self.events,
|
||||
name_map=self.previous_value_columns,
|
||||
indexer=self.previous_event_indexer(dates, sids),
|
||||
columns=columns,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
from numpy.ma import asarray
|
||||
import pandas as pd
|
||||
from six import viewvalues
|
||||
from toolz import groupby, curry
|
||||
from toolz import groupby
|
||||
from zipline.lib.adjusted_array import AdjustedArray
|
||||
from zipline.lib.adjustment import (Datetime641DArrayOverwrite,
|
||||
Float641DArrayOverwrite)
|
||||
@@ -18,10 +16,15 @@ from zipline.pipeline.common import (
|
||||
TS_FIELD_NAME,
|
||||
)
|
||||
from zipline.pipeline.loaders.base import PipelineLoader
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.utils.numpy_utils import datetime64ns_dtype
|
||||
from zipline.utils.pandas_utils import cross_product
|
||||
from zipline.pipeline.loaders.utils import last_in_date_group, ffill_across_cols
|
||||
from zipline.pipeline.loaders.utils import (
|
||||
ffill_across_cols,
|
||||
last_in_date_group
|
||||
)
|
||||
|
||||
NORMALIZED_QUARTERS = 'normalized_quarters'
|
||||
|
||||
SHIFTED_NORMALIZED_QTRS = 'shifted_normalized_quarters'
|
||||
|
||||
NEXT_FISCAL_QUARTER = 'next_fiscal_quarter'
|
||||
NEXT_FISCAL_YEAR = 'next_fiscal_year'
|
||||
@@ -101,47 +104,184 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
def load_quarters(self, num_quarters, last, dates):
|
||||
pass
|
||||
|
||||
def get_adjustments(self, result, col_result, last,
|
||||
def get_requested_data_for_col(self, stacked_last_per_qtr, idx, dates):
|
||||
"""
|
||||
Selects the requested data for each date.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
The latest estimate known per sid per date per quarter with the
|
||||
dates, normalized quarter, and sid as the index.
|
||||
idx : pd.MultiIndex
|
||||
The index of the row of the requested quarter from each date for
|
||||
each sid.
|
||||
dates : pd.DatetimeIndex
|
||||
The calendar dates for which estimates data is requested.
|
||||
|
||||
Returns
|
||||
--------
|
||||
requested_qtr_data : pd.DataFrame
|
||||
The DataFrame with final values for the requested quarter for all
|
||||
columns; `dates` are the index and columns are a MultiIndex with
|
||||
sids at the top level and the dataset columns on the bottom.
|
||||
"""
|
||||
requested_qtr_data = stacked_last_per_qtr.loc[idx]
|
||||
# We no longer need this in the index, but we do need it as a column
|
||||
# to calculate adjustments.
|
||||
requested_qtr_data = requested_qtr_data.reset_index(
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
)
|
||||
(requested_qtr_data[FISCAL_YEAR_FIELD_NAME],
|
||||
requested_qtr_data[FISCAL_QUARTER_FIELD_NAME]) = \
|
||||
split_normalized_quarters(
|
||||
requested_qtr_data[SHIFTED_NORMALIZED_QTRS]
|
||||
)
|
||||
# Move sids into the columns. Once we're left with just dates
|
||||
# as the index, we can reindex by all dates so that we have a
|
||||
# value for each calendar date.
|
||||
requested_qtr_data = requested_qtr_data.unstack(
|
||||
SID_FIELD_NAME
|
||||
).reindex(dates)
|
||||
return requested_qtr_data
|
||||
|
||||
def get_adjustments(self,
|
||||
zero_qtr_idx,
|
||||
requested_qtr_idx,
|
||||
stacked_last_per_qtr,
|
||||
last_per_qtr,
|
||||
dates,
|
||||
column_name,
|
||||
column, mask,
|
||||
assets):
|
||||
column,
|
||||
mask,
|
||||
assets,
|
||||
qtr_crossover_point):
|
||||
"""
|
||||
Creates an AdjustedArray from the given estimates data for the given
|
||||
dates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
zero_qtr_idx : pd.MultiIndex
|
||||
The index of the row of the zeroth (immediately next/previous)
|
||||
quarter from each date for each sid.
|
||||
requested_qtr_idx : pd.MultiIndex
|
||||
The index of the row of the requested quarter from each date for
|
||||
each sid.
|
||||
stacked_last_per_qtr : pd.DataFrame
|
||||
The latest estimate known per sid per date per quarter with the
|
||||
dates, normalized quarter, and sid as the index.
|
||||
last_per_qtr : pd.DataFrame
|
||||
The latest estimate known per sid per date per quarter with
|
||||
dates as the index and normalized quarter and sid in the columns
|
||||
MultiIndex; allows easy access to the timeline of estimates
|
||||
across all dates for a sid for a particular quarter.
|
||||
dates : pd.DatetimeIndex
|
||||
The calendar dates for which estimates data is requested.
|
||||
column_name : string
|
||||
The name of the column for which the AdjustedArray is being
|
||||
computed.
|
||||
column : BoundColumn
|
||||
The column for which the AdjustedArray is being computed.
|
||||
mask :
|
||||
assets :
|
||||
qtr_crossover_point :
|
||||
Whether we should use the 'right' or 'left' side when doing
|
||||
searchsorted on the dates for quarter boundaries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
adjusted_array : AdjustedArray
|
||||
The array of data and overwrites for the given column.
|
||||
"""
|
||||
adjustments = defaultdict(list)
|
||||
requested_qtr_data = self.get_requested_data_for_col(
|
||||
stacked_last_per_qtr, requested_qtr_idx, dates
|
||||
)
|
||||
zero_qtr_data = stacked_last_per_qtr.loc[zero_qtr_idx]
|
||||
# We no longer need this in the index, but we do need it as a column
|
||||
# to calculate adjustments.
|
||||
zero_qtr_data = zero_qtr_data.reset_index(NORMALIZED_QUARTERS)
|
||||
if column.dtype == datetime64ns_dtype:
|
||||
overwrite = Datetime641DArrayOverwrite
|
||||
missing_value = np.datetime64('NaT', 'ns')
|
||||
else:
|
||||
overwrite = Float641DArrayOverwrite
|
||||
missing_value = np.NaN
|
||||
for sid_idx, sid in enumerate(assets):
|
||||
sid_result = result[result.index.get_level_values(
|
||||
SID_FIELD_NAME
|
||||
) == sid]
|
||||
sid_result = sid_result.reset_index(
|
||||
level='shifted_normalized_quarters'
|
||||
) # Remove qtrs from index to find shifts
|
||||
# Figure out where we think quarters are changing.
|
||||
qtr_shifts = sid_result[
|
||||
sid_result['shifted_normalized_quarters'] !=
|
||||
sid_result['shifted_normalized_quarters'].shift(1)
|
||||
zero_qtr_sid_data = zero_qtr_data[
|
||||
zero_qtr_data.index.get_level_values(SID_FIELD_NAME) == sid
|
||||
]
|
||||
# Iterate backwards. No adjustment for 1st quarter.
|
||||
for row_indexer in list(reversed(qtr_shifts.index))[:-1]:
|
||||
# We want to write the values for this row's quarter over
|
||||
# everything that comes before this quarter when we are at
|
||||
# the date before this quarter starts.
|
||||
qtr_start_idx = last.index.get_loc(row_indexer[0])
|
||||
quarter = qtr_shifts.loc[row_indexer][
|
||||
'shifted_normalized_quarters'
|
||||
]
|
||||
adjustments[qtr_start_idx] = \
|
||||
[overwrite(0,
|
||||
qtr_start_idx - 1, # get index date
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
last[column_name, quarter,
|
||||
sid][:qtr_start_idx].values)
|
||||
]
|
||||
# Determine where quarters are changing for this sid.
|
||||
qtr_shifts = zero_qtr_sid_data[
|
||||
zero_qtr_sid_data[NORMALIZED_QUARTERS] !=
|
||||
zero_qtr_sid_data[NORMALIZED_QUARTERS].shift(1)
|
||||
]
|
||||
# On dates where we don't have any information about quarters,
|
||||
# we will get nulls, and each of these will be interpreted as
|
||||
# quarter shifts. We need to remove these here.
|
||||
qtr_shifts = qtr_shifts[
|
||||
qtr_shifts[NORMALIZED_QUARTERS].notnull()
|
||||
]
|
||||
# For the given sid, determine which quarters we have estimates
|
||||
# for.
|
||||
quarters_with_estimates_for_sid = last_per_qtr.xs(
|
||||
sid, axis=1, level=SID_FIELD_NAME
|
||||
).groupby(axis=1, level=1).first().columns.values
|
||||
for row_indexer in list(qtr_shifts.index):
|
||||
# Find the starting index of the quarter that comes right
|
||||
# after this row. This isn't the starting index of the
|
||||
# requested quarter, but simply the date we cross over into a
|
||||
# new quarter.
|
||||
qtr_start_idx = dates.searchsorted(
|
||||
zero_qtr_data.loc[
|
||||
row_indexer
|
||||
][EVENT_DATE_FIELD_NAME],
|
||||
side=qtr_crossover_point
|
||||
)
|
||||
|
||||
# Only add adjustments if the next quarter starts somewhere in
|
||||
# our date index for this sid. Our 'next' quarter can never
|
||||
# start at index 0; a starting index of 0 means that the next
|
||||
# quarter's event date was NaT.
|
||||
if 0 < qtr_start_idx < len(dates):
|
||||
# Find the quarter being requested in the quarter we're
|
||||
# crossing into.
|
||||
requested_quarter = requested_qtr_data[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
][sid].iloc[qtr_start_idx]
|
||||
|
||||
# If there are estimates for the requested quarter,
|
||||
# overwrite all values going up to the starting index of
|
||||
# that quarter with estimates for that quarter.
|
||||
if requested_quarter in quarters_with_estimates_for_sid:
|
||||
adjustments[qtr_start_idx] = \
|
||||
[overwrite(
|
||||
0,
|
||||
qtr_start_idx - 1, # overwrite thru last qtr
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
last_per_qtr[column_name,
|
||||
requested_quarter,
|
||||
sid][:qtr_start_idx].values)]
|
||||
# There are no estimates for the quarter. Overwrite all
|
||||
# values going up to the starting index of that quarter
|
||||
# with the missing value for this column.
|
||||
else:
|
||||
adjustments[qtr_start_idx] = [
|
||||
overwrite(
|
||||
0,
|
||||
qtr_start_idx - 1,
|
||||
sid_idx,
|
||||
sid_idx,
|
||||
np.array(
|
||||
[missing_value] *
|
||||
len(last_per_qtr.index[:qtr_start_idx]))
|
||||
)
|
||||
]
|
||||
|
||||
return AdjustedArray(
|
||||
col_result.values.astype(column.dtype),
|
||||
requested_qtr_data[column_name].values.astype(column.dtype),
|
||||
mask,
|
||||
dict(adjustments),
|
||||
column.missing_value,
|
||||
@@ -152,110 +292,101 @@ class QuarterEstimatesLoader(PipelineLoader):
|
||||
# attribute, given that they're created dynamically?
|
||||
groups = groupby(lambda x: x.dataset.num_quarters, columns)
|
||||
groups_columns = dict(groups)
|
||||
if (pd.Series(groups_columns) < 0).any():
|
||||
if (pd.Series(groups_columns.keys()) < 0).any():
|
||||
raise ValueError("Must pass a number of quarters >= 0")
|
||||
out = {}
|
||||
date_values = pd.DataFrame({SIMULTATION_DATES: dates})
|
||||
# dates column must be of type datetime64[ns] in order for subsequent
|
||||
# comparisons to work correctly.
|
||||
date_values[SIMULTATION_DATES] = date_values[
|
||||
SIMULTATION_DATES
|
||||
].astype('datetime64[ns]')
|
||||
self.estimates['normalized_quarters'] = normalize_quarters(
|
||||
self.estimates[NORMALIZED_QUARTERS] = normalize_quarters(
|
||||
self.estimates[FISCAL_YEAR_FIELD_NAME],
|
||||
self.estimates[FISCAL_QUARTER_FIELD_NAME],
|
||||
).astype(float)
|
||||
for num_quarters, columns in groups_columns.iteritems():
|
||||
name_map = {c:
|
||||
self.base_column_name_map[
|
||||
)
|
||||
for num_quarters, columns in groups_columns.items():
|
||||
# The column's dataset is itself dynamic and the mapping we
|
||||
# actually want is to its dataset's parent's column name.
|
||||
name_map = {c: self.base_column_name_map[
|
||||
getattr(c.dataset.__base__, c.name)
|
||||
] for c in columns}
|
||||
# Determine the last piece of information we know for each column
|
||||
# on each date in the index.
|
||||
last = last_in_date_group(self.estimates, True, dates,
|
||||
assets,
|
||||
extra_groupers=[
|
||||
'normalized_quarters'])
|
||||
# Forward fill values for each quarter.
|
||||
ffill_across_cols(last, columns)
|
||||
stacked = last.stack(1).stack(1)
|
||||
# on each date in the index for each sid and quarter.
|
||||
last_per_qtr = last_in_date_group(
|
||||
self.estimates, True, dates, assets,
|
||||
extra_groupers=[NORMALIZED_QUARTERS]
|
||||
)
|
||||
|
||||
result = self.load_quarters(num_quarters, stacked)
|
||||
# Forward fill values for each quarter/sid/dataset column.
|
||||
ffill_across_cols(last_per_qtr, columns)
|
||||
# Stack quarter and sid into the index.
|
||||
stacked_last_per_qtr = last_per_qtr.stack([NORMALIZED_QUARTERS,
|
||||
SID_FIELD_NAME])
|
||||
# Set date index name for ease of reference
|
||||
stacked_last_per_qtr.index.set_names(SIMULTATION_DATES, 0, True)
|
||||
# Determine which quarter is next/previous for each date.
|
||||
shifted_qtr_data = self.load_quarters(num_quarters,
|
||||
stacked_last_per_qtr)
|
||||
zero_qtr_idx = shifted_qtr_data.index
|
||||
requested_qtr_idx = shifted_qtr_data.set_index([
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SIMULTATION_DATES
|
||||
),
|
||||
shifted_qtr_data[SHIFTED_NORMALIZED_QTRS],
|
||||
shifted_qtr_data.index.get_level_values(
|
||||
SID_FIELD_NAME
|
||||
)]
|
||||
).index
|
||||
|
||||
for c in columns:
|
||||
column_name = name_map[c]
|
||||
col_result = result[
|
||||
column_name
|
||||
].reset_index(1, drop=True).unstack(1).reindex(dates)
|
||||
adjusted_array = self.get_adjustments(result,
|
||||
col_result,
|
||||
last,
|
||||
adjusted_array = self.get_adjustments(zero_qtr_idx,
|
||||
requested_qtr_idx,
|
||||
stacked_last_per_qtr,
|
||||
last_per_qtr,
|
||||
dates,
|
||||
column_name,
|
||||
c,
|
||||
mask,
|
||||
assets)
|
||||
assets,
|
||||
self.qtr_crossover_point)
|
||||
out[c] = adjusted_array
|
||||
return out
|
||||
|
||||
|
||||
class NextQuartersEstimatesLoader(QuarterEstimatesLoader):
|
||||
qtr_crossover_point = 'right'
|
||||
|
||||
def load_quarters(self, num_quarters, stacked):
|
||||
def load_quarters(self, num_quarters, stacked_last_per_qtr):
|
||||
# Filter for releases that are on or after each simulation date and
|
||||
# determine the next quarter by picking out the upcoming release for
|
||||
# each date in the index.
|
||||
stacked = stacked.sort(EVENT_DATE_FIELD_NAME)
|
||||
next_releases = stacked.loc[
|
||||
stacked[EVENT_DATE_FIELD_NAME] >= stacked.index.get_level_values(
|
||||
0
|
||||
)].groupby(level=[0, 2]).nth(0)
|
||||
next_releases[
|
||||
'shifted_normalized_quarters'
|
||||
] = next_releases.index.get_level_values(
|
||||
'normalized_quarters'
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(
|
||||
EVENT_DATE_FIELD_NAME
|
||||
)
|
||||
next_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] >=
|
||||
stacked_last_per_qtr.index.get_level_values(SIMULTATION_DATES)
|
||||
].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(0)
|
||||
next_releases_per_date[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
] = next_releases_per_date.index.get_level_values(
|
||||
NORMALIZED_QUARTERS
|
||||
) + (num_quarters - 1)
|
||||
next_releases = next_releases.set_index([
|
||||
next_releases.index.get_level_values(0), # dates
|
||||
'shifted_normalized_quarters',
|
||||
next_releases.index.get_level_values(2) # sids
|
||||
])
|
||||
return stacked.loc[next_releases.index]
|
||||
return next_releases_per_date
|
||||
|
||||
|
||||
class PreviousQuartersEstimatesLoader(QuarterEstimatesLoader):
|
||||
def __init__(self,
|
||||
estimates,
|
||||
columns):
|
||||
super(PreviousQuartersEstimatesLoader, self).__init__(estimates,
|
||||
columns)
|
||||
qtr_crossover_point = 'left'
|
||||
|
||||
def load_quarters(self, num_quarters, dates_sids, final_releases_per_qtr):
|
||||
# Filter for releases that are on or before each simulation date.
|
||||
eligible_previous_releases = final_releases_per_qtr[
|
||||
final_releases_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
final_releases_per_qtr[SIMULTATION_DATES]
|
||||
]
|
||||
# For each sid, get the latest release.
|
||||
eligible_previous_releases.sort(EVENT_DATE_FIELD_NAME)
|
||||
previous_releases = eligible_previous_releases.groupby(
|
||||
[SIMULTATION_DATES, SID_FIELD_NAME]
|
||||
).nth(-1).reset_index() # We use nth here to avoid forward filling
|
||||
# NaNs, which `last()` will do.
|
||||
previous_releases = previous_releases.rename(columns={
|
||||
FISCAL_YEAR_FIELD_NAME: PREVIOUS_FISCAL_YEAR,
|
||||
FISCAL_QUARTER_FIELD_NAME: PREVIOUS_FISCAL_QUARTER
|
||||
})
|
||||
# The previous fiscal quarter is already our starting point,
|
||||
# so we should offset `num_quarters` by 1.
|
||||
(previous_releases[FISCAL_YEAR_FIELD_NAME],
|
||||
previous_releases[FISCAL_QUARTER_FIELD_NAME]) = shift_quarters(
|
||||
-(num_quarters - 1),
|
||||
previous_releases[PREVIOUS_FISCAL_YEAR],
|
||||
previous_releases[PREVIOUS_FISCAL_QUARTER],
|
||||
)
|
||||
# Do a left merge to get values for each date.
|
||||
result = dates_sids.merge(previous_releases,
|
||||
on=([SIMULTATION_DATES,
|
||||
SID_FIELD_NAME]),
|
||||
how='left')
|
||||
return result
|
||||
def load_quarters(self, num_quarters, stacked_last_per_qtr):
|
||||
# Filter for releases that are on or before each simulation date and
|
||||
# determine the previous quarter by picking out the upcoming release
|
||||
# for each date in the index.
|
||||
stacked_last_per_qtr = stacked_last_per_qtr.sort(EVENT_DATE_FIELD_NAME)
|
||||
previous_releases_per_date = stacked_last_per_qtr.loc[
|
||||
stacked_last_per_qtr[EVENT_DATE_FIELD_NAME] <=
|
||||
stacked_last_per_qtr.index.get_level_values(
|
||||
SIMULTATION_DATES
|
||||
)].groupby(level=[SIMULTATION_DATES, SID_FIELD_NAME]).nth(-1)
|
||||
previous_releases_per_date[
|
||||
SHIFTED_NORMALIZED_QTRS
|
||||
] = previous_releases_per_date.index.get_level_values(
|
||||
NORMALIZED_QUARTERS
|
||||
) - (num_quarters - 1)
|
||||
return previous_releases_per_date
|
||||
|
||||
@@ -278,11 +278,41 @@ def check_data_query_args(data_query_time, data_query_tz):
|
||||
|
||||
def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
extra_groupers=[]):
|
||||
"""
|
||||
Determine the last piece of information known on each date in the date
|
||||
index for each group.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The DataFrame containing the data to be grouped.
|
||||
reindex : bool
|
||||
Whether or not the DataFrame should be reindexed against the date
|
||||
index. This will add back any dates to the index that were grouped
|
||||
away.
|
||||
dates : pd.DatetimeIndex
|
||||
The dates to use for grouping and reindexing.
|
||||
assets : pd.Int64Index
|
||||
The assets that should be included in the column multiindex.
|
||||
have_sids : bool
|
||||
Whether or not the DataFrame has sids. If it does, they will be used
|
||||
in the groupby.
|
||||
extra_groupers : list of str
|
||||
Any extra field names that should be included in the groupby.
|
||||
|
||||
Returns
|
||||
-------
|
||||
last_in_group : pd.DataFrame
|
||||
A DataFrame with dates as the index and fields used in the groupby as
|
||||
levels of a multiindex of columns.
|
||||
|
||||
"""
|
||||
idx = dates[dates.searchsorted(
|
||||
df[TS_FIELD_NAME].values.astype('datetime64[D]')
|
||||
)]
|
||||
if have_sids:
|
||||
idx = [idx, SID_FIELD_NAME] + extra_groupers
|
||||
idx = [idx, SID_FIELD_NAME]
|
||||
idx += extra_groupers
|
||||
|
||||
last_in_group = df.drop(TS_FIELD_NAME, axis=1).groupby(
|
||||
idx,
|
||||
@@ -291,8 +321,7 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
|
||||
# For the number of things that we're grouping by (except TS), unstack
|
||||
# the df
|
||||
for _ in range(len(idx) - 1):
|
||||
last_in_group = last_in_group.unstack()
|
||||
last_in_group = last_in_group.unstack([-1, -2])
|
||||
|
||||
if reindex:
|
||||
if have_sids:
|
||||
@@ -311,6 +340,18 @@ def last_in_date_group(df, reindex, dates, assets, have_sids=True,
|
||||
|
||||
|
||||
def ffill_across_cols(df, columns):
|
||||
"""
|
||||
Forward fill values in a DataFrame with special logic to handle cases
|
||||
that pd.DataFrame.ffill cannot and cast columns to appropriate types.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pd.DataFrame
|
||||
The DataFrame to do forward-filling on.
|
||||
columns : list of BoundColumn
|
||||
The BoundColumns that correspond to columns in the DataFrame to which
|
||||
special filling and/or casting logic should be applied.
|
||||
"""
|
||||
df.ffill(inplace=True)
|
||||
|
||||
# Fill in missing values specified by each column. This is made
|
||||
|
||||
Reference in New Issue
Block a user