mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 00:33:16 +08:00
TST: add tests for postprocess and to_workspace_value
This commit is contained in:
@@ -2,10 +2,13 @@ from functools import reduce
|
||||
from operator import or_
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.pipeline import Classifier
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import ZiplineTestCase
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
int64_dtype,
|
||||
@@ -464,3 +467,81 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
"TypeError(\"unhashable type: 'dict'\",)."
|
||||
)
|
||||
self.assertEqual(errmsg, expected)
|
||||
|
||||
|
||||
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
|
||||
def test_reversability_categorical(self):
|
||||
class F(Classifier):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
dtype = categorical_dtype
|
||||
missing_value = '<missing>'
|
||||
|
||||
f = F()
|
||||
column_data = LabelArray(
|
||||
np.array(
|
||||
[['a', f.missing_value],
|
||||
['b', f.missing_value],
|
||||
['c', 'd']],
|
||||
),
|
||||
missing_value=f.missing_value,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
f.postprocess(column_data.ravel()),
|
||||
pd.Categorical(
|
||||
['a', f.missing_value, 'b', f.missing_value, 'c', 'd'],
|
||||
),
|
||||
)
|
||||
|
||||
# only include the non-missing data
|
||||
pipeline_output = pd.Series(
|
||||
data=['a', 'b', 'c', 'd'],
|
||||
index=pd.MultiIndex.from_arrays([
|
||||
[pd.Timestamp('2014-01-01'),
|
||||
pd.Timestamp('2014-01-02'),
|
||||
pd.Timestamp('2014-01-03'),
|
||||
pd.Timestamp('2014-01-03')],
|
||||
[0, 0, 0, 1],
|
||||
]),
|
||||
dtype='category',
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
|
||||
column_data,
|
||||
)
|
||||
|
||||
def test_reversability_int64(self):
|
||||
class F(Classifier):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
dtype = int64_dtype
|
||||
missing_value = -1
|
||||
|
||||
f = F()
|
||||
column_data =np.array(
|
||||
[[0, f.missing_value],
|
||||
[1, f.missing_value],
|
||||
[2, 3]],
|
||||
)
|
||||
|
||||
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
|
||||
|
||||
# only include the non-missing data
|
||||
pipeline_output = pd.Series(
|
||||
data=[0, 1, 2, 3],
|
||||
index=pd.MultiIndex.from_arrays([
|
||||
[pd.Timestamp('2014-01-01'),
|
||||
pd.Timestamp('2014-01-02'),
|
||||
pd.Timestamp('2014-01-03'),
|
||||
pd.Timestamp('2014-01-03')],
|
||||
[0, 0, 0, 1],
|
||||
]),
|
||||
dtype=int64_dtype,
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
|
||||
column_data,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from numpy import (
|
||||
where,
|
||||
)
|
||||
from numpy.random import randn, seed
|
||||
import pandas as pd
|
||||
|
||||
from zipline.errors import UnknownRankMethod
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
@@ -37,6 +38,8 @@ from zipline.testing import (
|
||||
parameter_space,
|
||||
permute_rows,
|
||||
)
|
||||
from zipline.testing.fixtures import ZiplineTestCase
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.numpy_utils import (
|
||||
categorical_dtype,
|
||||
datetime64ns_dtype,
|
||||
@@ -1058,3 +1061,39 @@ class TestWindowSafety(TestCase):
|
||||
self.assertFalse(F().demean().window_safe)
|
||||
self.assertFalse(F(window_safe=False).demean().window_safe)
|
||||
self.assertTrue(F(window_safe=True).demean().window_safe)
|
||||
|
||||
|
||||
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
|
||||
@parameter_space(dtype_=(float64_dtype, datetime64ns_dtype))
|
||||
def test_reversability(self, dtype_):
|
||||
class F(Factor):
|
||||
inputs = ()
|
||||
dtype = dtype_
|
||||
window_length = 0
|
||||
|
||||
f = F()
|
||||
column_data = array(
|
||||
[[0, f.missing_value],
|
||||
[1, f.missing_value],
|
||||
[2, 3]],
|
||||
dtype=dtype_,
|
||||
)
|
||||
|
||||
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
|
||||
|
||||
# only include the non-missing data
|
||||
pipeline_output = pd.Series(
|
||||
data=array([0, 1, 2, 3], dtype=dtype_),
|
||||
index=pd.MultiIndex.from_arrays([
|
||||
[pd.Timestamp('2014-01-01'),
|
||||
pd.Timestamp('2014-01-02'),
|
||||
pd.Timestamp('2014-01-03'),
|
||||
pd.Timestamp('2014-01-03')],
|
||||
[0, 0, 0, 1],
|
||||
]),
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
|
||||
column_data,
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from numpy import (
|
||||
sum as np_sum
|
||||
)
|
||||
from numpy.random import randn, seed as random_seed
|
||||
import pandas as pd
|
||||
|
||||
from zipline.errors import BadPercentileBounds
|
||||
from zipline.pipeline import Filter, Factor, Pipeline
|
||||
@@ -859,3 +860,38 @@ class SpecificAssetsTestCase(WithSeededRandomPipelineEngine,
|
||||
assert_equal(results.odds, (sids % 2).astype(bool))
|
||||
assert_equal(results.first_five, sids < 5)
|
||||
assert_equal(results.last_three, sids >= 7)
|
||||
|
||||
|
||||
class TestPostProcessAndToWorkSpaceValue(ZiplineTestCase):
|
||||
def test_reversability(self):
|
||||
class F(Filter):
|
||||
inputs = ()
|
||||
window_length = 0
|
||||
missing_value = False
|
||||
|
||||
f = F()
|
||||
column_data = array(
|
||||
[[True, f.missing_value],
|
||||
[True, f.missing_value],
|
||||
[True, True]],
|
||||
dtype=bool,
|
||||
)
|
||||
|
||||
assert_equal(f.postprocess(column_data.ravel()), column_data.ravel())
|
||||
|
||||
# only include the non-missing data
|
||||
pipeline_output = pd.Series(
|
||||
data=True,
|
||||
index=pd.MultiIndex.from_arrays([
|
||||
[pd.Timestamp('2014-01-01'),
|
||||
pd.Timestamp('2014-01-02'),
|
||||
pd.Timestamp('2014-01-03'),
|
||||
pd.Timestamp('2014-01-03')],
|
||||
[0, 0, 0, 1],
|
||||
]),
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
f.to_workspace_value(pipeline_output, pd.Index([0, 1])),
|
||||
column_data,
|
||||
)
|
||||
|
||||
@@ -312,13 +312,26 @@ class Classifier(RestrictedDTypeMixin, ComputableTerm):
|
||||
|
||||
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
|
||||
"""
|
||||
data = super(Classifier, self).to_workspace_value(result, assets)
|
||||
if self.dtype == int64_dtype:
|
||||
return data
|
||||
assert isinstance(data, pd.Categorical), (
|
||||
'Expected a Categorical, got %r.' % type(data).__name__
|
||||
return super(Classifier, self).to_workspace_value(result, assets)
|
||||
|
||||
assert isinstance(result.values, pd.Categorical), (
|
||||
'Expected a Categorical, got %r.' % type(result.values)
|
||||
)
|
||||
with_missing = pd.Series(
|
||||
data=pd.Categorical(
|
||||
result.values,
|
||||
result.values.categories.union([self.missing_value]),
|
||||
),
|
||||
index=result.index,
|
||||
)
|
||||
return LabelArray(
|
||||
super(Classifier, self).to_workspace_value(
|
||||
with_missing,
|
||||
assets,
|
||||
),
|
||||
self.missing_value,
|
||||
)
|
||||
return LabelArray.from_categorical(data, self.missing_value)
|
||||
|
||||
@classlazyval
|
||||
def _downsampled_type(self):
|
||||
|
||||
@@ -593,10 +593,23 @@ class ComputableTerm(Term):
|
||||
|
||||
def to_workspace_value(self, result, assets):
|
||||
"""
|
||||
Called with the result of a pipeline. This needs to return an object
|
||||
which can be put into the workspace to continue doing computations.
|
||||
Called with a column of the result of a pipeline. This needs to put
|
||||
the data into a format that can be used in a workspace to continue
|
||||
doing computations.
|
||||
|
||||
This is the inverse of :func:`~zipline.pipeline.term.Term.postprocess`.
|
||||
Parameters
|
||||
----------
|
||||
result : pd.Series
|
||||
A multiindexed series with (dates, assets) whose values are the
|
||||
results of running this pipeline term over the dates.
|
||||
assets : pd.Index
|
||||
All of the assets being requested. This allows us to correctly
|
||||
shape the workspace value.
|
||||
|
||||
Returns
|
||||
-------
|
||||
workspace_value : array-like
|
||||
An array like value that the engine can consume.
|
||||
"""
|
||||
return result.unstack().fillna(self.missing_value).reindex(
|
||||
columns=assets,
|
||||
|
||||
Reference in New Issue
Block a user