Files
catalyst/tests/modelling/test_term.py
T
Scott Sanderson 26fd6fda8b ENH/BUG: Modeling API enhancements.
- Fixes an error where Modeling API data known as of the close of `day
  N` would be shown to algorithms during `before_trading_start` as of
  the close of the same day.  Algorithms should now only receive data
  during `before_trading_start/handle_data` that was known as of the
  simulation time at which the function would be called.

- All Term instances now have a `mask` attribute that must be a `Filter`
  or an instance of `AssetExists()`.  `mask` can be used to specify that
  a Factor should be computed in a manner that ignores the values that
  were not `True` in the mask.

- Changed the interface for `FFCLoader.load_adjusted_array` and
  `Term._compute` from `(columns, mask)`, with mask as a DataFrame, to
  `(columns, dates, assets, mask)`, where mask is a numpy array.  This
  is primarily to avoid having to reconstruct extra DataFrames when
  using masks produced by non `AssetExists` filters.

- Adds `BoundColumn.latest`, which gives the most-recently-known value
  of a column.
2015-09-16 01:47:11 -04:00

294 lines
8.2 KiB
Python

"""
Tests for Term.
"""
from itertools import product
from unittest import TestCase
from numpy import (
float32,
uint32,
uint8,
)
from zipline.data.dataset import (
Column,
DataSet,
)
from zipline.errors import (
InputTermNotAtomic,
TermInputsNotSpecified,
WindowLengthNotSpecified,
)
from zipline.modelling.expression import NUMEXPR_MATH_FUNCS
from zipline.modelling.factor import Factor
from zipline.modelling.graph import TermGraph
from zipline.modelling.term import AssetExists, NotSpecified
class SomeDataSet(DataSet):
foo = Column(float32)
bar = Column(uint32)
buzz = Column(uint8)
class SomeFactor(Factor):
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
class NoLookbackFactor(Factor):
window_length = 0
class SomeOtherFactor(Factor):
window_length = 5
inputs = [SomeDataSet.bar, SomeDataSet.buzz]
SomeFactorAlias = SomeFactor
def gen_equivalent_factors():
"""
Return an iterator of SomeFactor instances that should all be the same
object.
"""
yield SomeFactor()
yield SomeFactor(inputs=NotSpecified)
yield SomeFactor(SomeFactor.inputs)
yield SomeFactor(inputs=SomeFactor.inputs)
yield SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
yield SomeFactor(window_length=SomeFactor.window_length)
yield SomeFactor(window_length=NotSpecified)
yield SomeFactor(
[SomeDataSet.foo, SomeDataSet.bar],
window_length=NotSpecified,
)
yield SomeFactor(
[SomeDataSet.foo, SomeDataSet.bar],
window_length=SomeFactor.window_length,
)
yield SomeFactorAlias()
def to_dict(l):
"""
Convert a list to a dict with keys drawn from '0', '1', '2', ...
Example
-------
>>> to_dict([2, 3, 4])
{'0': 2, '1': 3, '2': 4}
"""
return dict(zip(map(str, range(len(l))), l))
class DependencyResolutionTestCase(TestCase):
def setup(self):
pass
def teardown(self):
pass
def test_single_factor(self):
"""
Test dependency resolution for a single factor.
"""
def check_output(graph):
resolution_order = list(graph.ordered())
self.assertEqual(len(resolution_order), 4)
self.assertIs(resolution_order[0], AssetExists())
self.assertEqual(
set([resolution_order[1], resolution_order[2]]),
set([SomeDataSet.foo, SomeDataSet.bar]),
)
self.assertEqual(resolution_order[-1], SomeFactor())
self.assertEqual(graph.node[SomeDataSet.foo]['extra_rows'], 4)
self.assertEqual(graph.node[SomeDataSet.bar]['extra_rows'], 4)
for foobar in gen_equivalent_factors():
check_output(TermGraph(to_dict([foobar])))
def test_single_factor_instance_args(self):
"""
Test dependency resolution for a single factor with arguments passed to
the constructor.
"""
bar, buzz = SomeDataSet.bar, SomeDataSet.buzz
graph = TermGraph(to_dict([SomeFactor([bar, buzz], window_length=5)]))
resolution_order = list(graph.ordered())
# SomeFactor, its inputs, and AssetExists()
self.assertEqual(len(resolution_order), 4)
self.assertIs(resolution_order[0], AssetExists())
self.assertEqual(graph.extra_rows[AssetExists()], 4)
self.assertEqual(
set([resolution_order[1], resolution_order[2]]),
set([bar, buzz]),
)
self.assertEqual(
resolution_order[-1],
SomeFactor([bar, buzz], window_length=5),
)
self.assertEqual(graph.extra_rows[bar], 4)
self.assertEqual(graph.extra_rows[buzz], 4)
def test_reuse_atomic_terms(self):
"""
Test that raw inputs only show up in the dependency graph once.
"""
f1 = SomeFactor([SomeDataSet.foo, SomeDataSet.bar])
f2 = SomeOtherFactor([SomeDataSet.bar, SomeDataSet.buzz])
graph = TermGraph(to_dict([f1, f2]))
resolution_order = list(graph.ordered())
# bar should only appear once.
self.assertEqual(len(resolution_order), 6)
indices = {
term: resolution_order.index(term)
for term in resolution_order
}
self.assertEqual(indices[AssetExists()], 0)
# Verify that f1's dependencies will be computed before f1.
self.assertLess(indices[SomeDataSet.foo], indices[f1])
self.assertLess(indices[SomeDataSet.bar], indices[f1])
# Verify that f2's dependencies will be computed before f2.
self.assertLess(indices[SomeDataSet.bar], indices[f2])
self.assertLess(indices[SomeDataSet.buzz], indices[f2])
def test_disallow_recursive_lookback(self):
with self.assertRaises(InputTermNotAtomic):
SomeFactor(inputs=[SomeFactor(), SomeDataSet.foo])
class ObjectIdentityTestCase(TestCase):
def assertSameObject(self, *objs):
first = objs[0]
for obj in objs:
self.assertIs(first, obj)
def test_instance_caching(self):
self.assertSameObject(*gen_equivalent_factors())
self.assertIs(
SomeFactor(window_length=SomeFactor.window_length + 1),
SomeFactor(window_length=SomeFactor.window_length + 1),
)
self.assertIs(
SomeFactor(dtype=int),
SomeFactor(dtype=int),
)
self.assertIs(
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
)
def test_instance_non_caching(self):
f = SomeFactor()
# Different window_length.
self.assertIsNot(
f,
SomeFactor(window_length=SomeFactor.window_length + 1),
)
# Different dtype
self.assertIsNot(
f,
SomeFactor(dtype=int)
)
# Reordering inputs changes semantics.
self.assertIsNot(
f,
SomeFactor(inputs=[SomeFactor.inputs[1], SomeFactor.inputs[0]]),
)
def test_instance_non_caching_redefine_class(self):
orig_foobar_instance = SomeFactorAlias()
class SomeFactor(Factor):
window_length = 5
inputs = [SomeDataSet.foo, SomeDataSet.bar]
self.assertIsNot(orig_foobar_instance, SomeFactor())
def test_instance_caching_binops(self):
f = SomeFactor()
g = SomeOtherFactor()
for lhs, rhs in product([f, g], [f, g]):
self.assertIs((lhs + rhs), (lhs + rhs))
self.assertIs((lhs - rhs), (lhs - rhs))
self.assertIs((lhs * rhs), (lhs * rhs))
self.assertIs((lhs / rhs), (lhs / rhs))
self.assertIs((lhs ** rhs), (lhs ** rhs))
self.assertIs((1 + rhs), (1 + rhs))
self.assertIs((rhs + 1), (rhs + 1))
self.assertIs((1 - rhs), (1 - rhs))
self.assertIs((rhs - 1), (rhs - 1))
self.assertIs((2 * rhs), (2 * rhs))
self.assertIs((rhs * 2), (rhs * 2))
self.assertIs((2 / rhs), (2 / rhs))
self.assertIs((rhs / 2), (rhs / 2))
self.assertIs((2 ** rhs), (2 ** rhs))
self.assertIs((rhs ** 2), (rhs ** 2))
self.assertIs((f + g) + (f + g), (f + g) + (f + g))
def test_instance_caching_unary_ops(self):
f = SomeFactor()
self.assertIs(-f, -f)
self.assertIs(--f, --f)
self.assertIs(---f, ---f)
def test_instance_caching_math_funcs(self):
f = SomeFactor()
for funcname in NUMEXPR_MATH_FUNCS:
method = getattr(f, funcname)
self.assertIs(method(), method())
def test_bad_input(self):
class SomeFactor(Factor):
pass
class SomeFactorDefaultInputs(Factor):
inputs = (SomeDataSet.foo, SomeDataSet.bar)
class SomeFactorDefaultLength(Factor):
window_length = 10
with self.assertRaises(TermInputsNotSpecified):
SomeFactor(window_length=1)
with self.assertRaises(TermInputsNotSpecified):
SomeFactorDefaultLength()
with self.assertRaises(WindowLengthNotSpecified):
SomeFactor(inputs=(SomeDataSet.foo,))
with self.assertRaises(WindowLengthNotSpecified):
SomeFactorDefaultInputs()