mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 15:55:02 +08:00
MAINT: Reworked Term atomicity
This commit is contained in:
@@ -86,10 +86,8 @@ class DependencyResolutionTestCase(TestCase):
|
||||
seen = set()
|
||||
|
||||
for term in ordered_terms:
|
||||
if not term.atomic:
|
||||
for input_ in term.inputs:
|
||||
self.assertIn(input_, seen)
|
||||
self.assertIn(term.mask, seen)
|
||||
for dep in term.dependencies:
|
||||
self.assertIn(dep, seen)
|
||||
|
||||
seen.add(term)
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from six import (
|
||||
with_metaclass,
|
||||
)
|
||||
|
||||
from zipline.pipeline.term import AtomicTerm
|
||||
from zipline.pipeline.term import Term, AssetExists
|
||||
from zipline.pipeline.factors import Latest
|
||||
|
||||
|
||||
@@ -26,10 +26,13 @@ class Column(object):
|
||||
return BoundColumn(dtype=self.dtype, dataset=dataset, name=name)
|
||||
|
||||
|
||||
class BoundColumn(AtomicTerm):
|
||||
class BoundColumn(Term):
|
||||
"""
|
||||
A Column of data that's been concretely bound to a particular dataset.
|
||||
"""
|
||||
mask = AssetExists()
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
|
||||
def __new__(cls, dtype, dataset, name):
|
||||
return super(BoundColumn, cls).__new__(
|
||||
|
||||
@@ -236,7 +236,7 @@ class SimplePipelineEngine(object):
|
||||
"""
|
||||
Load mask and mask row labels for term.
|
||||
"""
|
||||
mask = term.mask if not term.atomic else self._root_mask_term
|
||||
mask = term.mask
|
||||
offset = graph.extra_rows[mask] - graph.extra_rows[term]
|
||||
return workspace[mask][offset:], dates[offset:]
|
||||
|
||||
|
||||
+13
-28
@@ -102,16 +102,9 @@ class TermGraph(DiGraph):
|
||||
zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term
|
||||
zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term
|
||||
"""
|
||||
out = {}
|
||||
for term in self:
|
||||
if not term.atomic:
|
||||
extra_input_rows = term.extra_input_rows
|
||||
for input_ in term.inputs:
|
||||
out[term, input_] = (self.extra_rows[input_]
|
||||
- extra_input_rows)
|
||||
mask = term.mask
|
||||
out[term, mask] = self.extra_rows[mask] - extra_input_rows
|
||||
return out
|
||||
return {(term, dep): self.extra_rows[dep] - term.extra_input_rows
|
||||
for term in self
|
||||
for dep in term.dependencies}
|
||||
|
||||
@lazyval
|
||||
def extra_rows(self):
|
||||
@@ -192,25 +185,17 @@ class TermGraph(DiGraph):
|
||||
# Make sure we're going to compute at least `extra_rows` of `term`.
|
||||
self._ensure_extra_rows(term, extra_rows)
|
||||
|
||||
if not term.atomic:
|
||||
# Number of extra rows we need to compute for this term's
|
||||
# dependencies.
|
||||
dependency_extra_rows = extra_rows + term.extra_input_rows
|
||||
# Number of extra rows we need to compute for this term's dependencies.
|
||||
dependency_extra_rows = extra_rows + term.extra_input_rows
|
||||
|
||||
# Recursively add dependencies.
|
||||
for dependency in term.inputs:
|
||||
self._add_to_graph(
|
||||
dependency,
|
||||
parents,
|
||||
extra_rows=dependency_extra_rows,
|
||||
)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
# Add term's mask, which is really just a specially-enumerated
|
||||
# input.
|
||||
mask = term.mask
|
||||
self._add_to_graph(mask, parents, extra_rows=dependency_extra_rows)
|
||||
self.add_edge(mask, term)
|
||||
# Recursively add dependencies.
|
||||
for dependency in term.dependencies:
|
||||
self._add_to_graph(
|
||||
dependency,
|
||||
parents,
|
||||
extra_rows=dependency_extra_rows,
|
||||
)
|
||||
self.add_edge(dependency, term)
|
||||
|
||||
parents.remove(term)
|
||||
|
||||
|
||||
+48
-49
@@ -1,9 +1,11 @@
|
||||
"""
|
||||
Base class for Filters, Factors and Classifiers
|
||||
"""
|
||||
from abc import ABCMeta, abstractproperty
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from numpy import bool_, full, nan
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import (
|
||||
DTypeNotSpecified,
|
||||
@@ -38,7 +40,7 @@ class NotSpecified(object):
|
||||
return self
|
||||
|
||||
|
||||
class Term(object):
|
||||
class Term(with_metaclass(ABCMeta, object)):
|
||||
"""
|
||||
Base class for terms in a Pipeline API compute graph.
|
||||
"""
|
||||
@@ -135,15 +137,55 @@ class Term(object):
|
||||
if self.dtype is NotSpecified:
|
||||
raise DTypeNotSpecified(termname=type(self).__name__)
|
||||
|
||||
@property
|
||||
def atomic(self):
|
||||
@abstractproperty
|
||||
def inputs(self):
|
||||
"""
|
||||
Whether or not this term has dependencies.
|
||||
|
||||
If term.atomic is truthy, it should have dataset and dtype attributes.
|
||||
A tuple of other Terms that this Term requires for computation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractproperty
|
||||
def mask(self):
|
||||
"""
|
||||
A 2D Filter representing asset/date pairs to include while
|
||||
computing this Term. (True means include; False means exclude.)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@lazyval
|
||||
def dependencies(self):
|
||||
return self.inputs + (self.mask,)
|
||||
|
||||
@lazyval
|
||||
def atomic(self):
|
||||
return not any(dep for dep in self.dependencies
|
||||
if dep is not AssetExists())
|
||||
|
||||
|
||||
class AssetExists(Term):
|
||||
"""
|
||||
Pseudo-filter describing whether or not an asset existed on a given day.
|
||||
This is the default mask for all terms that haven't been passed a mask
|
||||
explicitly.
|
||||
|
||||
This is morally a Filter, in the sense that it produces a boolean value for
|
||||
every asset on every date. We don't subclass Filter, however, because
|
||||
`AssetExists` is computed directly by the PipelineEngine.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
"""
|
||||
dtype = bool_
|
||||
dataset = None
|
||||
extra_input_rows = 0
|
||||
inputs = ()
|
||||
dependencies = ()
|
||||
mask = None
|
||||
|
||||
def __repr__(self):
|
||||
return "AssetExists()"
|
||||
|
||||
|
||||
# TODO: Move mixins to a separate file?
|
||||
class SingleInputMixin(object):
|
||||
@@ -220,17 +262,6 @@ class CustomTermMixin(object):
|
||||
return out
|
||||
|
||||
|
||||
class AtomicTerm(Term):
|
||||
|
||||
@property
|
||||
def atomic(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CompositeTerm(Term):
|
||||
inputs = NotSpecified
|
||||
window_length = NotSpecified
|
||||
@@ -295,10 +326,6 @@ class CompositeTerm(Term):
|
||||
|
||||
return super(CompositeTerm, self)._validate()
|
||||
|
||||
@property
|
||||
def atomic(self):
|
||||
return False
|
||||
|
||||
def _compute(self, inputs, dates, assets, mask):
|
||||
"""
|
||||
Subclasses should implement this to perform actual computation.
|
||||
@@ -339,31 +366,3 @@ class CompositeTerm(Term):
|
||||
inputs=self.inputs,
|
||||
window_length=self.window_length,
|
||||
)
|
||||
|
||||
|
||||
class AssetExists(AtomicTerm):
|
||||
"""
|
||||
Pseudo-filter describing whether or not an asset existed on a given day.
|
||||
This is the default mask for all terms that haven't been passed a mask
|
||||
explicitly.
|
||||
|
||||
This is morally a Filter, in the sense that it produces a boolean value for
|
||||
every asset on every date. We don't subclass Filter, however, because
|
||||
`AssetExists` is computed directly by the PipelineEngine.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.assets.AssetFinder.lifetimes
|
||||
"""
|
||||
dtype = bool_
|
||||
dataset = None
|
||||
|
||||
def _compute(self, *args, **kwargs):
|
||||
# TODO: Consider moving the bulk of the logic from
|
||||
# SimplePipelineEngine._compute_root_mask here.
|
||||
raise NotImplementedError(
|
||||
"Direct computation of AssetExists is not supported!"
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "AssetExists()"
|
||||
|
||||
Reference in New Issue
Block a user