mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 14:09:44 +08:00
ENH: cleanup branch based on feedback
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from nose.tools import nottest
|
||||
import numpy as np
|
||||
|
||||
from zipline.testing.predicates import assert_equal
|
||||
@@ -7,7 +8,8 @@ from zipline.utils.numpy_utils import float64_dtype, int64_dtype
|
||||
from .base import BasePipelineTestCase
|
||||
|
||||
|
||||
class WithAlias(object):
|
||||
@nottest
|
||||
class BaseAliasTestCase(BasePipelineTestCase):
|
||||
|
||||
def test_alias(self):
|
||||
f = self.Term()
|
||||
@@ -42,20 +44,20 @@ class WithAlias(object):
|
||||
)
|
||||
|
||||
|
||||
class TestFactorAlias(WithAlias, BasePipelineTestCase):
|
||||
class TestFactorAlias(BaseAliasTestCase):
|
||||
class Term(Factor):
|
||||
dtype = float64_dtype
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class TestFilterAlias(WithAlias, BasePipelineTestCase):
|
||||
class TestFilterAlias(BaseAliasTestCase):
|
||||
class Term(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
|
||||
|
||||
class TestClassifierAlias(WithAlias, BasePipelineTestCase):
|
||||
class TestClassifierAlias(BaseAliasTestCase):
|
||||
class Term(Classifier):
|
||||
dtype = int64_dtype
|
||||
inputs = ()
|
||||
|
||||
@@ -14,6 +14,7 @@ from numpy import (
|
||||
float32,
|
||||
float64,
|
||||
full,
|
||||
full_like,
|
||||
log,
|
||||
nan,
|
||||
tile,
|
||||
@@ -79,6 +80,7 @@ from zipline.testing.fixtures import (
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.memoize import lazyval
|
||||
from zipline.utils.numpy_utils import bool_dtype, datetime64ns_dtype
|
||||
|
||||
@@ -1322,30 +1324,58 @@ class StringColumnTestCase(WithSeededRandomPipelineEngine,
|
||||
|
||||
class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
def test_populate_default_workspace(self):
|
||||
window_length = 5
|
||||
column = USEquityPricing.low
|
||||
base_term = column.latest
|
||||
term = (base_term + 1).alias('term')
|
||||
composed_term = term + 1
|
||||
precomputed_term = (base_term + 1).alias('precomputed_term')
|
||||
precomputed_term_with_window = SimpleMovingAverage(
|
||||
inputs=(column,),
|
||||
window_length=window_length,
|
||||
).alias('precomputed_term_with_window')
|
||||
depends_on_precomputed_term = precomputed_term + 1
|
||||
depends_on_precomputed_term_with_window = (
|
||||
precomputed_term_with_window + 1
|
||||
)
|
||||
column_value = self.constants[column]
|
||||
precomputed_value = -column_value
|
||||
precomputed_term_value = -column_value
|
||||
precomputed_term_with_window_value = -(column_value + 1)
|
||||
|
||||
def populate_initial_workspace(initial_workspace,
|
||||
root_mask_term,
|
||||
execution_plan,
|
||||
dates,
|
||||
assets):
|
||||
return assoc(
|
||||
ws = initial_workspace.copy()
|
||||
_, precomputed_term_dates = execution_plan.mask_and_dates_for_term(
|
||||
precomputed_term,
|
||||
root_mask_term,
|
||||
initial_workspace,
|
||||
term,
|
||||
full(
|
||||
(len(dates), len(assets)),
|
||||
precomputed_value,
|
||||
dtype=float64,
|
||||
),
|
||||
dates,
|
||||
)
|
||||
ws[precomputed_term] = full(
|
||||
(len(precomputed_term_dates), len(assets)),
|
||||
precomputed_term_value,
|
||||
dtype=float64,
|
||||
)
|
||||
(
|
||||
_,
|
||||
precomputed_term_with_window_dates,
|
||||
) = execution_plan.mask_and_dates_for_term(
|
||||
precomputed_term,
|
||||
root_mask_term,
|
||||
initial_workspace,
|
||||
dates,
|
||||
)
|
||||
|
||||
def dispatcher(column):
|
||||
if column is base_term:
|
||||
ws[precomputed_term_with_window] = full(
|
||||
(len(precomputed_term_with_window_dates), len(assets)),
|
||||
precomputed_term_with_window_value,
|
||||
dtype=float64,
|
||||
)
|
||||
return ws
|
||||
|
||||
def dispatcher(c):
|
||||
if c is column:
|
||||
# the base_term should never be loaded, its initial refcount
|
||||
# should be zero
|
||||
return ExplodingObject()
|
||||
@@ -1360,16 +1390,41 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline({
|
||||
'term': term,
|
||||
'composed_term': composed_term,
|
||||
'precomputed_term': precomputed_term,
|
||||
'precomputed_term_with_window': precomputed_term_with_window,
|
||||
'depends_on_precomputed_term': depends_on_precomputed_term,
|
||||
'depends_on_precomputed_term_with_window':
|
||||
depends_on_precomputed_term_with_window,
|
||||
}),
|
||||
self.dates[0],
|
||||
self.dates[window_length - 1],
|
||||
self.dates[-1],
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
(results['term'] == precomputed_value).all(),
|
||||
assert_equal(
|
||||
results['precomputed_term'].values,
|
||||
full_like(
|
||||
results['precomputed_term'],
|
||||
precomputed_term_value,
|
||||
),
|
||||
),
|
||||
assert_equal(
|
||||
results['precomputed_term_with_window'].values,
|
||||
full_like(
|
||||
results['precomputed_term_with_window'],
|
||||
precomputed_term_with_window_value,
|
||||
),
|
||||
),
|
||||
assert_equal(
|
||||
results['depends_on_precomputed_term'].values,
|
||||
full_like(
|
||||
results['depends_on_precomputed_term'],
|
||||
precomputed_term_value + 1,
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
(results['composed_term'] == (precomputed_value + 1)).all(),
|
||||
assert_equal(
|
||||
results['depends_on_precomputed_term_with_window'].values,
|
||||
full_like(
|
||||
results['depends_on_precomputed_term_with_window'],
|
||||
precomputed_term_with_window_value + 1,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -312,7 +312,7 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
|
||||
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
|
||||
"""
|
||||
data = super(Classifier, self).unprocess(result, assets)
|
||||
data = super(Classifier, self).to_workspace_value(result, assets)
|
||||
if self.dtype == int64_dtype:
|
||||
return data
|
||||
assert isinstance(data, pd.Categorical), (
|
||||
|
||||
@@ -146,13 +146,19 @@ class TermGraph(object):
|
||||
refcounts[t] += 1
|
||||
|
||||
for t in initial_terms:
|
||||
self._decref_recursive(t, refcounts, set())
|
||||
self._decref_depencies_recursive(t, refcounts, set())
|
||||
|
||||
return refcounts
|
||||
|
||||
def _decref_recursive(self, term, refcounts, garbage):
|
||||
def _decref_depencies_recursive(self, term, refcounts, garbage):
|
||||
"""
|
||||
Decrement terms recursivly to build the initial workspace.
|
||||
Decrement terms recursively.
|
||||
|
||||
Notes
|
||||
-----
|
||||
This should only be used to build the initial workspace, after that we
|
||||
should use:
|
||||
:meth:`~zipline.pipeline.graph.TermGraph.decref_dependencies`
|
||||
"""
|
||||
# Edges are tuple of (from, to).
|
||||
for parent, _ in self.graph.in_edges([term]):
|
||||
@@ -161,7 +167,7 @@ class TermGraph(object):
|
||||
# workspace to conserve memory.
|
||||
if refcounts[parent] == 0:
|
||||
garbage.add(parent)
|
||||
self._decref_recursive(parent, refcounts, garbage)
|
||||
self._decref_depencies_recursive(parent, refcounts, garbage)
|
||||
|
||||
def decref_dependencies(self, term, refcounts):
|
||||
"""
|
||||
@@ -430,7 +436,7 @@ class ExecutionPlan(TermGraph):
|
||||
mask = term.mask
|
||||
mask_offset = self.extra_rows[mask] - self.extra_rows[term]
|
||||
|
||||
# This offset is computed against _root_mask_term because that is what
|
||||
# This offset is computed against root_mask_term because that is what
|
||||
# determines the shape of the top-level dates array.
|
||||
dates_offset = (
|
||||
self.extra_rows[root_mask_term] - self.extra_rows[term]
|
||||
|
||||
@@ -20,7 +20,7 @@ from zipline.utils.control_flow import nullctx
|
||||
from zipline.utils.input_validation import expect_types
|
||||
from zipline.utils.sharedoc import (
|
||||
format_docstring,
|
||||
PIPELINE_ALIAS_DOC,
|
||||
PIPELINE_ALIAS_NAME_DOC,
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
|
||||
)
|
||||
from zipline.utils.pandas_utils import nearest_unequal_elements
|
||||
@@ -300,12 +300,12 @@ class AliasedMixin(SingleInputMixin):
|
||||
doc = format_docstring(
|
||||
owner_name=other_base.__name__,
|
||||
docstring=docstring,
|
||||
formatters={'name': PIPELINE_ALIAS_DOC},
|
||||
formatters={'name': PIPELINE_ALIAS_NAME_DOC},
|
||||
)
|
||||
|
||||
return type(
|
||||
'Aliased' + other_base.__name__,
|
||||
(cls, other_base,),
|
||||
(cls, other_base),
|
||||
{'__doc__': doc,
|
||||
'__module__': other_base.__module__},
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ from zipline.utils.numpy_utils import (
|
||||
)
|
||||
from zipline.utils.sharedoc import (
|
||||
templated_docstring,
|
||||
PIPELINE_ALIAS_DOC,
|
||||
PIPELINE_ALIAS_NAME_DOC,
|
||||
PIPELINE_DOWNSAMPLING_FREQUENCY_DOC,
|
||||
)
|
||||
|
||||
@@ -633,7 +633,7 @@ class ComputableTerm(Term):
|
||||
"for instances of %s." % type(self).__name__
|
||||
)
|
||||
|
||||
@templated_docstring(name=PIPELINE_ALIAS_DOC)
|
||||
@templated_docstring(name=PIPELINE_ALIAS_NAME_DOC)
|
||||
def alias(self, name):
|
||||
"""
|
||||
Make a term from ``self`` that names the expression.
|
||||
|
||||
@@ -18,7 +18,7 @@ PIPELINE_DOWNSAMPLING_FREQUENCY_DOC = dedent(
|
||||
"""
|
||||
)
|
||||
|
||||
PIPELINE_ALIAS_DOC = dedent(
|
||||
PIPELINE_ALIAS_NAME_DOC = dedent(
|
||||
"""\
|
||||
name : str
|
||||
The name to alias this term as.
|
||||
|
||||
Reference in New Issue
Block a user