mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 05:03:26 +08:00
TEST: Add test for categorical postprocessing.
This commit is contained in:
@@ -22,6 +22,7 @@ from numpy import (
|
||||
)
|
||||
from numpy.testing import assert_almost_equal
|
||||
from pandas import (
|
||||
Categorical,
|
||||
DataFrame,
|
||||
date_range,
|
||||
ewma,
|
||||
@@ -40,17 +41,10 @@ from toolz import merge
|
||||
|
||||
from zipline.assets.synthetic import make_rotating_equity_info
|
||||
from zipline.lib.adjustment import MULTIPLY
|
||||
from zipline.pipeline.loaders.synthetic import PrecomputedLoader
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.data import USEquityPricing, DataSet, Column
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
from zipline.pipeline.factors import CustomFactor
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
make_bar_data,
|
||||
expected_bar_values_2d,
|
||||
)
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.pipeline import CustomFactor, Pipeline
|
||||
from zipline.pipeline.data import Column, DataSet, USEquityPricing
|
||||
from zipline.pipeline.data.testing import TestingDataSet
|
||||
from zipline.pipeline.engine import SimplePipelineEngine
|
||||
from zipline.pipeline.factors import (
|
||||
AverageDollarVolume,
|
||||
@@ -61,7 +55,15 @@ from zipline.pipeline.factors import (
|
||||
MaxDrawdown,
|
||||
SimpleMovingAverage,
|
||||
)
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
USEquityPricingLoader,
|
||||
)
|
||||
from zipline.pipeline.loaders.frame import DataFrameLoader
|
||||
from zipline.pipeline.loaders.synthetic import (
|
||||
PrecomputedLoader,
|
||||
make_bar_data,
|
||||
expected_bar_values_2d,
|
||||
)
|
||||
from zipline.pipeline.term import NotSpecified
|
||||
from zipline.testing import (
|
||||
product_upper_triangle,
|
||||
@@ -69,6 +71,7 @@ from zipline.testing import (
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithAdjustmentReader,
|
||||
WithSeededRandomPipelineEngine,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
)
|
||||
@@ -1238,3 +1241,33 @@ class ParameterizedFactorTestCase(WithTradingEnvironment, ZiplineTestCase):
|
||||
|
||||
expected_5 = rolling_mean((self.raw_data ** 2) * 2, window=5)[5:]
|
||||
assert_frame_equal(results['dv5'].unstack(), expected_5)
|
||||
|
||||
|
||||
class StringColumnTestCase(WithSeededRandomPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
def test_string_classifiers_produce_categoricals(self):
|
||||
"""
|
||||
Test that string-based classifiers produce pandas categoricals as their
|
||||
outputs.
|
||||
"""
|
||||
col = TestingDataSet.categorical_col
|
||||
pipe = Pipeline(columns={'c': col.latest})
|
||||
|
||||
run_dates = self.trading_days[-10:]
|
||||
start_date, end_date = run_dates[[0, -1]]
|
||||
|
||||
result = self.run_pipeline(pipe, start_date, end_date)
|
||||
assert isinstance(result.c.values, Categorical)
|
||||
|
||||
expected_raw_data = self.raw_expected_values(
|
||||
col,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
expected_labels = LabelArray(expected_raw_data, col.missing_value)
|
||||
expected_final_result = expected_labels.as_categorical_frame(
|
||||
index=run_dates,
|
||||
columns=self.asset_finder.retrieve_all(self.asset_finder.sids),
|
||||
)
|
||||
assert_frame_equal(result.c.unstack(), expected_final_result)
|
||||
|
||||
@@ -27,5 +27,5 @@ class TestingDataSet(DataSet):
|
||||
categorical_col = Column(dtype=categorical_dtype, missing_value=u'')
|
||||
categorical_default_NULL = Column(
|
||||
dtype=categorical_dtype,
|
||||
missing_value=u'NULL',
|
||||
missing_value=u'<<NULL>>',
|
||||
)
|
||||
|
||||
@@ -194,7 +194,8 @@ class SeededRandomLoader(PrecomputedLoader):
|
||||
return self.state.randn(*shape) < 0
|
||||
|
||||
def _object_values(self, shape):
|
||||
return self._int_values(shape).astype(str).astype(object)
|
||||
res = self._int_values(shape).astype(str).astype(object)
|
||||
return res
|
||||
|
||||
|
||||
OHLCV = ('open', 'high', 'low', 'close', 'volume')
|
||||
|
||||
@@ -40,6 +40,7 @@ from ..utils.final import FinalMeta, final
|
||||
from ..utils.metautils import compose_types
|
||||
from .core import tmp_asset_finder, make_simple_equity_info, gen_calendars
|
||||
from zipline.pipeline import Pipeline, SimplePipelineEngine
|
||||
from zipline.pipeline.loaders.testing import make_seeded_random_loader
|
||||
from zipline.utils.numpy_utils import make_datetime64D
|
||||
from zipline.utils.numpy_utils import NaTD
|
||||
from zipline.pipeline.common import TS_FIELD_NAME
|
||||
@@ -986,6 +987,78 @@ class WithPipelineEventDataLoader(with_metaclass(
|
||||
check_names=False)
|
||||
|
||||
|
||||
class WithSeededRandomPipelineEngine(WithNYSETradingDays, WithAssetFinder):
|
||||
"""
|
||||
ZiplineTestCase mixin providing class-level fixtures for running pipelines
|
||||
against deterministically-generated random data.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
SEEDED_RANDOM_PIPELINE_SEED : int
|
||||
Fixture input. Random seed used to initialize the random state loader.
|
||||
seeded_random_loader : SeededRandomLoader
|
||||
Fixture output. Loader capable of providing columns for
|
||||
zipline.pipeline.data.testing.TestingDataSet.
|
||||
seeded_random_engine : SimplePipelineEngine
|
||||
Fixture output. A pipeline engine that will use seeded_random_loader
|
||||
as its only data provider.
|
||||
|
||||
Methods
|
||||
-------
|
||||
run_pipeline(start_date, end_date)
|
||||
Run a pipeline with self.seeded_random_engine.
|
||||
|
||||
See Also
|
||||
--------
|
||||
zipline.pipeline.loaders.synthetic.SeededRandomLoader
|
||||
zipline.pipeline.loaders.testing.make_seeded_random_loader
|
||||
zipline.pipeline.engine.SimplePipelineEngine
|
||||
"""
|
||||
SEEDED_RANDOM_PIPELINE_SEED = 42
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(WithSeededRandomPipelineEngine, cls).init_class_fixtures()
|
||||
cls._sids = cls.asset_finder.sids
|
||||
cls.seeded_random_loader = loader = make_seeded_random_loader(
|
||||
cls.SEEDED_RANDOM_PIPELINE_SEED,
|
||||
cls.trading_days,
|
||||
cls._sids,
|
||||
)
|
||||
cls.seeded_random_engine = SimplePipelineEngine(
|
||||
get_loader=lambda column: loader,
|
||||
calendar=cls.trading_days,
|
||||
asset_finder=cls.asset_finder,
|
||||
)
|
||||
|
||||
def raw_expected_values(self, column, start_date, end_date):
|
||||
"""
|
||||
Get an array containing the raw values we expect to be produced for the
|
||||
given dates between start_date and end_date, inclusive.
|
||||
"""
|
||||
all_values = self.seeded_random_loader.values(
|
||||
column.dtype,
|
||||
self.trading_days,
|
||||
self._sids,
|
||||
)
|
||||
row_slice = self.trading_days.slice_indexer(start_date, end_date)
|
||||
return all_values[row_slice]
|
||||
|
||||
def run_pipeline(self, pipeline, start_date, end_date):
|
||||
"""
|
||||
Run a pipeline with self.seeded_random_engine.
|
||||
"""
|
||||
if start_date not in self.trading_days:
|
||||
raise AssertionError("Start date not in calendar: %s" % start_date)
|
||||
if end_date not in self.trading_days:
|
||||
raise AssertionError("Start date not in calendar: %s" % start_date)
|
||||
return self.seeded_random_engine.run_pipeline(
|
||||
pipeline,
|
||||
start_date,
|
||||
end_date,
|
||||
)
|
||||
|
||||
|
||||
class WithDataPortal(WithBcolzMinuteBarReader, WithAdjustmentReader):
|
||||
"""
|
||||
ZiplineTestCase mixin providing self.data_portal as an instance level
|
||||
|
||||
Reference in New Issue
Block a user