mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 22:52:02 +08:00
MAINT: Don't name pipelines.
`Pipeline()` no longer takes a name. `attacH_pipeline` now takes a name. This is mainly for uniformity with how `Factors` and `Filters` are handled.
This commit is contained in:
@@ -117,7 +117,7 @@ class ConstantInputTestCase(TestCase):
|
||||
loader = self.loader
|
||||
engine = SimplePipelineEngine(loader, self.dates, self.asset_finder)
|
||||
|
||||
p = Pipeline('test')
|
||||
p = Pipeline()
|
||||
|
||||
msg = "start_date must be before end_date .*"
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
@@ -135,7 +135,7 @@ class ConstantInputTestCase(TestCase):
|
||||
|
||||
factor = AssetID()
|
||||
for asset in assets:
|
||||
p = Pipeline('test', columns={'f': factor}, screen=factor <= asset)
|
||||
p = Pipeline(columns={'f': factor}, screen=factor <= asset)
|
||||
result = engine.run_pipeline(p, dates[0], dates[-1])
|
||||
|
||||
expected_sids = assets[assets <= asset]
|
||||
@@ -161,9 +161,8 @@ class ConstantInputTestCase(TestCase):
|
||||
|
||||
# Since every asset will pass the screen, these should be equivalent.
|
||||
pipelines = [
|
||||
Pipeline('test', columns={'f': factor}),
|
||||
Pipeline(columns={'f': factor}),
|
||||
Pipeline(
|
||||
'test',
|
||||
columns={'f': factor},
|
||||
screen=factor.eq(expected_result),
|
||||
),
|
||||
@@ -198,7 +197,6 @@ class ConstantInputTestCase(TestCase):
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
'test',
|
||||
columns={
|
||||
'short': short_factor,
|
||||
'long': long_factor,
|
||||
@@ -242,7 +240,6 @@ class ConstantInputTestCase(TestCase):
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline(
|
||||
'test',
|
||||
columns={
|
||||
'high_low': high_minus_low,
|
||||
'open_close': open_minus_close,
|
||||
@@ -373,7 +370,6 @@ class FrameInputTestCase(TestCase):
|
||||
for start, stop in bounds:
|
||||
results = engine.run_pipeline(
|
||||
Pipeline(
|
||||
'test',
|
||||
columns={'low': low_mavg, 'high': high_mavg}
|
||||
),
|
||||
dates[start],
|
||||
@@ -488,7 +484,7 @@ class SyntheticBcolzTestCase(TestCase):
|
||||
)
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline('test', columns={'sma': SMA}),
|
||||
Pipeline(columns={'sma': SMA}),
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1],
|
||||
)
|
||||
@@ -540,7 +536,7 @@ class SyntheticBcolzTestCase(TestCase):
|
||||
)
|
||||
|
||||
results = engine.run_pipeline(
|
||||
Pipeline('test', columns={'drawdown': drawdown}),
|
||||
Pipeline(columns={'drawdown': drawdown}),
|
||||
dates_to_test[0],
|
||||
dates_to_test[-1],
|
||||
)
|
||||
@@ -594,7 +590,6 @@ class MultiColumnLoaderTestCase(TestCase):
|
||||
|
||||
result = engine.run_pipeline(
|
||||
Pipeline(
|
||||
'test',
|
||||
columns={
|
||||
'sumdiff': sumdiff,
|
||||
'open': open_.latest,
|
||||
|
||||
@@ -30,21 +30,20 @@ class SomeOtherFilter(Filter):
|
||||
class PipelineTestCase(TestCase):
|
||||
|
||||
def test_construction(self):
|
||||
p0 = Pipeline('arglebargle')
|
||||
self.assertEqual(p0.name, 'arglebargle')
|
||||
p0 = Pipeline()
|
||||
self.assertEqual(p0.columns, {})
|
||||
self.assertIs(p0.screen, None)
|
||||
|
||||
columns = {'f': SomeFactor()}
|
||||
p1 = Pipeline('test', columns=columns)
|
||||
p1 = Pipeline(columns=columns)
|
||||
self.assertEqual(p1.columns, columns)
|
||||
|
||||
screen = SomeFilter()
|
||||
p2 = Pipeline('test', screen=screen)
|
||||
p2 = Pipeline(screen=screen)
|
||||
self.assertEqual(p2.columns, {})
|
||||
self.assertEqual(p2.screen, screen)
|
||||
|
||||
p3 = Pipeline('test', columns=columns, screen=screen)
|
||||
p3 = Pipeline(columns=columns, screen=screen)
|
||||
self.assertEqual(p3.columns, columns)
|
||||
self.assertEqual(p3.screen, screen)
|
||||
|
||||
@@ -53,21 +52,18 @@ class PipelineTestCase(TestCase):
|
||||
with self.assertRaises(TypeError):
|
||||
Pipeline(1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
Pipeline('test', 1)
|
||||
|
||||
Pipeline('test', {})
|
||||
Pipeline({})
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
Pipeline('test', {}, 1)
|
||||
Pipeline({}, 1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
Pipeline('test', {}, SomeFactor())
|
||||
Pipeline({}, SomeFactor())
|
||||
|
||||
Pipeline('test', {}, SomeFactor() > 5)
|
||||
Pipeline({}, SomeFactor() > 5)
|
||||
|
||||
def test_add(self):
|
||||
p = Pipeline('test')
|
||||
p = Pipeline()
|
||||
f = SomeFactor()
|
||||
|
||||
p.add(f, 'f')
|
||||
@@ -80,7 +76,7 @@ class PipelineTestCase(TestCase):
|
||||
p.add(f, 1)
|
||||
|
||||
def test_overwrite(self):
|
||||
p = Pipeline('test')
|
||||
p = Pipeline()
|
||||
f = SomeFactor()
|
||||
other_f = SomeOtherFactor()
|
||||
|
||||
@@ -97,7 +93,7 @@ class PipelineTestCase(TestCase):
|
||||
|
||||
def test_remove(self):
|
||||
f = SomeFactor()
|
||||
p = Pipeline('test', columns={'f': f})
|
||||
p = Pipeline(columns={'f': f})
|
||||
|
||||
with self.assertRaises(KeyError) as e:
|
||||
p.remove('not_a_real_name')
|
||||
@@ -112,7 +108,7 @@ class PipelineTestCase(TestCase):
|
||||
def test_set_screen(self):
|
||||
f, g = SomeFilter(), SomeOtherFilter()
|
||||
|
||||
p = Pipeline('test')
|
||||
p = Pipeline()
|
||||
self.assertEqual(p.screen, None)
|
||||
|
||||
p.set_screen(f)
|
||||
|
||||
@@ -173,7 +173,7 @@ class ClosesOnly(TestCase):
|
||||
pass
|
||||
|
||||
def late_attach(context, data):
|
||||
attach_pipeline(Pipeline('test'))
|
||||
attach_pipeline(Pipeline(), 'test')
|
||||
raise AssertionError("Shouldn't make it past attach_pipeline!")
|
||||
|
||||
algo = TradingAlgorithm(
|
||||
@@ -211,7 +211,7 @@ class ClosesOnly(TestCase):
|
||||
Assert that calling pipeline_output after initialize raises correctly.
|
||||
"""
|
||||
def initialize(context):
|
||||
attach_pipeline(Pipeline('test'))
|
||||
attach_pipeline(Pipeline(), 'test')
|
||||
pipeline_output('test')
|
||||
raise AssertionError("Shouldn't make it past pipeline_output()")
|
||||
|
||||
@@ -240,7 +240,7 @@ class ClosesOnly(TestCase):
|
||||
Assert that calling add_pipeline after initialize raises appropriately.
|
||||
"""
|
||||
def initialize(context):
|
||||
attach_pipeline(Pipeline('test'))
|
||||
attach_pipeline(Pipeline(), 'test')
|
||||
|
||||
def handle_data(context, data):
|
||||
raise AssertionError("Shouldn't make it past before_trading_start")
|
||||
@@ -269,11 +269,9 @@ class ClosesOnly(TestCase):
|
||||
correctly-adjusted close price values.
|
||||
"""
|
||||
def initialize(context):
|
||||
p = Pipeline('test')
|
||||
p = attach_pipeline(Pipeline(), 'test')
|
||||
p.add(USEquityPricing.close.latest, 'close')
|
||||
|
||||
attach_pipeline(p)
|
||||
|
||||
def handle_data(context, data):
|
||||
results = pipeline_output('test')
|
||||
date = get_datetime().normalize()
|
||||
@@ -477,7 +475,7 @@ class PipelineAlgorithmTestCase(TestCase):
|
||||
return "vwap_%d" % length
|
||||
|
||||
def initialize(context):
|
||||
pipeline = Pipeline('test')
|
||||
pipeline = Pipeline()
|
||||
context.vwaps = []
|
||||
for length in vwaps:
|
||||
name = vwap_key(length)
|
||||
@@ -490,7 +488,7 @@ class PipelineAlgorithmTestCase(TestCase):
|
||||
if set_screen:
|
||||
pipeline.set_screen(filter_)
|
||||
|
||||
attach_pipeline(pipeline)
|
||||
attach_pipeline(pipeline, 'test')
|
||||
|
||||
def handle_data(context, data):
|
||||
today = get_datetime()
|
||||
|
||||
+12
-10
@@ -231,7 +231,7 @@ class TradingAlgorithm(object):
|
||||
|
||||
# Initialize Pipeline API data.
|
||||
self.init_engine(kwargs.pop('pipeline_loader', None))
|
||||
self._pipelines = []
|
||||
self._pipelines = {}
|
||||
# Create an always-expired cache so that we compute the first time data
|
||||
# is requested.
|
||||
self._pipeline_cache = CachedObject(None, pd.Timestamp(0, tz='UTC'))
|
||||
@@ -1337,13 +1337,17 @@ class TradingAlgorithm(object):
|
||||
##############
|
||||
@api_method
|
||||
@require_not_initialized(AttachPipelineAfterInitialize())
|
||||
def attach_pipeline(self, pipeline):
|
||||
def attach_pipeline(self, pipeline, name):
|
||||
"""
|
||||
Register a pipeline to be computed at the start of each day.
|
||||
"""
|
||||
if self._pipelines:
|
||||
raise NotImplementedError("Multiple pipelines are not supported.")
|
||||
self._pipelines.append(pipeline)
|
||||
self._pipelines[name] = pipeline
|
||||
|
||||
# Return the pipeline to allow expressions like
|
||||
# p = attach_pipeline(Pipeline(), 'name')
|
||||
return pipeline
|
||||
|
||||
@api_method
|
||||
@require_initialized(PipelineOutputDuringInitialize())
|
||||
@@ -1353,7 +1357,7 @@ class TradingAlgorithm(object):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str or None
|
||||
name : str
|
||||
Name of the pipeline for which results are requested.
|
||||
|
||||
Returns
|
||||
@@ -1373,14 +1377,12 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
# NOTE: We don't currently support multiple pipelines, but we plan to
|
||||
# in the future.
|
||||
for p in self._pipelines:
|
||||
if p.name == name:
|
||||
break
|
||||
# This is a for-else block. Yes, that's a thing in Python.
|
||||
else:
|
||||
try:
|
||||
p = self._pipelines[name]
|
||||
except KeyError:
|
||||
raise NoSuchPipeline(
|
||||
name=name,
|
||||
valid=[p.name for p in self._pipelines],
|
||||
valid=list(self._pipelines.keys()),
|
||||
)
|
||||
return self._pipeline_output(p)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ from .graph import TermGraph
|
||||
|
||||
class Pipeline(object):
|
||||
"""
|
||||
A computational Pipeline for use in trading algorithms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str, optional
|
||||
Name for this pipeline.
|
||||
columns : dict, optional
|
||||
Initial columns.
|
||||
screen : zipline.pipeline.term.Filter, optional
|
||||
@@ -20,34 +20,26 @@ class Pipeline(object):
|
||||
-------
|
||||
add
|
||||
remove
|
||||
apply_screen
|
||||
set_screen
|
||||
|
||||
Attributes
|
||||
----------
|
||||
columns
|
||||
screen
|
||||
"""
|
||||
__slots__ = ('_name', '_columns', '_screen', '__weakref__')
|
||||
__slots__ = ('_columns', '_screen', '__weakref__')
|
||||
|
||||
@expect_types(
|
||||
name=str,
|
||||
columns=optional(dict),
|
||||
screen=optional(Filter),
|
||||
)
|
||||
def __init__(self, name, columns=None, screen=None):
|
||||
self._name = name
|
||||
def __init__(self, columns=None, screen=None):
|
||||
|
||||
if columns is None:
|
||||
columns = {}
|
||||
self._columns = columns
|
||||
self._screen = screen
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""
|
||||
The name of this pipeline.
|
||||
"""
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
"""
|
||||
@@ -114,10 +106,7 @@ class Pipeline(object):
|
||||
@expect_types(screen=Filter)
|
||||
def set_screen(self, screen, overwrite=False):
|
||||
"""
|
||||
Apply a screen to this Pipeline.
|
||||
|
||||
If no screen has yet been applied to the pipeline, this method sets
|
||||
`screen` as the current screen.
|
||||
Set a screen on this Pipeline.
|
||||
|
||||
Parameter
|
||||
---------
|
||||
|
||||
Reference in New Issue
Block a user