mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 09:53:09 +08:00
ENH: Support multiple outputs for custom factors
This commit is contained in:
@@ -24,6 +24,14 @@ Enhancements
|
||||
dataframe. This model allows us to pass these writer objects around as a
|
||||
resource for other classes and functions to consume (:issue:`1109`).
|
||||
|
||||
* Implemented :class:`zipline.pipeline.factors.RecarrayField`, a new pipeline
|
||||
term designed to be the output type of a CustomFactor with multiple outputs.
|
||||
(:issue:`1119`)
|
||||
|
||||
* Added optional `outputs` parameter to :class:`zipline.pipeline.CustomFactor`.
|
||||
Custom factors are now capable of computing and returning multiple outputs,
|
||||
each of which are themselves a Factor. (:issue:`1119`)
|
||||
|
||||
Experimental Features
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
+209
-15
@@ -4,6 +4,7 @@ Tests for SimplePipelineEngine
|
||||
from __future__ import division
|
||||
from collections import OrderedDict
|
||||
from itertools import product
|
||||
from operator import add, sub
|
||||
|
||||
from nose_parameterized import parameterized
|
||||
from numpy import (
|
||||
@@ -11,6 +12,7 @@ from numpy import (
|
||||
array,
|
||||
concatenate,
|
||||
float32,
|
||||
float64,
|
||||
full,
|
||||
log,
|
||||
nan,
|
||||
@@ -38,19 +40,9 @@ from toolz import merge
|
||||
|
||||
from zipline.assets.synthetic import make_rotating_equity_info
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.pipeline.loaders.synthetic import PrecomputedLoader
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing, DataSet, Column
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
make_daily_bar_data,
|
||||
expected_daily_bar_values_2d,
|
||||
)
|
||||
from zipline.pipeline import CustomFactor, Pipeline
|
||||
from zipline.pipeline.data import Column, DataSet, USEquityPricing
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline import CustomFactor
|
||||
from zipline.pipeline.factors import (
|
||||
AverageDollarVolume,
|
||||
EWMA,
|
||||
@@ -60,6 +52,16 @@ from zipline.pipeline.factors import (
|
||||
MaxDrawdown,
|
||||
SimpleMovingAverage,
|
||||
)
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
expected_daily_bar_values_2d,
|
||||
make_daily_bar_data,
|
||||
PrecomputedLoader,
|
||||
)
|
||||
from zipline.pipeline.term import NotSpecified
|
||||
from zipline.testing import (
|
||||
product_upper_triangle,
|
||||
check_arrays,
|
||||
@@ -112,6 +114,28 @@ class OpenPrice(CustomFactor):
|
||||
out[:] = open
|
||||
|
||||
|
||||
class MultipleOutputs(CustomFactor):
|
||||
window_length = 1
|
||||
inputs = [USEquityPricing.open, USEquityPricing.close]
|
||||
outputs = ['open', 'close']
|
||||
|
||||
def compute(self, today, assets, out, open, close):
|
||||
out.open[:] = open
|
||||
out.close[:] = close
|
||||
|
||||
|
||||
class OpenCloseSumAndDiff(CustomFactor):
|
||||
"""
|
||||
Used for testing a CustomFactor with multiple outputs operating over a non-
|
||||
trivial window length.
|
||||
"""
|
||||
inputs = [USEquityPricing.open, USEquityPricing.close]
|
||||
|
||||
def compute(self, today, assets, out, open, close):
|
||||
out.sum_[:] = open.sum(axis=0) + close.sum(axis=0)
|
||||
out.diff[:] = open.sum(axis=0) - close.sum(axis=0)
|
||||
|
||||
|
||||
def assert_multi_index_is_product(testcase, index, *levels):
|
||||
"""Assert that a MultiIndex contains the product of `*levels`."""
|
||||
testcase.assertIsInstance(
|
||||
@@ -407,9 +431,9 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
alternating_mask = (AssetIDPlusDay() % 2).eq(0)
|
||||
expected_alternating_mask_result = array(
|
||||
[[False, True, False, True],
|
||||
[True, False, True, False],
|
||||
[False, True, False, True]],
|
||||
[[False, True, False, True],
|
||||
[True, False, True, False],
|
||||
[False, True, False, True]],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
@@ -510,6 +534,176 @@ class ConstantInputTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
),
|
||||
)
|
||||
|
||||
def test_factor_with_single_output(self):
|
||||
"""
|
||||
Test passing an `outputs` parameter of length 1 to a CustomFactor.
|
||||
"""
|
||||
dates = self.dates[5:10]
|
||||
assets = self.assets
|
||||
num_dates = len(dates)
|
||||
open = USEquityPricing.open
|
||||
open_values = [self.constants[open]] * num_dates
|
||||
open_values_as_tuple = [(self.constants[open],)] * num_dates
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: self.loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
single_output = OpenPrice(outputs=['open'])
|
||||
pipeline = Pipeline(
|
||||
columns={
|
||||
'open_instance': single_output,
|
||||
'open_attribute': single_output.open,
|
||||
},
|
||||
)
|
||||
results = engine.run_pipeline(pipeline, dates[0], dates[-1])
|
||||
|
||||
# The instance `single_output` itself will compute a numpy.recarray
|
||||
# when added as a column to our pipeline, so we expect its output
|
||||
# values to be 1-tuples.
|
||||
open_instance_expected = {
|
||||
asset: open_values_as_tuple for asset in assets
|
||||
}
|
||||
open_attribute_expected = {asset: open_values for asset in assets}
|
||||
|
||||
for colname, expected_values in (
|
||||
('open_instance', open_instance_expected),
|
||||
('open_attribute', open_attribute_expected)):
|
||||
column_results = results[colname].unstack()
|
||||
expected_results = DataFrame(
|
||||
expected_values, index=dates, columns=assets, dtype=float64,
|
||||
)
|
||||
assert_frame_equal(column_results, expected_results)
|
||||
|
||||
def test_factor_with_multiple_outputs(self):
|
||||
dates = self.dates[5:10]
|
||||
assets = self.assets
|
||||
asset_ids = self.asset_ids
|
||||
constants = self.constants
|
||||
open = USEquityPricing.open
|
||||
close = USEquityPricing.close
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: self.loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
def create_expected_results(expected_value, mask):
|
||||
expected_values = where(mask, expected_value, nan)
|
||||
return DataFrame(expected_values, index=dates, columns=assets)
|
||||
|
||||
cascading_mask = AssetIDPlusDay() < (asset_ids[-1] + dates[0].day)
|
||||
expected_cascading_mask_result = array(
|
||||
[[True, True, True, False],
|
||||
[True, True, False, False],
|
||||
[True, False, False, False],
|
||||
[False, False, False, False],
|
||||
[False, False, False, False]],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
alternating_mask = (AssetIDPlusDay() % 2).eq(0)
|
||||
expected_alternating_mask_result = array(
|
||||
[[False, True, False, True],
|
||||
[True, False, True, False],
|
||||
[False, True, False, True],
|
||||
[True, False, True, False],
|
||||
[False, True, False, True]],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
expected_no_mask_result = array(
|
||||
[[True, True, True, True],
|
||||
[True, True, True, True],
|
||||
[True, True, True, True],
|
||||
[True, True, True, True],
|
||||
[True, True, True, True]],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
masks = cascading_mask, alternating_mask, NotSpecified
|
||||
expected_mask_results = (
|
||||
expected_cascading_mask_result,
|
||||
expected_alternating_mask_result,
|
||||
expected_no_mask_result,
|
||||
)
|
||||
for mask, expected_mask in zip(masks, expected_mask_results):
|
||||
open_price, close_price = MultipleOutputs(mask=mask)
|
||||
pipeline = Pipeline(
|
||||
columns={'open_price': open_price, 'close_price': close_price},
|
||||
)
|
||||
if mask is not NotSpecified:
|
||||
pipeline.add(mask, 'mask')
|
||||
|
||||
results = engine.run_pipeline(pipeline, dates[0], dates[-1])
|
||||
for colname, case_column in (('open_price', open),
|
||||
('close_price', close)):
|
||||
if mask is not NotSpecified:
|
||||
mask_results = results['mask'].unstack()
|
||||
check_arrays(mask_results.values, expected_mask)
|
||||
output_results = results[colname].unstack()
|
||||
output_expected = create_expected_results(
|
||||
constants[case_column], expected_mask,
|
||||
)
|
||||
assert_frame_equal(output_results, output_expected)
|
||||
|
||||
def test_instance_of_factor_with_multiple_outputs(self):
|
||||
"""
|
||||
Test adding a CustomFactor instance, which has multiple outputs, as a
|
||||
pipeline column directly. Its computed values should be tuples
|
||||
containing the computed values of each of its outputs.
|
||||
"""
|
||||
dates = self.dates[5:10]
|
||||
assets = self.assets
|
||||
num_dates = len(dates)
|
||||
num_assets = len(assets)
|
||||
constants = self.constants
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: self.loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
open_values = [constants[USEquityPricing.open]] * num_assets
|
||||
close_values = [constants[USEquityPricing.close]] * num_assets
|
||||
expected_values = [list(zip(open_values, close_values))] * num_dates
|
||||
expected_results = DataFrame(
|
||||
expected_values, index=dates, columns=assets, dtype=float64,
|
||||
)
|
||||
|
||||
multiple_outputs = MultipleOutputs()
|
||||
pipeline = Pipeline(columns={'instance': multiple_outputs})
|
||||
results = engine.run_pipeline(pipeline, dates[0], dates[-1])
|
||||
instance_results = results['instance'].unstack()
|
||||
assert_frame_equal(instance_results, expected_results)
|
||||
|
||||
def test_custom_factor_outputs_parameter(self):
|
||||
dates = self.dates[5:10]
|
||||
assets = self.assets
|
||||
num_dates = len(dates)
|
||||
num_assets = len(assets)
|
||||
constants = self.constants
|
||||
engine = SimplePipelineEngine(
|
||||
lambda column: self.loader, self.dates, self.asset_finder,
|
||||
)
|
||||
|
||||
def create_expected_results(expected_value):
|
||||
expected_values = full(
|
||||
(num_dates, num_assets), expected_value, float64,
|
||||
)
|
||||
return DataFrame(expected_values, index=dates, columns=assets)
|
||||
|
||||
for window_length in range(1, 3):
|
||||
sum_, diff = OpenCloseSumAndDiff(
|
||||
outputs=['sum_', 'diff'], window_length=window_length,
|
||||
)
|
||||
pipeline = Pipeline(columns={'sum_': sum_, 'diff': diff})
|
||||
results = engine.run_pipeline(pipeline, dates[0], dates[-1])
|
||||
for colname, op in ('sum_', add), ('diff', sub):
|
||||
output_results = results[colname].unstack()
|
||||
output_expected = create_expected_results(
|
||||
op(
|
||||
constants[USEquityPricing.open] * window_length,
|
||||
constants[USEquityPricing.close] * window_length,
|
||||
)
|
||||
)
|
||||
assert_frame_equal(output_results, output_expected)
|
||||
|
||||
def test_loader_given_multiple_columns(self):
|
||||
|
||||
class Loader1DataSet1(DataSet):
|
||||
|
||||
+103
-1
@@ -10,10 +10,17 @@ from zipline.errors import (
|
||||
WindowedInputToWindowedTerm,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
TermOutputsEmpty,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
from zipline.pipeline import Classifier, Factor, Filter, TermGraph
|
||||
from zipline.pipeline import (
|
||||
Classifier,
|
||||
CustomFactor,
|
||||
Factor,
|
||||
Filter,
|
||||
TermGraph,
|
||||
)
|
||||
from zipline.pipeline.data import Column, DataSet
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
from zipline.pipeline.term import AssetExists, NotSpecified
|
||||
@@ -67,6 +74,19 @@ class NoLookbackFactor(Factor):
|
||||
window_length = 0
|
||||
|
||||
|
||||
class GenericCustomFactor(CustomFactor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.foo]
|
||||
|
||||
|
||||
class MultipleOutputs(CustomFactor):
|
||||
dtype = float64_dtype
|
||||
window_length = 5
|
||||
inputs = [SomeDataSet.foo, SomeDataSet.bar]
|
||||
outputs = ['alpha', 'beta']
|
||||
|
||||
|
||||
def gen_equivalent_factors():
|
||||
"""
|
||||
Return an iterator of SomeFactor instances that should all be the same
|
||||
@@ -210,6 +230,35 @@ class ObjectIdentityTestCase(TestCase):
|
||||
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
|
||||
)
|
||||
|
||||
mask = SomeFactor() + SomeOtherFactor()
|
||||
self.assertIs(SomeFactor(mask=mask), SomeFactor(mask=mask))
|
||||
|
||||
def test_instance_caching_multiple_outputs(self):
|
||||
self.assertIs(MultipleOutputs(), MultipleOutputs())
|
||||
self.assertIs(
|
||||
MultipleOutputs(),
|
||||
MultipleOutputs(outputs=MultipleOutputs.outputs),
|
||||
)
|
||||
self.assertIs(
|
||||
MultipleOutputs(
|
||||
outputs=[
|
||||
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
|
||||
],
|
||||
),
|
||||
MultipleOutputs(
|
||||
outputs=[
|
||||
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure that both methods of accessing our outputs return the same
|
||||
# things.
|
||||
multiple_outputs = MultipleOutputs()
|
||||
alpha, beta = MultipleOutputs()
|
||||
self.assertIs(alpha, multiple_outputs.alpha)
|
||||
self.assertIs(beta, multiple_outputs.beta)
|
||||
|
||||
def test_instance_non_caching(self):
|
||||
|
||||
f = SomeFactor()
|
||||
@@ -243,6 +292,30 @@ class ObjectIdentityTestCase(TestCase):
|
||||
|
||||
self.assertIsNot(orig_foobar_instance, SomeFactor())
|
||||
|
||||
def test_instance_non_caching_multiple_outputs(self):
|
||||
multiple_outputs = MultipleOutputs()
|
||||
|
||||
# Different outputs.
|
||||
self.assertIsNot(
|
||||
MultipleOutputs(), MultipleOutputs(outputs=['beta', 'gamma']),
|
||||
)
|
||||
|
||||
# Reordering outputs.
|
||||
self.assertIsNot(
|
||||
multiple_outputs,
|
||||
MultipleOutputs(
|
||||
outputs=[
|
||||
MultipleOutputs.outputs[1], MultipleOutputs.outputs[0],
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Different factors sharing an output name should produce different
|
||||
# RecarrayField factors.
|
||||
orig_beta = multiple_outputs.beta
|
||||
beta, gamma = MultipleOutputs(outputs=['beta', 'gamma'])
|
||||
self.assertIsNot(beta, orig_beta)
|
||||
|
||||
def test_instance_caching_binops(self):
|
||||
f = SomeFactor()
|
||||
g = SomeOtherFactor()
|
||||
@@ -343,6 +416,35 @@ class ObjectIdentityTestCase(TestCase):
|
||||
with self.assertRaises(UnsupportedDType):
|
||||
SomeFactor(dtype=complex128_dtype)
|
||||
|
||||
with self.assertRaises(TermOutputsEmpty):
|
||||
MultipleOutputs(outputs=[])
|
||||
|
||||
def test_bad_output_access(self):
|
||||
with self.assertRaises(AttributeError) as e:
|
||||
SomeFactor().not_an_attr
|
||||
|
||||
errmsg = str(e.exception)
|
||||
self.assertEqual(
|
||||
errmsg, "'SomeFactor' object has no attribute 'not_an_attr'",
|
||||
)
|
||||
|
||||
with self.assertRaises(AttributeError) as e:
|
||||
MultipleOutputs().not_an_attr
|
||||
|
||||
errmsg = str(e.exception)
|
||||
self.assertEqual(
|
||||
errmsg,
|
||||
"Instance of MultipleOutputs has no output called 'not_an_attr'.",
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError) as e:
|
||||
alpha, beta = GenericCustomFactor()
|
||||
|
||||
errmsg = str(e.exception)
|
||||
self.assertEqual(
|
||||
errmsg, "GenericCustomFactor does not have multiple outputs.",
|
||||
)
|
||||
|
||||
def test_require_super_call_in_validate(self):
|
||||
|
||||
class MyFactor(Factor):
|
||||
|
||||
+12
-2
@@ -433,10 +433,20 @@ class TermInputsNotSpecified(ZiplineError):
|
||||
msg = "{termname} requires inputs, but no inputs list was passed."
|
||||
|
||||
|
||||
class TermOutputsEmpty(ZiplineError):
|
||||
"""
|
||||
Raised if a user attempts to construct a term with an empty outputs list.
|
||||
"""
|
||||
msg = (
|
||||
"{termname} requires at least one output when passed an outputs "
|
||||
"argument."
|
||||
)
|
||||
|
||||
|
||||
class WindowLengthNotSpecified(ZiplineError):
|
||||
"""
|
||||
Raised if a user attempts to construct a term without specifying inputs and
|
||||
that term does not have class-level default inputs.
|
||||
Raised if a user attempts to construct a term without specifying window
|
||||
length and that term does not have a class-level default window length.
|
||||
"""
|
||||
msg = (
|
||||
"{termname} requires a window_length, but no window_length was passed."
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from .factor import (
|
||||
CustomFactor,
|
||||
Factor,
|
||||
Latest
|
||||
Latest,
|
||||
RecarrayField,
|
||||
)
|
||||
from .events import (
|
||||
BusinessDaysSinceCashBuybackAuth,
|
||||
@@ -44,6 +45,7 @@ __all__ = [
|
||||
'Latest',
|
||||
'MaxDrawdown',
|
||||
'RSI',
|
||||
'RecarrayField',
|
||||
'Returns',
|
||||
'SimpleMovingAverage',
|
||||
'VWAP',
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
factor.py
|
||||
"""
|
||||
from functools import wraps
|
||||
from functools import partial, wraps
|
||||
from operator import attrgetter
|
||||
from numbers import Number
|
||||
|
||||
@@ -1130,8 +1130,13 @@ class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
|
||||
inputs : iterable, optional
|
||||
An iterable of `BoundColumn` instances (e.g. USEquityPricing.close),
|
||||
describing the data to load and pass to `self.compute`. If this
|
||||
argument is passed to the CustomFactor constructor, we look for a
|
||||
argument is not passed to the CustomFactor constructor, we look for a
|
||||
class-level attribute named `inputs`.
|
||||
outputs : iterable[str], optional
|
||||
An iterable of strings which represent the names of each output this
|
||||
factor should compute and return. If this argument is not passed to the
|
||||
CustomFactor constructor, we look for a class-level attribute named
|
||||
`outputs`.
|
||||
window_length : int, optional
|
||||
Number of rows to pass for each input. If this argument is not passed
|
||||
to the CustomFactor constructor, we look for a class-level attribute
|
||||
@@ -1164,7 +1169,9 @@ class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
|
||||
Column labels for `out` and`inputs`.
|
||||
out : np.array[self.dtype, ndim=1]
|
||||
Output array of the same shape as `assets`. `compute` should write
|
||||
its desired return values into `out`.
|
||||
its desired return values into `out`. If multiple outputs are
|
||||
specified, `compute` should write its desired return values into
|
||||
`out.<output_name>` for each output name in `self.outputs`.
|
||||
*inputs : tuple of np.array
|
||||
Raw data arrays corresponding to the values of `self.inputs`.
|
||||
|
||||
@@ -1229,9 +1236,86 @@ class CustomFactor(PositiveWindowLengthMixin, CustomTermMixin, Factor):
|
||||
# MedianValue.
|
||||
median_close10 = MedianValue([USEquityPricing.close], window_length=10)
|
||||
median_low15 = MedianValue([USEquityPricing.low], window_length=15)
|
||||
|
||||
A CustomFactor with multiple outputs:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class MultipleOutputs(CustomFactor):
|
||||
inputs = [USEquityPricing.close]
|
||||
outputs = ['alpha', 'beta']
|
||||
window_length = N
|
||||
|
||||
def compute(self, today, assets, out, close):
|
||||
computed_alpha, computed_beta = some_function(close)
|
||||
out.alpha[:] = computed_alpha
|
||||
out.beta[:] = computed_beta
|
||||
|
||||
# Each output is returned as its own Factor upon instantiation.
|
||||
alpha, beta = MultipleOutputs()
|
||||
|
||||
# Equivalently, we can create a single factor instance and access each
|
||||
# output as an attribute of that instance.
|
||||
multiple_outputs = MultipleOutputs()
|
||||
alpha = multiple_outputs.alpha
|
||||
beta = multiple_outputs.beta
|
||||
|
||||
Note: If a CustomFactor has multiple outputs, all outputs must have the
|
||||
same dtype. For instance, in the example above, if alpha is a float then
|
||||
beta must also be a float.
|
||||
'''
|
||||
dtype = float64_dtype
|
||||
|
||||
def __getattr__(self, attribute_name):
|
||||
if self.outputs is NotSpecified:
|
||||
return getattr(super(CustomFactor, self), attribute_name)
|
||||
if attribute_name in self.outputs:
|
||||
return RecarrayField(factor=self, attribute=attribute_name)
|
||||
else:
|
||||
raise AttributeError(
|
||||
'Instance of {factor} has no output called {attr!r}.'.format(
|
||||
factor=type(self).__name__, attr=attribute_name,
|
||||
)
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
if self.outputs is NotSpecified:
|
||||
raise ValueError(
|
||||
'{factor} does not have multiple outputs.'.format(
|
||||
factor=type(self).__name__,
|
||||
)
|
||||
)
|
||||
RecarrayField_ = partial(RecarrayField, self)
|
||||
return iter(map(RecarrayField_, self.outputs))
|
||||
|
||||
|
||||
class RecarrayField(SingleInputMixin, Factor):
|
||||
|
||||
def __new__(cls, factor, attribute):
|
||||
return super(RecarrayField, cls).__new__(
|
||||
cls,
|
||||
attribute=attribute,
|
||||
inputs=[factor],
|
||||
window_length=0,
|
||||
mask=factor.mask,
|
||||
dtype=factor.dtype,
|
||||
missing_value=factor.missing_value,
|
||||
)
|
||||
|
||||
def _init(self, attribute, *args, **kwargs):
|
||||
self._attribute = attribute
|
||||
return super(RecarrayField, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, attribute, *args, **kwargs):
|
||||
return (
|
||||
super(RecarrayField, cls).static_identity(*args, **kwargs),
|
||||
attribute,
|
||||
)
|
||||
|
||||
def _compute(self, windows, dates, assets, mask):
|
||||
return windows[0][self._attribute]
|
||||
|
||||
|
||||
class Latest(LatestMixin, CustomFactor):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Mixins classes for use with Filters and Factors.
|
||||
"""
|
||||
from numpy import full_like
|
||||
from numpy import full_like, recarray
|
||||
|
||||
from zipline.utils.control_flow import nullctx
|
||||
from zipline.errors import WindowLengthNotPositive, UnsupportedDataType
|
||||
@@ -69,6 +69,7 @@ class CustomTermMixin(object):
|
||||
|
||||
def __new__(cls,
|
||||
inputs=NotSpecified,
|
||||
outputs=NotSpecified,
|
||||
window_length=NotSpecified,
|
||||
mask=NotSpecified,
|
||||
dtype=NotSpecified,
|
||||
@@ -88,6 +89,7 @@ class CustomTermMixin(object):
|
||||
return super(CustomTermMixin, cls).__new__(
|
||||
cls,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
window_length=window_length,
|
||||
mask=mask,
|
||||
dtype=dtype,
|
||||
@@ -109,7 +111,16 @@ class CustomTermMixin(object):
|
||||
compute = self.compute
|
||||
missing_value = self.missing_value
|
||||
params = self.params
|
||||
out = full_like(mask, missing_value, dtype=self.dtype)
|
||||
outputs = self.outputs
|
||||
if outputs is not NotSpecified:
|
||||
out = recarray(
|
||||
mask.shape,
|
||||
formats=[self.dtype.str] * len(outputs),
|
||||
names=outputs,
|
||||
)
|
||||
out[:] = missing_value
|
||||
else:
|
||||
out = full_like(mask, missing_value, dtype=self.dtype)
|
||||
with self.ctx:
|
||||
# TODO: Consider pre-filtering columns that are all-nan at each
|
||||
# time-step?
|
||||
|
||||
@@ -11,6 +11,7 @@ from zipline.errors import (
|
||||
WindowedInputToWindowedTerm,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
TermOutputsEmpty,
|
||||
UnsupportedDType,
|
||||
WindowLengthNotSpecified,
|
||||
)
|
||||
@@ -349,11 +350,13 @@ class ComputableTerm(Term):
|
||||
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Factor`.
|
||||
"""
|
||||
inputs = NotSpecified
|
||||
outputs = NotSpecified
|
||||
window_length = NotSpecified
|
||||
mask = NotSpecified
|
||||
|
||||
def __new__(cls,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
window_length=window_length,
|
||||
mask=mask,
|
||||
*args, **kwargs):
|
||||
@@ -368,6 +371,11 @@ class ComputableTerm(Term):
|
||||
# normalize to a tuple so that inputs is hashable.
|
||||
inputs = tuple(inputs)
|
||||
|
||||
if outputs is NotSpecified:
|
||||
outputs = cls.outputs
|
||||
if outputs is not NotSpecified:
|
||||
outputs = tuple(outputs)
|
||||
|
||||
if mask is NotSpecified:
|
||||
mask = cls.mask
|
||||
if mask is NotSpecified:
|
||||
@@ -379,22 +387,31 @@ class ComputableTerm(Term):
|
||||
return super(ComputableTerm, cls).__new__(
|
||||
cls,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
mask=mask,
|
||||
window_length=window_length,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
def _init(self, inputs, window_length, mask, *args, **kwargs):
|
||||
def _init(self, inputs, outputs, window_length, mask, *args, **kwargs):
|
||||
self.inputs = inputs
|
||||
self.outputs = outputs
|
||||
self.window_length = window_length
|
||||
self.mask = mask
|
||||
return super(ComputableTerm, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
|
||||
def static_identity(cls,
|
||||
inputs,
|
||||
outputs,
|
||||
window_length,
|
||||
mask,
|
||||
*args,
|
||||
**kwargs):
|
||||
return (
|
||||
super(ComputableTerm, cls).static_identity(*args, **kwargs),
|
||||
inputs,
|
||||
outputs,
|
||||
window_length,
|
||||
mask,
|
||||
)
|
||||
@@ -405,6 +422,9 @@ class ComputableTerm(Term):
|
||||
if self.inputs is NotSpecified:
|
||||
raise TermInputsNotSpecified(termname=type(self).__name__)
|
||||
|
||||
if not self.outputs:
|
||||
raise TermOutputsEmpty(termname=type(self).__name__)
|
||||
|
||||
if self.window_length is NotSpecified:
|
||||
raise WindowLengthNotSpecified(termname=type(self).__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user