TST: add tests for postprocess and to_workspace_value

This commit is contained in:
Joe Jevnik
2016-10-12 19:16:30 -04:00
parent bac7af580b
commit af3e1016a0
5 changed files with 190 additions and 8 deletions
+81
View File
@@ -2,10 +2,13 @@ from functools import reduce
from operator import or_
import numpy as np
import pandas as pd
from zipline.lib.labelarray import LabelArray
from zipline.pipeline import Classifier
from zipline.testing import parameter_space
from zipline.testing.fixtures import ZiplineTestCase
from zipline.testing.predicates import assert_equal
from zipline.utils.numpy_utils import (
categorical_dtype,
int64_dtype,
@@ -464,3 +467,81 @@ class ClassifierTestCase(BasePipelineTestCase):
"TypeError(\"unhashable type: 'dict'\",)."
)
self.assertEqual(errmsg, expected)
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
def test_reversability_categorical(self):
class F(Classifier):
inputs = ()
window_length = 0
dtype = categorical_dtype
missing_value = '<missing>'
f = F()
column_data = LabelArray(
np.array(
[['a', f.missing_value],
['b', f.missing_value],
['c', 'd']],
),
missing_value=f.missing_value,
)
assert_equal(
f.postprocess(column_data.ravel()),
pd.Categorical(
['a', f.missing_value, 'b', f.missing_value, 'c', 'd'],
),
)
# only include the non-missing data
pipeline_output = pd.Series(
data=['a', 'b', 'c', 'd'],
index=pd.MultiIndex.from_arrays([
[pd.Timestamp('2014-01-01'),
pd.Timestamp('2014-01-02'),
pd.Timestamp('2014-01-03'),
pd.Timestamp('2014-01-03')],
[0, 0, 0, 1],
]),
dtype='category',
)
assert_equal(
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
column_data,
)
def test_reversability_int64(self):
class F(Classifier):
inputs = ()
window_length = 0
dtype = int64_dtype
missing_value = -1
f = F()
column_data =np.array(
[[0, f.missing_value],
[1, f.missing_value],
[2, 3]],
)
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
# only include the non-missing data
pipeline_output = pd.Series(
data=[0, 1, 2, 3],
index=pd.MultiIndex.from_arrays([
[pd.Timestamp('2014-01-01'),
pd.Timestamp('2014-01-02'),
pd.Timestamp('2014-01-03'),
pd.Timestamp('2014-01-03')],
[0, 0, 0, 1],
]),
dtype=int64_dtype,
)
assert_equal(
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
column_data,
)
+39
View File
@@ -21,6 +21,7 @@ from numpy import (
where,
)
from numpy.random import randn, seed
import pandas as pd
from zipline.errors import UnknownRankMethod
from zipline.lib.labelarray import LabelArray
@@ -37,6 +38,8 @@ from zipline.testing import (
parameter_space,
permute_rows,
)
from zipline.testing.fixtures import ZiplineTestCase
from zipline.testing.predicates import assert_equal
from zipline.utils.numpy_utils import (
categorical_dtype,
datetime64ns_dtype,
@@ -1058,3 +1061,39 @@ class TestWindowSafety(TestCase):
self.assertFalse(F().demean().window_safe)
self.assertFalse(F(window_safe=False).demean().window_safe)
self.assertTrue(F(window_safe=True).demean().window_safe)
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
@parameter_space(dtype_=(float64_dtype, datetime64ns_dtype))
def test_reversability(self, dtype_):
class F(Factor):
inputs = ()
dtype = dtype_
window_length = 0
f = F()
column_data = array(
[[0, f.missing_value],
[1, f.missing_value],
[2, 3]],
dtype=dtype_,
)
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
# only include the non-missing data
pipeline_output = pd.Series(
data=array([0, 1, 2, 3], dtype=dtype_),
index=pd.MultiIndex.from_arrays([
[pd.Timestamp('2014-01-01'),
pd.Timestamp('2014-01-02'),
pd.Timestamp('2014-01-03'),
pd.Timestamp('2014-01-03')],
[0, 0, 0, 1],
]),
)
assert_equal(
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
column_data,
)
+36
View File
@@ -24,6 +24,7 @@ from numpy import (
sum as np_sum
)
from numpy.random import randn, seed as random_seed
import pandas as pd
from zipline.errors import BadPercentileBounds
from zipline.pipeline import Filter, Factor, Pipeline
@@ -859,3 +860,38 @@ class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
assert_equal(results.odds, (sids % 2).astype(bool))
assert_equal(results.first_five, sids < 5)
assert_equal(results.last_three, sids >= 7)
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
def test_reversability(self):
class F(Filter):
inputs = ()
window_length = 0
missing_value = False
f = F()
column_data = array(
[[True, f.missing_value],
[True, f.missing_value],
[True, True]],
dtype=bool,
)
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
# only include the non-missing data
pipeline_output = pd.Series(
data=True,
index=pd.MultiIndex.from_arrays([
[pd.Timestamp('2014-01-01'),
pd.Timestamp('2014-01-02'),
pd.Timestamp('2014-01-03'),
pd.Timestamp('2014-01-03')],
[0, 0, 0, 1],
]),
)
assert_equal(
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
column_data,
)
+18 -5
View File
@@ -312,13 +312,26 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
"""
data = super(Classifier, self).to_workspace_value(result, assets)
if self.dtype == int64_dtype:
return data
assert isinstance(data, pd.Categorical), (
'Expected a Categorical, got %r.' % type(data).__name__
return super(Classifier, self).to_workspace_value(result, assets)
assert isinstance(result.values, pd.Categorical), (
'Expected a Categorical, got %r.' % type(result.values)
)
with_missing = pd.Series(
data=pd.Categorical(
result.values,
result.values.categories.union([self.missing_value]),
),
index=result.index,
)
return LabelArray(
super(Classifier, self).to_workspace_value(
with_missing,
assets,
),
self.missing_value,
)
return LabelArray.from_categorical(data, self.missing_value)
@classlazyval
def _downsampled_type(self):
+16 -3
View File
@@ -593,10 +593,23 @@ class ComputableTerm(Term):
def to_workspace_value(self, result, assets):
"""
Called with the result of a pipeline. This needs to return an object
which can be put into the workspace to continue doing computations.
Called with a column of the result of a pipeline. This needs to put
the data into a format that can be used in a workspace to continue
doing computations.
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
Parameters
----------
result : pd.Series
A multiindexed series with (dates, assets) whose values are the
results of running this pipeline term over the dates.
assets : pd.Index
All of the assets being requested. This allows us to correctly
shape the workspace value.
Returns
-------
workspace_value : array-like
An array like value that the engine can consume.
"""
return result.unstack().fillna(self.missing_value).reindex(
columns=assets,