From 2dabda6b76caa7e2437863b5908c5bea87f8a7ec Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Fri, 9 Oct 2015 17:32:22 -0400 Subject: [PATCH] MAINT: Reworked Term atomicity --- tests/pipeline/test_term.py | 6 +- zipline/pipeline/data/dataset.py | 7 ++- zipline/pipeline/engine.py | 2 +- zipline/pipeline/graph.py | 41 +++++--------- zipline/pipeline/term.py | 97 ++++++++++++++++---------------- 5 files changed, 69 insertions(+), 84 deletions(-) diff --git a/tests/pipeline/test_term.py b/tests/pipeline/test_term.py index c10bf16b..ffa2df7b 100644 --- a/tests/pipeline/test_term.py +++ b/tests/pipeline/test_term.py @@ -86,10 +86,8 @@ class DependencyResolutionTestCase(TestCase): seen = set() for term in ordered_terms: - if not term.atomic: - for input_ in term.inputs: - self.assertIn(input_, seen) - self.assertIn(term.mask, seen) + for dep in term.dependencies: + self.assertIn(dep, seen) seen.add(term) diff --git a/zipline/pipeline/data/dataset.py b/zipline/pipeline/data/dataset.py index dd64fd50..fe8be41a 100644 --- a/zipline/pipeline/data/dataset.py +++ b/zipline/pipeline/data/dataset.py @@ -7,7 +7,7 @@ from six import ( with_metaclass, ) -from zipline.pipeline.term import AtomicTerm +from zipline.pipeline.term import Term, AssetExists from zipline.pipeline.factors import Latest @@ -26,10 +26,13 @@ class Column(object): return BoundColumn(dtype=self.dtype, dataset=dataset, name=name) -class BoundColumn(AtomicTerm): +class BoundColumn(Term): """ A Column of data that's been concretely bound to a particular dataset. """ + mask = AssetExists() + extra_input_rows = 0 + inputs = () def __new__(cls, dtype, dataset, name): return super(BoundColumn, cls).__new__( diff --git a/zipline/pipeline/engine.py b/zipline/pipeline/engine.py index a8347776..8abccb83 100644 --- a/zipline/pipeline/engine.py +++ b/zipline/pipeline/engine.py @@ -236,7 +236,7 @@ class SimplePipelineEngine(object): """ Load mask and mask row labels for term. """ - mask = term.mask if not term.atomic else self._root_mask_term + mask = term.mask offset = graph.extra_rows[mask] - graph.extra_rows[term] return workspace[mask][offset:], dates[offset:] diff --git a/zipline/pipeline/graph.py b/zipline/pipeline/graph.py index 3a91a37c..af4a8326 100644 --- a/zipline/pipeline/graph.py +++ b/zipline/pipeline/graph.py @@ -102,16 +102,9 @@ class TermGraph(DiGraph): zipline.pipeline.engine.SimplePipelineEngine._inputs_for_term zipline.pipeline.engine.SimplePipelineEngine._mask_and_dates_for_term """ - out = {} - for term in self: - if not term.atomic: - extra_input_rows = term.extra_input_rows - for input_ in term.inputs: - out[term, input_] = (self.extra_rows[input_] - - extra_input_rows) - mask = term.mask - out[term, mask] = self.extra_rows[mask] - extra_input_rows - return out + return {(term, dep): self.extra_rows[dep] - term.extra_input_rows + for term in self + for dep in term.dependencies} @lazyval def extra_rows(self): @@ -192,25 +185,17 @@ class TermGraph(DiGraph): # Make sure we're going to compute at least `extra_rows` of `term`. self._ensure_extra_rows(term, extra_rows) - if not term.atomic: - # Number of extra rows we need to compute for this term's - # dependencies. - dependency_extra_rows = extra_rows + term.extra_input_rows + # Number of extra rows we need to compute for this term's dependencies. + dependency_extra_rows = extra_rows + term.extra_input_rows - # Recursively add dependencies. - for dependency in term.inputs: - self._add_to_graph( - dependency, - parents, - extra_rows=dependency_extra_rows, - ) - self.add_edge(dependency, term) - - # Add term's mask, which is really just a specially-enumerated - # input. - mask = term.mask - self._add_to_graph(mask, parents, extra_rows=dependency_extra_rows) - self.add_edge(mask, term) + # Recursively add dependencies. + for dependency in term.dependencies: + self._add_to_graph( + dependency, + parents, + extra_rows=dependency_extra_rows, + ) + self.add_edge(dependency, term) parents.remove(term) diff --git a/zipline/pipeline/term.py b/zipline/pipeline/term.py index 1dc3974a..31301053 100644 --- a/zipline/pipeline/term.py +++ b/zipline/pipeline/term.py @@ -1,9 +1,11 @@ """ Base class for Filters, Factors and Classifiers """ +from abc import ABCMeta, abstractproperty from weakref import WeakValueDictionary from numpy import bool_, full, nan +from six import with_metaclass from zipline.errors import ( DTypeNotSpecified, @@ -38,7 +40,7 @@ class NotSpecified(object): return self -class Term(object): +class Term(with_metaclass(ABCMeta, object)): """ Base class for terms in a Pipeline API compute graph. """ @@ -135,15 +137,55 @@ class Term(object): if self.dtype is NotSpecified: raise DTypeNotSpecified(termname=type(self).__name__) - @property - def atomic(self): + @abstractproperty + def inputs(self): """ - Whether or not this term has dependencies. - - If term.atomic is truthy, it should have dataset and dtype attributes. + A tuple of other Terms that this Term requires for computation. """ raise NotImplementedError() + @abstractproperty + def mask(self): + """ + A 2D Filter representing asset/date pairs to include while + computing this Term. (True means include; False means exclude.) + """ + raise NotImplementedError() + + @lazyval + def dependencies(self): + return self.inputs + (self.mask,) + + @lazyval + def atomic(self): + return not any(dep for dep in self.dependencies + if dep is not AssetExists()) + + +class AssetExists(Term): + """ + Pseudo-filter describing whether or not an asset existed on a given day. + This is the default mask for all terms that haven't been passed a mask + explicitly. + + This is morally a Filter, in the sense that it produces a boolean value for + every asset on every date. We don't subclass Filter, however, because + `AssetExists` is computed directly by the PipelineEngine. + + See Also + -------- + zipline.assets.AssetFinder.lifetimes + """ + dtype = bool_ + dataset = None + extra_input_rows = 0 + inputs = () + dependencies = () + mask = None + + def __repr__(self): + return "AssetExists()" + # TODO: Move mixins to a separate file? class SingleInputMixin(object): @@ -220,17 +262,6 @@ class CustomTermMixin(object): return out -class AtomicTerm(Term): - - @property - def atomic(self): - return True - - @property - def dataset(self): - raise NotImplementedError() - - class CompositeTerm(Term): inputs = NotSpecified window_length = NotSpecified @@ -295,10 +326,6 @@ class CompositeTerm(Term): return super(CompositeTerm, self)._validate() - @property - def atomic(self): - return False - def _compute(self, inputs, dates, assets, mask): """ Subclasses should implement this to perform actual computation. @@ -339,31 +366,3 @@ class CompositeTerm(Term): inputs=self.inputs, window_length=self.window_length, ) - - -class AssetExists(AtomicTerm): - """ - Pseudo-filter describing whether or not an asset existed on a given day. - This is the default mask for all terms that haven't been passed a mask - explicitly. - - This is morally a Filter, in the sense that it produces a boolean value for - every asset on every date. We don't subclass Filter, however, because - `AssetExists` is computed directly by the PipelineEngine. - - See Also - -------- - zipline.assets.AssetFinder.lifetimes - """ - dtype = bool_ - dataset = None - - def _compute(self, *args, **kwargs): - # TODO: Consider moving the bulk of the logic from - # SimplePipelineEngine._compute_root_mask here. - raise NotImplementedError( - "Direct computation of AssetExists is not supported!" - ) - - def __repr__(self): - return "AssetExists()"