From bb6f908036c05e91a8301a196fdf290ad3fc39ab Mon Sep 17 00:00:00 2001 From: Scott Sanderson Date: Fri, 29 Apr 2016 19:52:11 -0400 Subject: [PATCH] TEST: Add test for categorical postprocessing. --- tests/pipeline/test_engine.py | 55 ++++++++++++++++---- zipline/pipeline/data/testing.py | 2 +- zipline/pipeline/loaders/synthetic.py | 3 +- zipline/testing/fixtures.py | 73 +++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 13 deletions(-) diff --git a/tests/pipeline/test_engine.py b/tests/pipeline/test_engine.py index f00d349a..254936ff 100644 --- a/tests/pipeline/test_engine.py +++ b/tests/pipeline/test_engine.py @@ -22,6 +22,7 @@ from numpy import ( ) from numpy.testing import assert_almost_equal from pandas import ( + Categorical, DataFrame, date_range, ewma, @@ -40,17 +41,10 @@ from toolz import merge from zipline.assets.synthetic import make_rotating_equity_info from zipline.lib.adjustment import MULTIPLY -from zipline.pipeline.loaders.synthetic import PrecomputedLoader -from zipline.pipeline import Pipeline -from zipline.pipeline.data import USEquityPricing, DataSet, Column -from zipline.pipeline.loaders.equity_pricing_loader import ( - USEquityPricingLoader, -) -from zipline.pipeline.factors import CustomFactor -from zipline.pipeline.loaders.synthetic import ( - make_bar_data, - expected_bar_values_2d, -) +from zipline.lib.labelarray import LabelArray +from zipline.pipeline import CustomFactor, Pipeline +from zipline.pipeline.data import Column, DataSet, USEquityPricing +from zipline.pipeline.data.testing import TestingDataSet from zipline.pipeline.engine import SimplePipelineEngine from zipline.pipeline.factors import ( AverageDollarVolume, @@ -61,7 +55,15 @@ from zipline.pipeline.factors import ( MaxDrawdown, SimpleMovingAverage, ) +from zipline.pipeline.loaders.equity_pricing_loader import ( + USEquityPricingLoader, +) from zipline.pipeline.loaders.frame import DataFrameLoader +from zipline.pipeline.loaders.synthetic import ( + PrecomputedLoader, + make_bar_data, + expected_bar_values_2d, +) from zipline.pipeline.term import NotSpecified from zipline.testing import ( product_upper_triangle, @@ -69,6 +71,7 @@ from zipline.testing import ( ) from zipline.testing.fixtures import ( WithAdjustmentReader, + WithSeededRandomPipelineEngine, WithTradingEnvironment, ZiplineTestCase, ) @@ -1238,3 +1241,33 @@ class ParameterizedFactorTestCase(WithTradingEnvironment, ZiplineTestCase): expected_5 = rolling_mean((self.raw_data ** 2) * 2, window=5)[5:] assert_frame_equal(results['dv5'].unstack(), expected_5) + + +class StringColumnTestCase(WithSeededRandomPipelineEngine, + ZiplineTestCase): + + def test_string_classifiers_produce_categoricals(self): + """ + Test that string-based classifiers produce pandas categoricals as their + outputs. + """ + col = TestingDataSet.categorical_col + pipe = Pipeline(columns={'c': col.latest}) + + run_dates = self.trading_days[-10:] + start_date, end_date = run_dates[[0, -1]] + + result = self.run_pipeline(pipe, start_date, end_date) + assert isinstance(result.c.values, Categorical) + + expected_raw_data = self.raw_expected_values( + col, + start_date, + end_date, + ) + expected_labels = LabelArray(expected_raw_data, col.missing_value) + expected_final_result = expected_labels.as_categorical_frame( + index=run_dates, + columns=self.asset_finder.retrieve_all(self.asset_finder.sids), + ) + assert_frame_equal(result.c.unstack(), expected_final_result) diff --git a/zipline/pipeline/data/testing.py b/zipline/pipeline/data/testing.py index 30f92642..23712000 100644 --- a/zipline/pipeline/data/testing.py +++ b/zipline/pipeline/data/testing.py @@ -27,5 +27,5 @@ class TestingDataSet(DataSet): categorical_col = Column(dtype=categorical_dtype, missing_value=u'') categorical_default_NULL = Column( dtype=categorical_dtype, - missing_value=u'NULL', + missing_value=u'<>', ) diff --git a/zipline/pipeline/loaders/synthetic.py b/zipline/pipeline/loaders/synthetic.py index dfea5829..ab02f878 100644 --- a/zipline/pipeline/loaders/synthetic.py +++ b/zipline/pipeline/loaders/synthetic.py @@ -194,7 +194,8 @@ class SeededRandomLoader(PrecomputedLoader): return self.state.randn(*shape) < 0 def _object_values(self, shape): - return self._int_values(shape).astype(str).astype(object) + res = self._int_values(shape).astype(str).astype(object) + return res OHLCV = ('open', 'high', 'low', 'close', 'volume') diff --git a/zipline/testing/fixtures.py b/zipline/testing/fixtures.py index ff82042c..77c33671 100644 --- a/zipline/testing/fixtures.py +++ b/zipline/testing/fixtures.py @@ -40,6 +40,7 @@ from ..utils.final import FinalMeta, final from ..utils.metautils import compose_types from .core import tmp_asset_finder, make_simple_equity_info, gen_calendars from zipline.pipeline import Pipeline, SimplePipelineEngine +from zipline.pipeline.loaders.testing import make_seeded_random_loader from zipline.utils.numpy_utils import make_datetime64D from zipline.utils.numpy_utils import NaTD from zipline.pipeline.common import TS_FIELD_NAME @@ -986,6 +987,78 @@ class WithPipelineEventDataLoader(with_metaclass( check_names=False) +class WithSeededRandomPipelineEngine(WithNYSETradingDays, WithAssetFinder): + """ + ZiplineTestCase mixin providing class-level fixtures for running pipelines + against deterministically-generated random data. + + Attributes + ---------- + SEEDED_RANDOM_PIPELINE_SEED : int + Fixture input. Random seed used to initialize the random state loader. + seeded_random_loader : SeededRandomLoader + Fixture output. Loader capable of providing columns for + zipline.pipeline.data.testing.TestingDataSet. + seeded_random_engine : SimplePipelineEngine + Fixture output. A pipeline engine that will use seeded_random_loader + as its only data provider. + + Methods + ------- + run_pipeline(start_date, end_date) + Run a pipeline with self.seeded_random_engine. + + See Also + -------- + zipline.pipeline.loaders.synthetic.SeededRandomLoader + zipline.pipeline.loaders.testing.make_seeded_random_loader + zipline.pipeline.engine.SimplePipelineEngine + """ + SEEDED_RANDOM_PIPELINE_SEED = 42 + + @classmethod + def init_class_fixtures(cls): + super(WithSeededRandomPipelineEngine, cls).init_class_fixtures() + cls._sids = cls.asset_finder.sids + cls.seeded_random_loader = loader = make_seeded_random_loader( + cls.SEEDED_RANDOM_PIPELINE_SEED, + cls.trading_days, + cls._sids, + ) + cls.seeded_random_engine = SimplePipelineEngine( + get_loader=lambda column: loader, + calendar=cls.trading_days, + asset_finder=cls.asset_finder, + ) + + def raw_expected_values(self, column, start_date, end_date): + """ + Get an array containing the raw values we expect to be produced for the + given dates between start_date and end_date, inclusive. + """ + all_values = self.seeded_random_loader.values( + column.dtype, + self.trading_days, + self._sids, + ) + row_slice = self.trading_days.slice_indexer(start_date, end_date) + return all_values[row_slice] + + def run_pipeline(self, pipeline, start_date, end_date): + """ + Run a pipeline with self.seeded_random_engine. + """ + if start_date not in self.trading_days: + raise AssertionError("Start date not in calendar: %s" % start_date) + if end_date not in self.trading_days: + raise AssertionError("Start date not in calendar: %s" % start_date) + return self.seeded_random_engine.run_pipeline( + pipeline, + start_date, + end_date, + ) + + class WithDataPortal(WithBcolzMinuteBarReader, WithAdjustmentReader): """ ZiplineTestCase mixin providing self.data_portal as an instance level