mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-05 03:02:00 +08:00
ENH: Add run_chunked_pipeline method to PipelineEngine
This commit is contained in:
@@ -51,6 +51,7 @@ from zipline.pipeline.factors import (
|
||||
ExponentialWeightedMovingAverage,
|
||||
ExponentialWeightedMovingStdDev,
|
||||
MaxDrawdown,
|
||||
Returns,
|
||||
SimpleMovingAverage,
|
||||
)
|
||||
from zipline.pipeline.loaders.equity_pricing_loader import (
|
||||
@@ -77,6 +78,7 @@ from zipline.testing import (
|
||||
)
|
||||
from zipline.testing.fixtures import (
|
||||
WithAdjustmentReader,
|
||||
WithEquityPricingPipelineEngine,
|
||||
WithSeededRandomPipelineEngine,
|
||||
WithTradingEnvironment,
|
||||
ZiplineTestCase,
|
||||
@@ -1497,3 +1499,35 @@ class PopulateInitialWorkspaceTestCase(WithConstantInputs, ZiplineTestCase):
|
||||
precomputed_term_value,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ChunkedPipelineTestCase(WithEquityPricingPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
PIPELINE_START_DATE = Timestamp('2006-01-05', tz='UTC')
|
||||
END_DATE = Timestamp('2006-12-29', tz='UTC')
|
||||
|
||||
def test_run_chunked_pipeline(self):
|
||||
"""
|
||||
Test that running a pipeline in chunks produces the same result as if
|
||||
it were run all at once
|
||||
"""
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'close': USEquityPricing.close.latest,
|
||||
'returns': Returns(window_length=2),
|
||||
'categorical': USEquityPricing.close.latest.quantiles(5)
|
||||
},
|
||||
)
|
||||
pipeline_result = self.pipeline_engine.run_pipeline(
|
||||
pipe,
|
||||
start_date=self.PIPELINE_START_DATE,
|
||||
end_date=self.END_DATE,
|
||||
)
|
||||
chunked_result = self.pipeline_engine.run_chunked_pipeline(
|
||||
pipeline=pipe,
|
||||
start_date=self.PIPELINE_START_DATE,
|
||||
end_date=self.END_DATE,
|
||||
chunksize=22
|
||||
)
|
||||
self.assertTrue(chunked_result.equals(pipeline_result))
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
from zipline.pipeline import Pipeline, run_chunked_pipeline
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.factors import Returns
|
||||
from zipline.testing import ZiplineTestCase
|
||||
from zipline.testing.fixtures import WithEquityPricingPipelineEngine
|
||||
|
||||
|
||||
class ChunkedPipelineTestCase(WithEquityPricingPipelineEngine,
|
||||
ZiplineTestCase):
|
||||
|
||||
def test_run_chunked_pipeline(self):
|
||||
"""
|
||||
Test that running a pipeline in chunks produces the same result as if
|
||||
it were run all at once
|
||||
"""
|
||||
pipe = Pipeline(
|
||||
columns={
|
||||
'close': USEquityPricing.close.latest,
|
||||
'returns': Returns(window_length=2),
|
||||
},
|
||||
)
|
||||
sessions = self.nyse_calendar.all_sessions
|
||||
start_date = sessions[sessions.get_loc(self.START_DATE) + 2]
|
||||
|
||||
pipeline_result = self.pipeline_engine.run_pipeline(
|
||||
pipe,
|
||||
start_date=start_date,
|
||||
end_date=self.END_DATE,
|
||||
)
|
||||
chunked_result = run_chunked_pipeline(
|
||||
engine=self.pipeline_engine,
|
||||
pipeline=pipe,
|
||||
start_date=start_date,
|
||||
end_date=self.END_DATE,
|
||||
chunksize=22
|
||||
)
|
||||
self.assertTrue(chunked_result.equals(pipeline_result))
|
||||
@@ -3,26 +3,84 @@ from nose_parameterized import parameterized
|
||||
|
||||
from zipline.testing import ZiplineTestCase
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.date_utils import roll_dates_to_previous_session
|
||||
from zipline.utils.date_utils import compute_date_range_chunks
|
||||
|
||||
|
||||
class TestRollDatesToPreviousSession(ZiplineTestCase):
|
||||
def T(s):
|
||||
"""
|
||||
Helpful function to improve readibility.
|
||||
"""
|
||||
return Timestamp(s, tz='UTC')
|
||||
|
||||
|
||||
class TestDateUtils(ZiplineTestCase):
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
super(TestDateUtils, cls).init_class_fixtures()
|
||||
cls.calendar = get_calendar('NYSE')
|
||||
|
||||
@parameterized.expand([
|
||||
(
|
||||
Timestamp('05-19-2017', tz='UTC'), # actual trading date
|
||||
Timestamp('05-19-2017', tz='UTC'),
|
||||
),
|
||||
(
|
||||
Timestamp('07-04-2015', tz='UTC'), # weekend nyse holiday
|
||||
Timestamp('07-02-2015', tz='UTC'),
|
||||
),
|
||||
(
|
||||
Timestamp('01-16-2017', tz='UTC'), # weeknight nyse holiday
|
||||
Timestamp('01-13-2017', tz='UTC'),
|
||||
),
|
||||
(None, [(T('2017-01-03'), T('2017-01-31'))]),
|
||||
(10, [
|
||||
(T('2017-01-03'), T('2017-01-17')),
|
||||
(T('2017-01-18'), T('2017-01-31'))
|
||||
]),
|
||||
(15, [
|
||||
(T('2017-01-03'), T('2017-01-24')),
|
||||
(T('2017-01-25'), T('2017-01-31'))
|
||||
]),
|
||||
])
|
||||
def test_roll_dates_to_previous_session(self, date, expected_rolled_date):
|
||||
calendar = get_calendar('NYSE')
|
||||
result = roll_dates_to_previous_session(calendar, date)
|
||||
self.assertEqual(result[0], expected_rolled_date)
|
||||
def test_compute_date_range_chunks(self, chunksize, expected):
|
||||
# This date range results in 20 business days
|
||||
start_date = T('2017-01-03')
|
||||
end_date = T('2017-01-31')
|
||||
|
||||
date_ranges = compute_date_range_chunks(
|
||||
self.calendar.all_sessions,
|
||||
start_date,
|
||||
end_date,
|
||||
chunksize
|
||||
)
|
||||
|
||||
self.assertListEqual(list(date_ranges), expected)
|
||||
|
||||
def test_compute_date_range_chunks_invalid_input(self):
|
||||
# Start date not found in calendar
|
||||
with self.assertRaises(KeyError) as cm:
|
||||
compute_date_range_chunks(
|
||||
self.calendar.all_sessions,
|
||||
T('2017-05-07'), # Sunday
|
||||
T('2017-06-01'),
|
||||
None
|
||||
)
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
"'Start date 2017-05-07 is not found in calendar.'"
|
||||
)
|
||||
|
||||
# End date not found in calendar
|
||||
with self.assertRaises(KeyError) as cm:
|
||||
compute_date_range_chunks(
|
||||
self.calendar.all_sessions,
|
||||
T('2017-05-01'),
|
||||
T('2017-05-27'), # Saturday
|
||||
None
|
||||
)
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
"'End date 2017-05-27 is not found in calendar.'"
|
||||
)
|
||||
|
||||
# End date before start date
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
compute_date_range_chunks(
|
||||
self.calendar.all_sessions,
|
||||
T('2017-06-01'),
|
||||
T('2017-05-01'),
|
||||
None
|
||||
)
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
"End date 2017-05-01 cannot precede start date 2017-06-01."
|
||||
)
|
||||
|
||||
@@ -173,8 +173,16 @@ class TestCatDFConcat(ZiplineTestCase):
|
||||
),
|
||||
]
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
categorical_df_concat(mismatched_dtypes)
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
"Input DataFrames must have the same columns/dtypes."
|
||||
)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
categorical_df_concat(mismatched_column_names)
|
||||
self.assertEqual(
|
||||
str(cm.exception),
|
||||
"Input DataFrames must have the same columns/dtypes."
|
||||
)
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
from __future__ import print_function
|
||||
from zipline.assets import AssetFinder
|
||||
from zipline.utils.calendars import get_calendar
|
||||
from zipline.utils.date_utils import compute_date_range_chunks
|
||||
from zipline.utils.pandas_utils import categorical_df_concat
|
||||
|
||||
from .classifiers import Classifier, CustomClassifier
|
||||
from .engine import SimplePipelineEngine
|
||||
@@ -50,39 +47,6 @@ def engine_from_files(daily_bar_path,
|
||||
)
|
||||
|
||||
|
||||
def run_chunked_pipeline(engine, pipeline, start_date, end_date, chunksize):
|
||||
"""Run a pipeline to collect the results.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
engine : Engine
|
||||
The pipeline engine.
|
||||
pipeline : Pipeline
|
||||
The pipeline to run.
|
||||
start_date : pd.Timestamp
|
||||
The start date to run the pipeline for.
|
||||
end_date : pd.Timestamp
|
||||
The end date to run the pipeline for.
|
||||
chunksize : int or None
|
||||
The number of days to execute at a time. If this is None, all the days
|
||||
will be run at once.
|
||||
|
||||
Returns
|
||||
-------
|
||||
results : pd.DataFrame
|
||||
The results for each output term in the pipeline.
|
||||
"""
|
||||
ranges = compute_date_range_chunks(
|
||||
get_calendar('NYSE'),
|
||||
start_date,
|
||||
end_date,
|
||||
chunksize,
|
||||
)
|
||||
chunks = [engine.run_pipeline(pipeline, s, e) for s, e in ranges]
|
||||
|
||||
return categorical_df_concat(chunks, inplace=True)
|
||||
|
||||
|
||||
__all__ = (
|
||||
'Classifier',
|
||||
'CustomFactor',
|
||||
@@ -94,7 +58,6 @@ __all__ = (
|
||||
'Filter',
|
||||
'Pipeline',
|
||||
'SimplePipelineEngine',
|
||||
'run_chunked_pipeline',
|
||||
'Term',
|
||||
'TermGraph',
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from six import (
|
||||
with_metaclass,
|
||||
)
|
||||
from numpy import array
|
||||
from odo.utils import copydoc
|
||||
from pandas import DataFrame, MultiIndex
|
||||
from toolz import groupby, juxt
|
||||
from toolz.curried.operator import getitem
|
||||
@@ -27,6 +28,9 @@ from zipline.utils.pandas_utils import explode
|
||||
|
||||
from .term import AssetExists, InputDates, LoadableTerm
|
||||
|
||||
from zipline.utils.date_utils import compute_date_range_chunks
|
||||
from zipline.utils.pandas_utils import categorical_df_concat
|
||||
|
||||
|
||||
class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
|
||||
@@ -62,6 +66,45 @@ class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
"""
|
||||
raise NotImplementedError("run_pipeline")
|
||||
|
||||
@abstractmethod
|
||||
def run_chunked_pipeline(self, pipeline, start_date, end_date, chunksize):
|
||||
"""
|
||||
Compute values for `pipeline` in number of days equal to `chunksize`
|
||||
and return stitched up result. Computing in chunks is useful for
|
||||
pipelines computed over a long period of time.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pipeline : Pipeline
|
||||
The pipeline to run.
|
||||
start_date : pd.Timestamp
|
||||
The start date to run the pipeline for.
|
||||
end_date : pd.Timestamp
|
||||
The end date to run the pipeline for.
|
||||
chunksize : int or None
|
||||
The number of days to execute at a time. If None, then
|
||||
results will be calculated for entire date range at once.
|
||||
|
||||
Returns
|
||||
-------
|
||||
result : pd.DataFrame
|
||||
A frame of computed results.
|
||||
|
||||
The columns `result` correspond to the entries of
|
||||
`pipeline.columns`, which should be a dictionary mapping strings to
|
||||
instances of `zipline.pipeline.term.Term`.
|
||||
|
||||
For each date between `start_date` and `end_date`, `result` will
|
||||
contain a row for each asset that passed `pipeline.screen`. A
|
||||
screen of None indicates that a row should be returned for each
|
||||
asset that existed each day.
|
||||
|
||||
See Also
|
||||
--------
|
||||
:meth:`PipelineEngine.run_pipeline`
|
||||
"""
|
||||
raise NotImplementedError("run_chunked_pipeline")
|
||||
|
||||
|
||||
class NoEngineRegistered(Exception):
|
||||
"""
|
||||
@@ -80,6 +123,12 @@ class ExplodingPipelineEngine(PipelineEngine):
|
||||
"resources were registered."
|
||||
)
|
||||
|
||||
def run_chunked_pipeline(self, pipeline, start_date, end_date, chunksize):
|
||||
raise NoEngineRegistered(
|
||||
"Attempted to run a chunked pipeline but no pipeline "
|
||||
"resources were registered."
|
||||
)
|
||||
|
||||
|
||||
def default_populate_initial_workspace(initial_workspace,
|
||||
root_mask_term,
|
||||
@@ -114,7 +163,7 @@ def default_populate_initial_workspace(initial_workspace,
|
||||
return initial_workspace
|
||||
|
||||
|
||||
class SimplePipelineEngine(object):
|
||||
class SimplePipelineEngine(PipelineEngine):
|
||||
"""
|
||||
PipelineEngine class that computes each term independently.
|
||||
|
||||
@@ -146,7 +195,6 @@ class SimplePipelineEngine(object):
|
||||
'_root_mask_term',
|
||||
'_root_mask_dates_term',
|
||||
'_populate_initial_workspace',
|
||||
'__weakref__',
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
@@ -210,7 +258,8 @@ class SimplePipelineEngine(object):
|
||||
|
||||
See Also
|
||||
--------
|
||||
PipelineEngine.run_pipeline
|
||||
:meth:`PipelineEngine.run_pipeline`
|
||||
:meth:`PipelineEngine.run_chunked_pipeline`
|
||||
"""
|
||||
if end_date < start_date:
|
||||
raise ValueError(
|
||||
@@ -256,6 +305,18 @@ class SimplePipelineEngine(object):
|
||||
assets,
|
||||
)
|
||||
|
||||
@copydoc(PipelineEngine.run_chunked_pipeline)
|
||||
def run_chunked_pipeline(self, pipeline, start_date, end_date, chunksize):
|
||||
ranges = compute_date_range_chunks(
|
||||
self._calendar,
|
||||
start_date,
|
||||
end_date,
|
||||
chunksize,
|
||||
)
|
||||
chunks = [self.run_pipeline(pipeline, s, e) for s, e in ranges]
|
||||
|
||||
return categorical_df_concat(chunks, inplace=True)
|
||||
|
||||
def _compute_root_mask(self, start_date, end_date, extra_rows):
|
||||
"""
|
||||
Compute a lifetimes matrix from our AssetFinder, then drop columns that
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from itertools import repeat
|
||||
import os
|
||||
import sqlite3
|
||||
from unittest import TestCase
|
||||
@@ -1333,12 +1332,15 @@ class WithEquityPricingPipelineEngine(WithAdjustmentReader,
|
||||
cls.bcolz_equity_daily_bar_reader,
|
||||
SQLiteAdjustmentReader(cls.adjustments_db_path),
|
||||
)
|
||||
dispatcher = dict(
|
||||
zip(USEquityPricing.columns, repeat(loader))
|
||||
).__getitem__
|
||||
|
||||
def get_loader(column):
|
||||
if column in USEquityPricing.columns:
|
||||
return loader
|
||||
else:
|
||||
raise AssertionError("No loader registered for %s" % column)
|
||||
|
||||
cls.pipeline_engine = SimplePipelineEngine(
|
||||
get_loader=dispatcher,
|
||||
get_loader=get_loader,
|
||||
calendar=cls.nyse_sessions,
|
||||
asset_finder=cls.asset_finder,
|
||||
)
|
||||
|
||||
@@ -1,29 +1,6 @@
|
||||
from toolz import partition_all
|
||||
|
||||
|
||||
def roll_dates_to_previous_session(sessions, *dates):
|
||||
"""
|
||||
Roll `dates` to the last session of `calendar`. Return input date if it
|
||||
is a valid session.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sessions : pandas.tseries.index.DatetimeIndex
|
||||
The list of valid session dates.
|
||||
*dates : pd.Timestamp
|
||||
The dates for which the last trading date is needed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
rolled_dates: pandas.tseries.index.DatetimeIndex
|
||||
The last trading date of the input dates, inclusive.
|
||||
|
||||
"""
|
||||
# Find the previous index value if there is no exact match.
|
||||
locs = [sessions.get_loc(dt, method='ffill') for dt in dates]
|
||||
return sessions[locs].tolist()
|
||||
|
||||
|
||||
def compute_date_range_chunks(sessions, start_date, end_date, chunksize):
|
||||
"""Compute the start and end dates to run a pipeline for.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user