MAINT: Reworked Term atomicity

This commit is contained in:
Richard Frank
2015-10-09 17:32:22 -04:00
parent 940831e1cf
commit 2dabda6b76
5 changed files with 69 additions and 84 deletions
+2 -4
View File
@@ -86,10 +86,8 @@ class DependencyResolutionTestCase(TestCase):
seen = set()
for term in ordered_terms:
if not term.atomic:
for input_ in term.inputs:
self.assertIn(input_, seen)
self.assertIn(term.mask, seen)
for dep in term.dependencies:
self.assertIn(dep, seen)
seen.add(term)
+5 -2
View File
@@ -7,7 +7,7 @@ from six import (
with_metaclass,
)
from zipline.pipeline.term import AtomicTerm
from zipline.pipeline.term import Term, AssetExists
from zipline.pipeline.factors import Latest
@@ -26,10 +26,13 @@ class Column(object):
return BoundColumn(dtype=self.dtype, dataset=dataset, name=name)
class BoundColumn(AtomicTerm):
class BoundColumn(Term):
"""
A Column of data that's been concretely bound to a particular dataset.
"""
mask = AssetExists()
extra_input_rows = 0
inputs = ()
def __new__(cls, dtype, dataset, name):
return super(BoundColumn, cls).__new__(
+1 -1
View File
@@ -236,7 +236,7 @@ class SimplePipelineEngine(object):
"""
Load mask and mask row labels for term.
"""
mask = term.mask if not term.atomic else self._root_mask_term
mask = term.mask
offset = graph.extra_rows[mask] - graph.extra_rows[term]
return workspace[mask][offset:], dates[offset:]
+13 -28
View File
@@ -102,16 +102,9 @@ class TermGraph(DiGraph):
zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term
zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term
"""
out = {}
for term in self:
if not term.atomic:
extra_input_rows = term.extra_input_rows
for input_ in term.inputs:
out[term, input_] = (self.extra_rows[input_]
- extra_input_rows)
mask = term.mask
out[term, mask] = self.extra_rows[mask] - extra_input_rows
return out
return {(term, dep): self.extra_rows[dep] - term.extra_input_rows
for term in self
for dep in term.dependencies}
@lazyval
def extra_rows(self):
@@ -192,25 +185,17 @@ class TermGraph(DiGraph):
# Make sure we're going to compute at least `extra_rows` of `term`.
self._ensure_extra_rows(term, extra_rows)
if not term.atomic:
# Number of extra rows we need to compute for this term's
# dependencies.
dependency_extra_rows = extra_rows + term.extra_input_rows
# Number of extra rows we need to compute for this term's dependencies.
dependency_extra_rows = extra_rows + term.extra_input_rows
# Recursively add dependencies.
for dependency in term.inputs:
self._add_to_graph(
dependency,
parents,
extra_rows=dependency_extra_rows,
)
self.add_edge(dependency, term)
# Add term's mask, which is really just a specially-enumerated
# input.
mask = term.mask
self._add_to_graph(mask, parents, extra_rows=dependency_extra_rows)
self.add_edge(mask, term)
# Recursively add dependencies.
for dependency in term.dependencies:
self._add_to_graph(
dependency,
parents,
extra_rows=dependency_extra_rows,
)
self.add_edge(dependency, term)
parents.remove(term)
+48 -49
View File
@@ -1,9 +1,11 @@
"""
Base class for Filters, Factors and Classifiers
"""
from abc import ABCMeta, abstractproperty
from weakref import WeakValueDictionary
from numpy import bool_, full, nan
from six import with_metaclass
from zipline.errors import (
DTypeNotSpecified,
@@ -38,7 +40,7 @@ class NotSpecified(object):
return self
class Term(object):
class Term(with_metaclass(ABCMeta, object)):
"""
Base class for terms in a Pipeline API compute graph.
"""
@@ -135,15 +137,55 @@ class Term(object):
if self.dtype is NotSpecified:
raise DTypeNotSpecified(termname=type(self).__name__)
@property
def atomic(self):
@abstractproperty
def inputs(self):
"""
Whether or not this term has dependencies.
If term.atomic is truthy, it should have dataset and dtype attributes.
A tuple of other Terms that this Term requires for computation.
"""
raise NotImplementedError()
@abstractproperty
def mask(self):
"""
A 2D Filter representing asset/date pairs to include while
computing this Term. (True means include; False means exclude.)
"""
raise NotImplementedError()
@lazyval
def dependencies(self):
return self.inputs + (self.mask,)
@lazyval
def atomic(self):
return not any(dep for dep in self.dependencies
if dep is not AssetExists())
class AssetExists(Term):
"""
Pseudo-filter describing whether or not an asset existed on a given day.
This is the default mask for all terms that haven't been passed a mask
explicitly.
This is morally a Filter, in the sense that it produces a boolean value for
every asset on every date. We don't subclass Filter, however, because
`AssetExists` is computed directly by the PipelineEngine.
See Also
--------
zipline.assets.AssetFinder.lifetimes
"""
dtype = bool_
dataset = None
extra_input_rows = 0
inputs = ()
dependencies = ()
mask = None
def __repr__(self):
return "AssetExists()"
# TODO: Move mixins to a separate file?
class SingleInputMixin(object):
@@ -220,17 +262,6 @@ class CustomTermMixin(object):
return out
class AtomicTerm(Term):
@property
def atomic(self):
return True
@property
def dataset(self):
raise NotImplementedError()
class CompositeTerm(Term):
inputs = NotSpecified
window_length = NotSpecified
@@ -295,10 +326,6 @@ class CompositeTerm(Term):
return super(CompositeTerm, self)._validate()
@property
def atomic(self):
return False
def _compute(self, inputs, dates, assets, mask):
"""
Subclasses should implement this to perform actual computation.
@@ -339,31 +366,3 @@ class CompositeTerm(Term):
inputs=self.inputs,
window_length=self.window_length,
)
class AssetExists(AtomicTerm):
"""
Pseudo-filter describing whether or not an asset existed on a given day.
This is the default mask for all terms that haven't been passed a mask
explicitly.
This is morally a Filter, in the sense that it produces a boolean value for
every asset on every date. We don't subclass Filter, however, because
`AssetExists` is computed directly by the PipelineEngine.
See Also
--------
zipline.assets.AssetFinder.lifetimes
"""
dtype = bool_
dataset = None
def _compute(self, *args, **kwargs):
# TODO: Consider moving the bulk of the logic from
# SimplePipelineEngine._compute_root_mask here.
raise NotImplementedError(
"Direct computation of AssetExists is not supported!"
)
def __repr__(self):
return "AssetExists()"