From af3e1016a0b359dcf0d8a36d83465c1260440dbe Mon Sep 17 00:00:00 2001 From: Joe Jevnik Date: Wed, 12 Oct 2016 19:16:30 -0400 Subject: [PATCH] TST: add tests for postprocess and to_workspace_value --- tests/pipeline/test_classifier.py | 81 ++++++++++++++++++++++ tests/pipeline/test_factor.py | 39 +++++++++++ tests/pipeline/test_filter.py | 36 ++++++++++ zipline/pipeline/classifiers/classifier.py | 23 ++++-- zipline/pipeline/term.py | 19 ++++- 5 files changed, 190 insertions(+), 8 deletions(-) diff --git a/tests/pipeline/test_classifier.py b/tests/pipeline/test_classifier.py index aba52225..8eddb6f8 100644 --- a/tests/pipeline/test_classifier.py +++ b/tests/pipeline/test_classifier.py @@ -2,10 +2,13 @@ from functools import reduce from operator import or_ import numpy as np +import pandas as pd from zipline.lib.labelarray import LabelArray from zipline.pipeline import Classifier from zipline.testing import parameter_space +from zipline.testing.fixtures import ZiplineTestCase +from zipline.testing.predicates import assert_equal from zipline.utils.numpy_utils import ( categorical_dtype, int64_dtype, @@ -464,3 +467,81 @@ class ClassifierTestCase(BasePipelineTestCase): "TypeError(\"unhashable type: 'dict'\",)." ) self.assertEqual(errmsg, expected) + + +class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase): + def test_reversability_categorical(self): + class F(Classifier): + inputs = () + window_length = 0 + dtype = categorical_dtype + missing_value = '' + + f = F() + column_data = LabelArray( + np.array( + [['a', f.missing_value], + ['b', f.missing_value], + ['c', 'd']], + ), + missing_value=f.missing_value, + ) + + assert_equal( + f.postprocess(column_data.ravel()), + pd.Categorical( + ['a', f.missing_value, 'b', f.missing_value, 'c', 'd'], + ), + ) + + # only include the non-missing data + pipeline_output = pd.Series( + data=['a', 'b', 'c', 'd'], + index=pd.MultiIndex.from_arrays([ + [pd.Timestamp('2014-01-01'), + pd.Timestamp('2014-01-02'), + pd.Timestamp('2014-01-03'), + pd.Timestamp('2014-01-03')], + [0, 0, 0, 1], + ]), + dtype='category', + ) + + assert_equal( + f.to_workspace_value(pipeline_output, pd.Index([0, 1])), + column_data, + ) + + def test_reversability_int64(self): + class F(Classifier): + inputs = () + window_length = 0 + dtype = int64_dtype + missing_value = -1 + + f = F() + column_data =np.array( + [[0, f.missing_value], + [1, f.missing_value], + [2, 3]], + ) + + assert_equal(f.postprocess(column_data.ravel()), column_data.ravel()) + + # only include the non-missing data + pipeline_output = pd.Series( + data=[0, 1, 2, 3], + index=pd.MultiIndex.from_arrays([ + [pd.Timestamp('2014-01-01'), + pd.Timestamp('2014-01-02'), + pd.Timestamp('2014-01-03'), + pd.Timestamp('2014-01-03')], + [0, 0, 0, 1], + ]), + dtype=int64_dtype, + ) + + assert_equal( + f.to_workspace_value(pipeline_output, pd.Index([0, 1])), + column_data, + ) diff --git a/tests/pipeline/test_factor.py b/tests/pipeline/test_factor.py index f020ae8b..2d81dc18 100644 --- a/tests/pipeline/test_factor.py +++ b/tests/pipeline/test_factor.py @@ -21,6 +21,7 @@ from numpy import ( where, ) from numpy.random import randn, seed +import pandas as pd from zipline.errors import UnknownRankMethod from zipline.lib.labelarray import LabelArray @@ -37,6 +38,8 @@ from zipline.testing import ( parameter_space, permute_rows, ) +from zipline.testing.fixtures import ZiplineTestCase +from zipline.testing.predicates import assert_equal from zipline.utils.numpy_utils import ( categorical_dtype, datetime64ns_dtype, @@ -1058,3 +1061,39 @@ class TestWindowSafety(TestCase): self.assertFalse(F().demean().window_safe) self.assertFalse(F(window_safe=False).demean().window_safe) self.assertTrue(F(window_safe=True).demean().window_safe) + + +class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase): + @parameter_space(dtype_=(float64_dtype, datetime64ns_dtype)) + def test_reversability(self, dtype_): + class F(Factor): + inputs = () + dtype = dtype_ + window_length = 0 + + f = F() + column_data = array( + [[0, f.missing_value], + [1, f.missing_value], + [2, 3]], + dtype=dtype_, + ) + + assert_equal(f.postprocess(column_data.ravel()), column_data.ravel()) + + # only include the non-missing data + pipeline_output = pd.Series( + data=array([0, 1, 2, 3], dtype=dtype_), + index=pd.MultiIndex.from_arrays([ + [pd.Timestamp('2014-01-01'), + pd.Timestamp('2014-01-02'), + pd.Timestamp('2014-01-03'), + pd.Timestamp('2014-01-03')], + [0, 0, 0, 1], + ]), + ) + + assert_equal( + f.to_workspace_value(pipeline_output, pd.Index([0, 1])), + column_data, + ) diff --git a/tests/pipeline/test_filter.py b/tests/pipeline/test_filter.py index 1051cee2..aa578a5a 100644 --- a/tests/pipeline/test_filter.py +++ b/tests/pipeline/test_filter.py @@ -24,6 +24,7 @@ from numpy import ( sum as np_sum ) from numpy.random import randn, seed as random_seed +import pandas as pd from zipline.errors import BadPercentileBounds from zipline.pipeline import Filter, Factor, Pipeline @@ -859,3 +860,38 @@ class SpecificAssetsTestCase(WithSeededRandomPipelineEngine, assert_equal(results.odds, (sids % 2).astype(bool)) assert_equal(results.first_five, sids < 5) assert_equal(results.last_three, sids >= 7) + + +class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase): + def test_reversability(self): + class F(Filter): + inputs = () + window_length = 0 + missing_value = False + + f = F() + column_data = array( + [[True, f.missing_value], + [True, f.missing_value], + [True, True]], + dtype=bool, + ) + + assert_equal(f.postprocess(column_data.ravel()), column_data.ravel()) + + # only include the non-missing data + pipeline_output = pd.Series( + data=True, + index=pd.MultiIndex.from_arrays([ + [pd.Timestamp('2014-01-01'), + pd.Timestamp('2014-01-02'), + pd.Timestamp('2014-01-03'), + pd.Timestamp('2014-01-03')], + [0, 0, 0, 1], + ]), + ) + + assert_equal( + f.to_workspace_value(pipeline_output, pd.Index([0, 1])), + column_data, + ) diff --git a/zipline/pipeline/classifiers/classifier.py b/zipline/pipeline/classifiers/classifier.py index 5e49e655..86a91a0e 100644 --- a/zipline/pipeline/classifiers/classifier.py +++ b/zipline/pipeline/classifiers/classifier.py @@ -312,13 +312,26 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm): This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`. """ - data = super(Classifier, self).to_workspace_value(result, assets) if self.dtype == int64_dtype: - return data - assert isinstance(data, pd.Categorical), ( - 'Expected a Categorical, got %r.' % type(data).__name__ + return super(Classifier, self).to_workspace_value(result, assets) + + assert isinstance(result.values, pd.Categorical), ( + 'Expected a Categorical, got %r.' % type(result.values) + ) + with_missing = pd.Series( + data=pd.Categorical( + result.values, + result.values.categories.union([self.missing_value]), + ), + index=result.index, + ) + return LabelArray( + super(Classifier, self).to_workspace_value( + with_missing, + assets, + ), + self.missing_value, ) - return LabelArray.from_categorical(data, self.missing_value) @classlazyval def _downsampled_type(self): diff --git a/zipline/pipeline/term.py b/zipline/pipeline/term.py index 50236c5b..284f536d 100644 --- a/zipline/pipeline/term.py +++ b/zipline/pipeline/term.py @@ -593,10 +593,23 @@ class ComputableTerm(Term): def to_workspace_value(self, result, assets): """ - Called with the result of a pipeline. This needs to return an object - which can be put into the workspace to continue doing computations. + Called with a column of the result of a pipeline. This needs to put + the data into a format that can be used in a workspace to continue + doing computations. - This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`. + Parameters + ---------- + result : pd.Series + A multiindexed series with (dates, assets) whose values are the + results of running this pipeline term over the dates. + assets : pd.Index + All of the assets being requested. This allows us to correctly + shape the workspace value. + + Returns + ------- + workspace_value : array-like + An array like value that the engine can consume. """ return result.unstack().fillna(self.missing_value).reindex( columns=assets,