mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 21:59:10 +08:00
Merge pull request #1035 from quantopian/refactor-atomic
MAINT: Remove notion of "atomic" pipeline terms.
This commit is contained in:
@@ -7,7 +7,7 @@ from unittest import TestCase
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
WindowedInputToWindowedTerm,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
@@ -157,7 +157,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
self.assertEqual(graph.extra_rows[bar], 4)
|
||||
self.assertEqual(graph.extra_rows[buzz], 4)
|
||||
|
||||
def test_reuse_atomic_terms(self):
|
||||
def test_reuse_loadable_terms(self):
|
||||
"""
|
||||
Test that raw inputs only show up in the dependency graph once.
|
||||
"""
|
||||
@@ -174,7 +174,7 @@ class DependencyResolutionTestCase(TestCase):
|
||||
|
||||
def test_disallow_recursive_lookback(self):
|
||||
|
||||
with self.assertRaises(InputTermNotAtomic):
|
||||
with self.assertRaises(WindowedInputToWindowedTerm):
|
||||
SomeFactor(inputs=[SomeFactor(), SomeDataSet.foo])
|
||||
|
||||
|
||||
|
||||
+8
-4
@@ -347,13 +347,17 @@ class WindowLengthNotPositive(ZiplineError):
|
||||
).strip()
|
||||
|
||||
|
||||
class InputTermNotAtomic(ZiplineError):
|
||||
class WindowedInputToWindowedTerm(ZiplineError):
|
||||
"""
|
||||
Raised when a non-atomic term is specified as an input to a Pipeline API
|
||||
term with a lookback window.
|
||||
Raised when a windowed Pipeline API term is specified as an input to
|
||||
another windowed term.
|
||||
|
||||
This is an error because it's generally not safe to compose windowed
|
||||
functions on split/dividend adjusted data.
|
||||
"""
|
||||
msg = (
|
||||
"Can't compute {parent} with non-atomic input {child}."
|
||||
"Can't compute windowed expression {parent} with "
|
||||
"windowed input {child}."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
classifier.py
|
||||
"""
|
||||
|
||||
from zipline.pipeline.term import CompositeTerm
|
||||
from zipline.pipeline.term import ComputableTerm
|
||||
|
||||
|
||||
class Classifier(CompositeTerm):
|
||||
class Classifier(ComputableTerm):
|
||||
pass
|
||||
|
||||
@@ -8,9 +8,10 @@ from six import (
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import (
|
||||
Term,
|
||||
AssetExists,
|
||||
LoadableTerm,
|
||||
NotSpecified,
|
||||
Term,
|
||||
)
|
||||
from zipline.utils.input_validation import ensure_dtype
|
||||
from zipline.utils.numpy_utils import (
|
||||
@@ -87,7 +88,7 @@ class _BoundColumnDescr(object):
|
||||
)
|
||||
|
||||
|
||||
class BoundColumn(Term):
|
||||
class BoundColumn(LoadableTerm):
|
||||
"""
|
||||
A column of data that's been concretely bound to a particular dataset.
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ from zipline.errors import NoFurtherDataError
|
||||
from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis
|
||||
from zipline.utils.pandas_utils import explode
|
||||
|
||||
from .term import AssetExists
|
||||
from .term import AssetExists, LoadableTerm
|
||||
|
||||
|
||||
class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
@@ -83,7 +83,7 @@ class SimplePipelineEngine(object):
|
||||
Parameters
|
||||
----------
|
||||
get_loader : callable
|
||||
A function that is given an atomic term and returns a PipelineLoader
|
||||
A function that is given a loadable term and returns a PipelineLoader
|
||||
to use to retrieve raw data for that term.
|
||||
calendar : DatetimeIndex
|
||||
Array of dates to consider as trading days when computing a range
|
||||
@@ -279,12 +279,6 @@ class SimplePipelineEngine(object):
|
||||
return out
|
||||
|
||||
def get_loader(self, term):
|
||||
# AssetExists is one of the atomic terms in the graph, so we look up
|
||||
# a loader here when grouping by loader, but since it's already in the
|
||||
# workspace, we don't actually use that group.
|
||||
if term is AssetExists():
|
||||
return None
|
||||
|
||||
return self._get_loader(term)
|
||||
|
||||
def compute_chunk(self, graph, dates, assets, initial_workspace):
|
||||
@@ -316,14 +310,14 @@ class SimplePipelineEngine(object):
|
||||
# Copy the supplied initial workspace so we don't mutate it in place.
|
||||
workspace = initial_workspace.copy()
|
||||
|
||||
# If atomic terms share the same loader and extra_rows, load them all
|
||||
# If loadable terms share the same loader and extra_rows, load them all
|
||||
# together.
|
||||
atomic_group_key = juxt(get_loader, getitem(graph.extra_rows))
|
||||
atomic_groups = groupby(atomic_group_key, graph.atomic_terms)
|
||||
loader_group_key = juxt(get_loader, getitem(graph.extra_rows))
|
||||
loader_groups = groupby(loader_group_key, graph.loadable_terms)
|
||||
|
||||
for term in graph.ordered():
|
||||
# `term` may have been supplied in `initial_workspace`, and in the
|
||||
# future we may pre-compute atomic terms coming from the same
|
||||
# future we may pre-compute loadable terms coming from the same
|
||||
# dataset. In either case, we will already have an entry for this
|
||||
# term, which we shouldn't re-compute.
|
||||
if term in workspace:
|
||||
@@ -335,9 +329,9 @@ class SimplePipelineEngine(object):
|
||||
term, workspace, graph, dates
|
||||
)
|
||||
|
||||
if term.atomic:
|
||||
if isinstance(term, LoadableTerm):
|
||||
to_load = sorted(
|
||||
atomic_groups[atomic_group_key(term)],
|
||||
loader_groups[loader_group_key(term)],
|
||||
key=lambda t: t.dataset
|
||||
)
|
||||
loader = get_loader(term)
|
||||
|
||||
@@ -12,7 +12,7 @@ from numpy import (
|
||||
inf,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import Term, CompositeTerm
|
||||
from zipline.pipeline.term import Term, ComputableTerm
|
||||
|
||||
|
||||
_VARIABLE_NAME_RE = re.compile("^(x_)([0-9]+)$")
|
||||
@@ -164,7 +164,7 @@ def is_comparison(op):
|
||||
return op in COMPARISONS
|
||||
|
||||
|
||||
class NumericalExpression(CompositeTerm):
|
||||
class NumericalExpression(ComputableTerm):
|
||||
"""
|
||||
Term binding to a numexpr expression.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from zipline.pipeline.mixins import (
|
||||
PositiveWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.term import CompositeTerm, NotSpecified
|
||||
from zipline.pipeline.term import ComputableTerm, NotSpecified
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
COMPARISONS,
|
||||
@@ -343,7 +343,7 @@ def if_not_float64_tell_caller_to_use_isnull(f):
|
||||
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype, int64_dtype])
|
||||
|
||||
|
||||
class Factor(CompositeTerm):
|
||||
class Factor(ComputableTerm):
|
||||
"""
|
||||
Pipeline API expression producing numerically-valued outputs.
|
||||
"""
|
||||
|
||||
@@ -19,7 +19,7 @@ from zipline.pipeline.mixins import (
|
||||
PositiveWindowLengthMixin,
|
||||
SingleInputMixin,
|
||||
)
|
||||
from zipline.pipeline.term import CompositeTerm
|
||||
from zipline.pipeline.term import ComputableTerm
|
||||
from zipline.pipeline.expression import (
|
||||
BadBinaryOperator,
|
||||
FILTER_BINOPS,
|
||||
@@ -112,7 +112,7 @@ def unary_operator(op):
|
||||
return unary_operator
|
||||
|
||||
|
||||
class Filter(CompositeTerm):
|
||||
class Filter(ComputableTerm):
|
||||
"""
|
||||
Pipeline API expression producing boolean-valued outputs.
|
||||
"""
|
||||
|
||||
@@ -9,6 +9,8 @@ from six import itervalues, iteritems
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.pipeline.visualize import display_graph
|
||||
|
||||
from .term import LoadableTerm
|
||||
|
||||
|
||||
class CyclicDependency(Exception):
|
||||
pass
|
||||
@@ -163,8 +165,8 @@ class TermGraph(DiGraph):
|
||||
return iter(self._ordered)
|
||||
|
||||
@lazyval
|
||||
def atomic_terms(self):
|
||||
return tuple(term for term in self if term.atomic)
|
||||
def loadable_terms(self):
|
||||
return tuple(term for term in self if isinstance(term, LoadableTerm))
|
||||
|
||||
def _add_to_graph(self, term, parents, extra_rows):
|
||||
"""
|
||||
|
||||
@@ -25,7 +25,7 @@ class BlazeCashBuybackAuthorizationsLoader(BlazeEventsLoader):
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
resources : dict, optional
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
@@ -97,7 +97,7 @@ class BlazeShareBuybackAuthorizationsLoader(BlazeEventsLoader):
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
resources : dict, optional
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
|
||||
@@ -1059,7 +1059,7 @@ def bind_expression_to_resources(expr, resources):
|
||||
expr : bz.Expr
|
||||
The expression to which we want to bind resources.
|
||||
resources : dict[bz.Symbol -> any]
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
@@ -17,7 +17,7 @@ class BlazeEarningsCalendarLoader(BlazeEventsLoader):
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
resources : dict, optional
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
|
||||
@@ -29,7 +29,7 @@ class BlazeEventsLoader(PipelineLoader):
|
||||
expr : Expr
|
||||
The expression representing the data to load.
|
||||
resources : dict, optional
|
||||
Mapping from the atomic terms of ``expr`` to actual data resources.
|
||||
Mapping from the loadable terms of ``expr`` to actual data resources.
|
||||
odo_kwargs : dict, optional
|
||||
Extra keyword arguments to pass to odo when executing the expression.
|
||||
data_query_time : time, optional
|
||||
|
||||
+50
-21
@@ -8,7 +8,7 @@ from numpy import dtype as dtype_class
|
||||
from six import with_metaclass
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
InputTermNotAtomic,
|
||||
WindowedInputToWindowedTerm,
|
||||
NotDType,
|
||||
TermInputsNotSpecified,
|
||||
UnsupportedDType,
|
||||
@@ -268,27 +268,33 @@ class Term(with_metaclass(ABCMeta, object)):
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
"""
|
||||
A tuple of other Terms that this Term requires for computation.
|
||||
A tuple of other Terms needed as direct inputs for this Term.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError('inputs')
|
||||
|
||||
@abstractproperty
|
||||
def windowed(self):
|
||||
"""
|
||||
Boolean indicating whether this term is a trailing-window computation.
|
||||
"""
|
||||
raise NotImplementedError('windowed')
|
||||
|
||||
@abstractproperty
|
||||
def mask(self):
|
||||
"""
|
||||
A 2D Filter representing asset/date pairs to include while
|
||||
A Filter representing asset/date pairs to include while
|
||||
computing this Term. (True means include; False means exclude.)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError('mask')
|
||||
|
||||
@lazyval
|
||||
def dependencies(self):
|
||||
"""
|
||||
A tuple containing all terms that must be computed before this term can
|
||||
be loaded or computed.
|
||||
"""
|
||||
return self.inputs + (self.mask,)
|
||||
|
||||
@lazyval
|
||||
def atomic(self):
|
||||
return not any(dep for dep in self.dependencies
|
||||
if dep is not AssetExists())
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
@@ -310,12 +316,29 @@ class AssetExists(Term):
|
||||
inputs = ()
|
||||
dependencies = ()
|
||||
mask = None
|
||||
windowed = False
|
||||
|
||||
def __repr__(self):
|
||||
return "AssetExists()"
|
||||
|
||||
|
||||
class CompositeTerm(Term):
|
||||
class LoadableTerm(Term):
|
||||
"""
|
||||
A Term that should be loaded from an external resource by a PipelineLoader.
|
||||
|
||||
This is the base class for :class:`zipline.pipeline.data.BoundColumn`.
|
||||
"""
|
||||
inputs = ()
|
||||
windowed = False
|
||||
|
||||
|
||||
class ComputableTerm(Term):
|
||||
"""
|
||||
A Term that should be computed from a tuple of inputs.
|
||||
|
||||
This is the base class for :class:`zipline.pipeline.Factor`,
|
||||
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Factor`.
|
||||
"""
|
||||
inputs = NotSpecified
|
||||
window_length = NotSpecified
|
||||
mask = NotSpecified
|
||||
@@ -344,20 +367,24 @@ class CompositeTerm(Term):
|
||||
if window_length is NotSpecified:
|
||||
window_length = cls.window_length
|
||||
|
||||
return super(CompositeTerm, cls).__new__(cls, inputs=inputs, mask=mask,
|
||||
window_length=window_length,
|
||||
*args, **kwargs)
|
||||
return super(ComputableTerm, cls).__new__(
|
||||
cls,
|
||||
inputs=inputs,
|
||||
mask=mask,
|
||||
window_length=window_length,
|
||||
*args, **kwargs
|
||||
)
|
||||
|
||||
def _init(self, inputs, window_length, mask, *args, **kwargs):
|
||||
self.inputs = inputs
|
||||
self.window_length = window_length
|
||||
self.mask = mask
|
||||
return super(CompositeTerm, self)._init(*args, **kwargs)
|
||||
return super(ComputableTerm, self)._init(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
|
||||
return (
|
||||
super(CompositeTerm, cls).static_identity(*args, **kwargs),
|
||||
super(ComputableTerm, cls).static_identity(*args, **kwargs),
|
||||
inputs,
|
||||
window_length,
|
||||
mask,
|
||||
@@ -378,16 +405,18 @@ class CompositeTerm(Term):
|
||||
|
||||
if self.window_length:
|
||||
for child in self.inputs:
|
||||
if not child.atomic:
|
||||
raise InputTermNotAtomic(parent=self, child=child)
|
||||
if child.windowed:
|
||||
raise WindowedInputToWindowedTerm(parent=self, child=child)
|
||||
|
||||
return super(CompositeTerm, self)._validate()
|
||||
return super(ComputableTerm, self)._validate()
|
||||
|
||||
def _compute(self, inputs, dates, assets, mask):
|
||||
"""
|
||||
Subclasses should implement this to perform actual computation.
|
||||
This is `_compute` rather than just `compute` because `compute` is
|
||||
reserved for user-supplied functions in CustomFactor.
|
||||
|
||||
This is named ``_compute`` rather than just ``compute`` because
|
||||
``compute`` is reserved for user-supplied functions in
|
||||
CustomFilter/CustomFactor/CustomClassifier.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -98,7 +98,7 @@ def _render(g, out, format_, include_asset_exists=False):
|
||||
graph_attrs = {'rankdir': 'TB', 'splines': 'ortho'}
|
||||
cluster_attrs = {'style': 'filled', 'color': 'lightgoldenrod1'}
|
||||
|
||||
in_nodes = g.atomic_terms
|
||||
in_nodes = g.loadable_terms
|
||||
out_nodes = list(g.outputs.values())
|
||||
|
||||
f = BytesIO()
|
||||
|
||||
Reference in New Issue
Block a user