Merge pull request #1035 from quantopian/refactor-atomic

MAINT: Remove notion of "atomic" pipeline terms.
This commit is contained in:
Scott Sanderson
2016-03-08 14:50:24 -05:00
15 changed files with 90 additions and 60 deletions
+3 -3
View File
@@ -7,7 +7,7 @@ from unittest import TestCase
from zipline.errors import (
DTypeNotSpecified,
InputTermNotAtomic,
WindowedInputToWindowedTerm,
NotDType,
TermInputsNotSpecified,
UnsupportedDType,
@@ -157,7 +157,7 @@ class DependencyResolutionTestCase(TestCase):
self.assertEqual(graph.extra_rows[bar], 4)
self.assertEqual(graph.extra_rows[buzz], 4)
def test_reuse_atomic_terms(self):
def test_reuse_loadable_terms(self):
"""
Test that raw inputs only show up in the dependency graph once.
"""
@@ -174,7 +174,7 @@ class DependencyResolutionTestCase(TestCase):
def test_disallow_recursive_lookback(self):
with self.assertRaises(InputTermNotAtomic):
with self.assertRaises(WindowedInputToWindowedTerm):
SomeFactor(inputs=[SomeFactor(), SomeDataSet.foo])
+8 -4
View File
@@ -347,13 +347,17 @@ class WindowLengthNotPositive(ZiplineError):
).strip()
class InputTermNotAtomic(ZiplineError):
class WindowedInputToWindowedTerm(ZiplineError):
"""
Raised when a non-atomic term is specified as an input to a Pipeline API
term with a lookback window.
Raised when a windowed Pipeline API term is specified as an input to
another windowed term.
This is an error because it's generally not safe to compose windowed
functions on split/dividend adjusted data.
"""
msg = (
"Can't compute {parent} with non-atomic input {child}."
"Can't compute windowed expression {parent} with "
"windowed input {child}."
)
+2 -2
View File
@@ -2,8 +2,8 @@
classifier.py
"""
from zipline.pipeline.term import CompositeTerm
from zipline.pipeline.term import ComputableTerm
class Classifier(CompositeTerm):
class Classifier(ComputableTerm):
pass
+3 -2
View File
@@ -8,9 +8,10 @@ from six import (
)
from zipline.pipeline.term import (
Term,
AssetExists,
LoadableTerm,
NotSpecified,
Term,
)
from zipline.utils.input_validation import ensure_dtype
from zipline.utils.numpy_utils import (
@@ -87,7 +88,7 @@ class _BoundColumnDescr(object):
)
class BoundColumn(Term):
class BoundColumn(LoadableTerm):
"""
A column of data that's been concretely bound to a particular dataset.
+8 -14
View File
@@ -25,7 +25,7 @@ from zipline.errors import NoFurtherDataError
from zipline.utils.numpy_utils import repeat_first_axis, repeat_last_axis
from zipline.utils.pandas_utils import explode
from .term import AssetExists
from .term import AssetExists, LoadableTerm
class PipelineEngine(with_metaclass(ABCMeta)):
@@ -83,7 +83,7 @@ class SimplePipelineEngine(object):
Parameters
----------
get_loader : callable
A function that is given an atomic term and returns a PipelineLoader
A function that is given a loadable term and returns a PipelineLoader
to use to retrieve raw data for that term.
calendar : DatetimeIndex
Array of dates to consider as trading days when computing a range
@@ -279,12 +279,6 @@ class SimplePipelineEngine(object):
return out
def get_loader(self, term):
# AssetExists is one of the atomic terms in the graph, so we look up
# a loader here when grouping by loader, but since it's already in the
# workspace, we don't actually use that group.
if term is AssetExists():
return None
return self._get_loader(term)
def compute_chunk(self, graph, dates, assets, initial_workspace):
@@ -316,14 +310,14 @@ class SimplePipelineEngine(object):
# Copy the supplied initial workspace so we don't mutate it in place.
workspace = initial_workspace.copy()
# If atomic terms share the same loader and extra_rows, load them all
# If loadable terms share the same loader and extra_rows, load them all
# together.
atomic_group_key = juxt(get_loader, getitem(graph.extra_rows))
atomic_groups = groupby(atomic_group_key, graph.atomic_terms)
loader_group_key = juxt(get_loader, getitem(graph.extra_rows))
loader_groups = groupby(loader_group_key, graph.loadable_terms)
for term in graph.ordered():
# `term` may have been supplied in `initial_workspace`, and in the
# future we may pre-compute atomic terms coming from the same
# future we may pre-compute loadable terms coming from the same
# dataset. In either case, we will already have an entry for this
# term, which we shouldn't re-compute.
if term in workspace:
@@ -335,9 +329,9 @@ class SimplePipelineEngine(object):
term, workspace, graph, dates
)
if term.atomic:
if isinstance(term, LoadableTerm):
to_load = sorted(
atomic_groups[atomic_group_key(term)],
loader_groups[loader_group_key(term)],
key=lambda t: t.dataset
)
loader = get_loader(term)
+2 -2
View File
@@ -12,7 +12,7 @@ from numpy import (
inf,
)
from zipline.pipeline.term import Term, CompositeTerm
from zipline.pipeline.term import Term, ComputableTerm
_VARIABLE_NAME_RE = re.compile("^(x_)([0-9]+)$")
@@ -164,7 +164,7 @@ def is_comparison(op):
return op in COMPARISONS
class NumericalExpression(CompositeTerm):
class NumericalExpression(ComputableTerm):
"""
Term binding to a numexpr expression.
+2 -2
View File
@@ -18,7 +18,7 @@ from zipline.pipeline.mixins import (
PositiveWindowLengthMixin,
SingleInputMixin,
)
from zipline.pipeline.term import CompositeTerm, NotSpecified
from zipline.pipeline.term import ComputableTerm, NotSpecified
from zipline.pipeline.expression import (
BadBinaryOperator,
COMPARISONS,
@@ -343,7 +343,7 @@ def if_not_float64_tell_caller_to_use_isnull(f):
FACTOR_DTYPES = frozenset([datetime64ns_dtype, float64_dtype, int64_dtype])
class Factor(CompositeTerm):
class Factor(ComputableTerm):
"""
Pipeline API expression producing numerically-valued outputs.
"""
+2 -2
View File
@@ -19,7 +19,7 @@ from zipline.pipeline.mixins import (
PositiveWindowLengthMixin,
SingleInputMixin,
)
from zipline.pipeline.term import CompositeTerm
from zipline.pipeline.term import ComputableTerm
from zipline.pipeline.expression import (
BadBinaryOperator,
FILTER_BINOPS,
@@ -112,7 +112,7 @@ def unary_operator(op):
return unary_operator
class Filter(CompositeTerm):
class Filter(ComputableTerm):
"""
Pipeline API expression producing boolean-valued outputs.
"""
+4 -2
View File
@@ -9,6 +9,8 @@ from six import itervalues, iteritems
from zipline.utils.memoize import lazyval
from zipline.pipeline.visualize import display_graph
from .term import LoadableTerm
class CyclicDependency(Exception):
pass
@@ -163,8 +165,8 @@ class TermGraph(DiGraph):
return iter(self._ordered)
@lazyval
def atomic_terms(self):
return tuple(term for term in self if term.atomic)
def loadable_terms(self):
return tuple(term for term in self if isinstance(term, LoadableTerm))
def _add_to_graph(self, term, parents, extra_rows):
"""
@@ -25,7 +25,7 @@ class BlazeCashBuybackAuthorizationsLoader(BlazeEventsLoader):
expr : Expr
The expression representing the data to load.
resources : dict, optional
Mapping from the atomic terms of ``expr`` to actual data resources.
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
@@ -97,7 +97,7 @@ class BlazeShareBuybackAuthorizationsLoader(BlazeEventsLoader):
expr : Expr
The expression representing the data to load.
resources : dict, optional
Mapping from the atomic terms of ``expr`` to actual data resources.
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
+1 -1
View File
@@ -1059,7 +1059,7 @@ def bind_expression_to_resources(expr, resources):
expr : bz.Expr
The expression to which we want to bind resources.
resources : dict[bz.Symbol -> any]
Mapping from the atomic terms of ``expr`` to actual data resources.
Mapping from the loadable terms of ``expr`` to actual data resources.
Returns
-------
+1 -1
View File
@@ -17,7 +17,7 @@ class BlazeEarningsCalendarLoader(BlazeEventsLoader):
expr : Expr
The expression representing the data to load.
resources : dict, optional
Mapping from the atomic terms of ``expr`` to actual data resources.
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
+1 -1
View File
@@ -29,7 +29,7 @@ class BlazeEventsLoader(PipelineLoader):
expr : Expr
The expression representing the data to load.
resources : dict, optional
Mapping from the atomic terms of ``expr`` to actual data resources.
Mapping from the loadable terms of ``expr`` to actual data resources.
odo_kwargs : dict, optional
Extra keyword arguments to pass to odo when executing the expression.
data_query_time : time, optional
+50 -21
View File
@@ -8,7 +8,7 @@ from numpy import dtype as dtype_class
from six import with_metaclass
from zipline.errors import (
DTypeNotSpecified,
InputTermNotAtomic,
WindowedInputToWindowedTerm,
NotDType,
TermInputsNotSpecified,
UnsupportedDType,
@@ -268,27 +268,33 @@ class Term(with_metaclass(ABCMeta, object)):
@abstractproperty
def inputs(self):
"""
A tuple of other Terms that this Term requires for computation.
A tuple of other Terms needed as direct inputs for this Term.
"""
raise NotImplementedError()
raise NotImplementedError('inputs')
@abstractproperty
def windowed(self):
"""
Boolean indicating whether this term is a trailing-window computation.
"""
raise NotImplementedError('windowed')
@abstractproperty
def mask(self):
"""
A 2D Filter representing asset/date pairs to include while
A Filter representing asset/date pairs to include while
computing this Term. (True means include; False means exclude.)
"""
raise NotImplementedError()
raise NotImplementedError('mask')
@lazyval
def dependencies(self):
"""
A tuple containing all terms that must be computed before this term can
be loaded or computed.
"""
return self.inputs + (self.mask,)
@lazyval
def atomic(self):
return not any(dep for dep in self.dependencies
if dep is not AssetExists())
class AssetExists(Term):
"""
@@ -310,12 +316,29 @@ class AssetExists(Term):
inputs = ()
dependencies = ()
mask = None
windowed = False
def __repr__(self):
return "AssetExists()"
class CompositeTerm(Term):
class LoadableTerm(Term):
"""
A Term that should be loaded from an external resource by a PipelineLoader.
This is the base class for :class:`zipline.pipeline.data.BoundColumn`.
"""
inputs = ()
windowed = False
class ComputableTerm(Term):
"""
A Term that should be computed from a tuple of inputs.
This is the base class for :class:`zipline.pipeline.Factor`,
:class:`zipline.pipeline.Filter`, and :class:`zipline.pipeline.Factor`.
"""
inputs = NotSpecified
window_length = NotSpecified
mask = NotSpecified
@@ -344,20 +367,24 @@ class CompositeTerm(Term):
if window_length is NotSpecified:
window_length = cls.window_length
return super(CompositeTerm, cls).__new__(cls, inputs=inputs, mask=mask,
window_length=window_length,
*args, **kwargs)
return super(ComputableTerm, cls).__new__(
cls,
inputs=inputs,
mask=mask,
window_length=window_length,
*args, **kwargs
)
def _init(self, inputs, window_length, mask, *args, **kwargs):
self.inputs = inputs
self.window_length = window_length
self.mask = mask
return super(CompositeTerm, self)._init(*args, **kwargs)
return super(ComputableTerm, self)._init(*args, **kwargs)
@classmethod
def static_identity(cls, inputs, window_length, mask, *args, **kwargs):
return (
super(CompositeTerm, cls).static_identity(*args, **kwargs),
super(ComputableTerm, cls).static_identity(*args, **kwargs),
inputs,
window_length,
mask,
@@ -378,16 +405,18 @@ class CompositeTerm(Term):
if self.window_length:
for child in self.inputs:
if not child.atomic:
raise InputTermNotAtomic(parent=self, child=child)
if child.windowed:
raise WindowedInputToWindowedTerm(parent=self, child=child)
return super(CompositeTerm, self)._validate()
return super(ComputableTerm, self)._validate()
def _compute(self, inputs, dates, assets, mask):
"""
Subclasses should implement this to perform actual computation.
This is `_compute` rather than just `compute` because `compute` is
reserved for user-supplied functions in CustomFactor.
This is named ``_compute`` rather than just ``compute`` because
``compute`` is reserved for user-supplied functions in
CustomFilter/CustomFactor/CustomClassifier.
"""
raise NotImplementedError()
+1 -1
View File
@@ -98,7 +98,7 @@ def _render(g, out, format_, include_asset_exists=False):
graph_attrs = {'rankdir': 'TB', 'splines': 'ortho'}
cluster_attrs = {'style': 'filled', 'color': 'lightgoldenrod1'}
in_nodes = g.atomic_terms
in_nodes = g.loadable_terms
out_nodes = list(g.outputs.values())
f = BytesIO()