mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 00:50:45 +08:00
Merge pull request #1230 from quantopian/pipeline-example
DOC/TEST: Add example algo using Pipeline.
This commit is contained in:
Binary file not shown.
@@ -8,10 +8,33 @@ import click
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from zipline import examples, run_algorithm
|
||||
from zipline import examples
|
||||
from zipline.data.bundles import clean, ingest, register, yahoo_equities
|
||||
from zipline.testing import test_resource_path, tmp_dir
|
||||
from zipline.utils.cache import dataframe_cache
|
||||
from zipline.data.bundles import register
|
||||
|
||||
|
||||
INPUT_DATA_START_DATE = pd.Timestamp('2004-01-02')
|
||||
INPUT_DATA_END_DATE = pd.Timestamp('2014-12-31')
|
||||
INPUT_DATA_SYMBOLS = (
|
||||
'AMD',
|
||||
'CERN',
|
||||
'COST',
|
||||
'DELL',
|
||||
'GPS',
|
||||
'INTC',
|
||||
'MMM',
|
||||
'AAPL',
|
||||
'MSFT',
|
||||
)
|
||||
TEST_BUNDLE_NAME = 'test'
|
||||
input_bundle = yahoo_equities(
|
||||
INPUT_DATA_SYMBOLS,
|
||||
INPUT_DATA_START_DATE,
|
||||
INPUT_DATA_END_DATE,
|
||||
)
|
||||
register(TEST_BUNDLE_NAME, input_bundle)
|
||||
|
||||
|
||||
banner = """
|
||||
Please verify that the new performance is more correct than the old
|
||||
@@ -20,6 +43,13 @@ performance.
|
||||
To do this, please inspect `new` and `old` which are mappings from the name of
|
||||
the example to the results.
|
||||
|
||||
The name `cols_to_check` has been bound to a list of perf columns that we
|
||||
expect to be reliably deterministic (excluding, e.g. `orders`, which contains
|
||||
UUIDs).
|
||||
|
||||
Calling `changed_results(new, old)` will compute a list of names of results
|
||||
that produced a different value in one of the `cols_to_check` fields.
|
||||
|
||||
If you are sure that the new results are more correct, or that the difference
|
||||
is acceptable, please call `correct()`. Otherwise, call `incorrect()`.
|
||||
|
||||
@@ -29,28 +59,58 @@ Remember to run this with the other supported versions of pandas!
|
||||
"""
|
||||
|
||||
|
||||
def changed_results(new, old):
|
||||
"""
|
||||
Get the names of results that changed since the last invocation.
|
||||
|
||||
Useful for verifying that only expected results changed.
|
||||
"""
|
||||
changed = []
|
||||
for col in new:
|
||||
if col not in old:
|
||||
changed.append(col)
|
||||
continue
|
||||
try:
|
||||
pd.util.testing.assert_frame_equal(
|
||||
new[col][examples._cols_to_check],
|
||||
old[col][examples._cols_to_check],
|
||||
)
|
||||
except AssertionError:
|
||||
changed.append(col)
|
||||
return changed
|
||||
|
||||
|
||||
def eof(*args, **kwargs):
|
||||
raise EOFError()
|
||||
|
||||
|
||||
def rebuild_input_data(environ):
|
||||
ingest(TEST_BUNDLE_NAME, environ=environ, show_progress=True)
|
||||
clean(TEST_BUNDLE_NAME, keep_last=1, environ=environ)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option(
|
||||
'--rebuild-input',
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Should we rebuild the input data from Yahoo?",
|
||||
)
|
||||
@click.pass_context
|
||||
def main(ctx):
|
||||
def main(ctx, rebuild_input):
|
||||
"""Rebuild the perf data for test_examples
|
||||
"""
|
||||
example_path = test_resource_path('example_data.tar.gz')
|
||||
|
||||
register('test', lambda *args: None)
|
||||
|
||||
with tmp_dir() as d:
|
||||
with tarfile.open(example_path) as tar:
|
||||
tar.extractall(d.path)
|
||||
|
||||
mods = (
|
||||
(e, getattr(examples, e))
|
||||
for e in dir(examples)
|
||||
if not e.startswith('_')
|
||||
)
|
||||
# The environ here should be the same (modulo the tempdir location)
|
||||
# as we use in test_examples.py.
|
||||
environ = {'ZIPLINE_ROOT': d.getpath('example_data/root')}
|
||||
if rebuild_input:
|
||||
rebuild_input_data(environ)
|
||||
|
||||
new_perf_path = d.getpath(
|
||||
'example_data/new_perf/%s' % pd.__version__.replace('.', '-'),
|
||||
@@ -60,21 +120,8 @@ def main(ctx):
|
||||
serialization='pickle:2',
|
||||
)
|
||||
with c:
|
||||
for name, mod in mods:
|
||||
c[name] = run_algorithm(
|
||||
handle_data=mod.handle_data,
|
||||
initialize=mod.initialize,
|
||||
before_trading_start=getattr(
|
||||
mod, 'before_trading_start', None,
|
||||
),
|
||||
analyze=getattr(mod, 'analyze', None),
|
||||
bundle='test',
|
||||
environ={
|
||||
'ZIPLINE_ROOT': d.getpath('example_data/root'),
|
||||
},
|
||||
capital_base=1e7,
|
||||
**mod._test_args()
|
||||
)
|
||||
for name in examples.EXAMPLE_MODULES:
|
||||
c[name] = examples.run_example(name, environ=environ)
|
||||
|
||||
correct_called = [False]
|
||||
|
||||
@@ -105,6 +152,8 @@ def main(ctx):
|
||||
serialization='pickle',
|
||||
),
|
||||
'pd': pd,
|
||||
'cols_to_check': examples._cols_to_check,
|
||||
'changed_results': changed_results,
|
||||
})
|
||||
console.interact(banner)
|
||||
|
||||
|
||||
@@ -3936,3 +3936,33 @@ class TestOrderAfterDelist(WithTradingEnvironment, ZiplineTestCase):
|
||||
"asset will be liquidated on "
|
||||
"2016-01-11 00:00:00+00:00.",
|
||||
w.message)
|
||||
|
||||
|
||||
class AlgoInputValidationTestCase(ZiplineTestCase):
|
||||
|
||||
def test_reject_passing_both_api_methods_and_script(self):
|
||||
script = dedent(
|
||||
"""
|
||||
def initialize(context):
|
||||
pass
|
||||
|
||||
def handle_data(context, data):
|
||||
pass
|
||||
|
||||
def before_trading_start(context, data):
|
||||
pass
|
||||
|
||||
def analyze(context, results):
|
||||
pass
|
||||
"""
|
||||
)
|
||||
for method in ('initialize',
|
||||
'handle_data',
|
||||
'before_trading_start',
|
||||
'analyze'):
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
TradingAlgorithm(
|
||||
script=script,
|
||||
**{method: lambda *args, **kwargs: None}
|
||||
)
|
||||
|
||||
+24
-50
@@ -13,13 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
import gc
|
||||
import tarfile
|
||||
|
||||
import matplotlib
|
||||
from nose_parameterized import parameterized
|
||||
import pandas as pd
|
||||
|
||||
from zipline import examples, run_algorithm
|
||||
from zipline import examples
|
||||
from zipline.data.bundles import register, unregister
|
||||
from zipline.testing import test_resource_path
|
||||
from zipline.testing.fixtures import WithTmpDir, ZiplineTestCase
|
||||
@@ -34,42 +35,6 @@ matplotlib.use('Agg')
|
||||
|
||||
class ExamplesTests(WithTmpDir, ZiplineTestCase):
|
||||
# some columns contain values with unique ids that will not be the same
|
||||
cols_to_check = [
|
||||
'algo_volatility',
|
||||
'algorithm_period_return',
|
||||
'alpha',
|
||||
'benchmark_period_return',
|
||||
'benchmark_volatility',
|
||||
'beta',
|
||||
'capital_used',
|
||||
'ending_cash',
|
||||
'ending_exposure',
|
||||
'ending_value',
|
||||
'excess_return',
|
||||
'gross_leverage',
|
||||
'long_exposure',
|
||||
'long_value',
|
||||
'longs_count',
|
||||
'max_drawdown',
|
||||
'max_leverage',
|
||||
'net_leverage',
|
||||
'period_close',
|
||||
'period_label',
|
||||
'period_open',
|
||||
'pnl',
|
||||
'portfolio_value',
|
||||
'positions',
|
||||
'returns',
|
||||
'short_exposure',
|
||||
'short_value',
|
||||
'shorts_count',
|
||||
'sortino',
|
||||
'starting_cash',
|
||||
'starting_exposure',
|
||||
'starting_value',
|
||||
'trading_days',
|
||||
'treasury_period_return',
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def init_class_fixtures(cls):
|
||||
@@ -89,24 +54,33 @@ class ExamplesTests(WithTmpDir, ZiplineTestCase):
|
||||
serialization='pickle',
|
||||
)
|
||||
|
||||
@parameterized.expand(e for e in dir(examples) if not e.startswith('_'))
|
||||
def test_example(self, example):
|
||||
mod = getattr(examples, example)
|
||||
actual_perf = run_algorithm(
|
||||
handle_data=mod.handle_data,
|
||||
initialize=mod.initialize,
|
||||
before_trading_start=getattr(mod, 'before_trading_start', None),
|
||||
analyze=getattr(mod, 'analyze', None),
|
||||
bundle='test',
|
||||
# We need to call gc.collect before tearing down our class because we
|
||||
# have a cycle between TradingAlgorithm and AlgorithmSimulator which
|
||||
# ultimately holds a reference to the pipeline engine passed to the
|
||||
# tests here.
|
||||
|
||||
# This means that we're not guaranteed to have deleted our disk-backed
|
||||
# resource readers (e.g. SQLiteAdjustmentReader) before trying to
|
||||
# delete the tempdir, which causes failures on Windows because Windows
|
||||
# doesn't allow you to delete a file if someone still has an open
|
||||
# handle to that file.
|
||||
|
||||
# :(
|
||||
cls.add_class_callback(gc.collect)
|
||||
|
||||
@parameterized.expand(examples.EXAMPLE_MODULES)
|
||||
def test_example(self, example_name):
|
||||
actual_perf = examples.run_example(
|
||||
example_name,
|
||||
# This should match the invocation in
|
||||
# zipline/tests/resources/rebuild_example_data
|
||||
environ={
|
||||
'ZIPLINE_ROOT': self.tmpdir.getpath('example_data/root'),
|
||||
},
|
||||
capital_base=1e7,
|
||||
**mod._test_args()
|
||||
)
|
||||
assert_equal(
|
||||
actual_perf[self.cols_to_check],
|
||||
self.expected_perf[example][self.cols_to_check],
|
||||
actual_perf[examples._cols_to_check],
|
||||
self.expected_perf[example_name][examples._cols_to_check],
|
||||
# There is a difference in the datetime columns in pandas
|
||||
# 0.16 and 0.17 because in 16 they are object and in 17 they are
|
||||
# datetime[ns, UTC]. We will just ignore the dtypes for now.
|
||||
|
||||
+1
-1
@@ -355,7 +355,7 @@ def bundles():
|
||||
if not pth.hidden(ing)),
|
||||
reverse=True,
|
||||
)
|
||||
except IOError as e:
|
||||
except OSError as e:
|
||||
if e.errno != errno.ENOENT:
|
||||
raise
|
||||
ingestions = []
|
||||
|
||||
+34
-20
@@ -31,6 +31,7 @@ from six import (
|
||||
iteritems,
|
||||
itervalues,
|
||||
string_types,
|
||||
viewkeys,
|
||||
)
|
||||
|
||||
from zipline._protocol import handle_non_market_minutes
|
||||
@@ -81,7 +82,7 @@ from zipline.assets import Asset, Future
|
||||
from zipline.assets.futures import FutureChain
|
||||
from zipline.gens.tradesimulation import AlgorithmSimulator
|
||||
from zipline.pipeline.engine import (
|
||||
NoOpPipelineEngine,
|
||||
ExplodingPipelineEngine,
|
||||
SimplePipelineEngine,
|
||||
)
|
||||
from zipline.utils.api_support import (
|
||||
@@ -332,29 +333,46 @@ class TradingAlgorithm(object):
|
||||
|
||||
self._handle_data = None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
if self.algoscript is not None:
|
||||
api_methods = {
|
||||
'initialize',
|
||||
'handle_data',
|
||||
'before_trading_start',
|
||||
'analyze',
|
||||
}
|
||||
unexpected_api_methods = viewkeys(kwargs) & api_methods
|
||||
if unexpected_api_methods:
|
||||
raise ValueError(
|
||||
"TradingAlgorithm received a script and the following API"
|
||||
" methods as functions:\n{funcs}".format(
|
||||
funcs=unexpected_api_methods,
|
||||
)
|
||||
)
|
||||
|
||||
filename = kwargs.pop('algo_filename', None)
|
||||
if filename is None:
|
||||
filename = '<string>'
|
||||
code = compile(self.algoscript, filename, 'exec')
|
||||
exec_(code, self.namespace)
|
||||
self._initialize = self.namespace.get('initialize')
|
||||
if 'handle_data' in self.namespace:
|
||||
self._handle_data = self.namespace['handle_data']
|
||||
|
||||
self._before_trading_start = \
|
||||
self.namespace.get('before_trading_start')
|
||||
self._initialize = self.namespace.get('initialize', noop)
|
||||
self._handle_data = self.namespace.get('handle_data', noop)
|
||||
self._before_trading_start = self.namespace.get(
|
||||
'before_trading_start',
|
||||
)
|
||||
# Optional analyze function, gets called after run
|
||||
self._analyze = self.namespace.get('analyze')
|
||||
|
||||
elif kwargs.get('initialize') and kwargs.get('handle_data'):
|
||||
if self.algoscript is not None:
|
||||
raise ValueError('You can not set script and \
|
||||
initialize/handle_data.')
|
||||
self._initialize = kwargs.pop('initialize')
|
||||
self._handle_data = kwargs.pop('handle_data')
|
||||
self._before_trading_start = kwargs.pop('before_trading_start',
|
||||
None)
|
||||
else:
|
||||
self._initialize = kwargs.pop('initialize', noop)
|
||||
self._handle_data = kwargs.pop('handle_data', noop)
|
||||
self._before_trading_start = kwargs.pop(
|
||||
'before_trading_start',
|
||||
None,
|
||||
)
|
||||
self._analyze = kwargs.pop('analyze', None)
|
||||
|
||||
self.event_manager.add_event(
|
||||
@@ -367,10 +385,6 @@ class TradingAlgorithm(object):
|
||||
prepend=True,
|
||||
)
|
||||
|
||||
# If method not defined, NOOP
|
||||
if self._initialize is None:
|
||||
self._initialize = lambda x: None
|
||||
|
||||
# Alternative way of setting data_frequency for backwards
|
||||
# compatibility.
|
||||
if 'data_frequency' in kwargs:
|
||||
@@ -390,7 +404,7 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
Construct and store a PipelineEngine from loader.
|
||||
|
||||
If get_loader is None, constructs a NoOpPipelineEngine.
|
||||
If get_loader is None, constructs an ExplodingPipelineEngine
|
||||
"""
|
||||
if get_loader is not None:
|
||||
self.engine = SimplePipelineEngine(
|
||||
@@ -399,7 +413,7 @@ class TradingAlgorithm(object):
|
||||
self.asset_finder,
|
||||
)
|
||||
else:
|
||||
self.engine = NoOpPipelineEngine()
|
||||
self.engine = ExplodingPipelineEngine()
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# These imports are necessary to force module-scope register calls to happen.
|
||||
from . import quandl # noqa
|
||||
from .core import (
|
||||
UnknownBundle,
|
||||
|
||||
@@ -96,6 +96,7 @@ def fetch_symbol_metadata_frame(api_key,
|
||||
name: the full name of the asset
|
||||
start_date: the first date of data for this asset
|
||||
end_date: the last date of data for this asset
|
||||
auto_close_date: end_date + one day
|
||||
exchange: the exchange for the asset; this is always 'quandl'
|
||||
The index of the dataframe will be used for symbol->sid mappings but
|
||||
otherwise does not have specific meaning.
|
||||
@@ -119,6 +120,7 @@ def fetch_symbol_metadata_frame(api_key,
|
||||
# we need to escape the paren because it is actually splitting on a regex
|
||||
data.asset_name = data.asset_name.str.split(r' \(', 1).str.get(0)
|
||||
data['exchange'] = 'quandl'
|
||||
data['auto_close_date'] = data['end_date'] + pd.Timedelta(days=1)
|
||||
return data
|
||||
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ def yahoo_equities(symbols, start=None, end=None):
|
||||
metadata = pd.DataFrame(np.empty(len(symbols), dtype=[
|
||||
('start_date', 'datetime64[ns]'),
|
||||
('end_date', 'datetime64[ns]'),
|
||||
('auto_close_date', 'datetime64[ns]'),
|
||||
('symbol', 'object'),
|
||||
]))
|
||||
|
||||
@@ -99,7 +100,12 @@ def yahoo_equities(symbols, start=None, end=None):
|
||||
|
||||
# the start date is the date of the first trade and
|
||||
# the end date is the date of the last trade
|
||||
metadata.iloc[sid] = df.index[0], df.index[-1], symbol
|
||||
start_date = df.index[0]
|
||||
end_date = df.index[-1]
|
||||
# The auto_close date is the day after the last trade.
|
||||
ac_date = end_date + pd.Timedelta(days=1)
|
||||
metadata.iloc[sid] = start_date, end_date, ac_date, symbol
|
||||
|
||||
df.rename(
|
||||
columns={
|
||||
'Open': 'open',
|
||||
|
||||
@@ -1,17 +1,77 @@
|
||||
from glob import glob
|
||||
from importlib import import_module
|
||||
import os
|
||||
|
||||
from toolz import merge
|
||||
|
||||
from zipline import run_algorithm
|
||||
|
||||
|
||||
# These are used by test_examples.py to discover the examples to run.
|
||||
EXAMPLE_MODULES = {}
|
||||
for f in os.listdir(os.path.dirname(__file__)):
|
||||
if not f.endswith('.py') or f == '__init__.py':
|
||||
continue
|
||||
modname = f[:-len('.py')]
|
||||
globals()[modname] = import_module('.' + modname, package=__name__)
|
||||
mod = import_module('.' + modname, package=__name__)
|
||||
EXAMPLE_MODULES[modname] = mod
|
||||
globals()[modname] = mod
|
||||
|
||||
del f
|
||||
try:
|
||||
del modname
|
||||
except NameError:
|
||||
pass
|
||||
# Remove noise from loop variables.
|
||||
del f, modname, mod
|
||||
|
||||
del os, import_module, glob
|
||||
|
||||
# Columns that we expect to be able to reliably deterministic
|
||||
# Doesn't include fields that have UUIDS.
|
||||
_cols_to_check = [
|
||||
'algo_volatility',
|
||||
'algorithm_period_return',
|
||||
'alpha',
|
||||
'benchmark_period_return',
|
||||
'benchmark_volatility',
|
||||
'beta',
|
||||
'capital_used',
|
||||
'ending_cash',
|
||||
'ending_exposure',
|
||||
'ending_value',
|
||||
'excess_return',
|
||||
'gross_leverage',
|
||||
'long_exposure',
|
||||
'long_value',
|
||||
'longs_count',
|
||||
'max_drawdown',
|
||||
'max_leverage',
|
||||
'net_leverage',
|
||||
'period_close',
|
||||
'period_label',
|
||||
'period_open',
|
||||
'pnl',
|
||||
'portfolio_value',
|
||||
'positions',
|
||||
'returns',
|
||||
'short_exposure',
|
||||
'short_value',
|
||||
'shorts_count',
|
||||
'sortino',
|
||||
'starting_cash',
|
||||
'starting_exposure',
|
||||
'starting_value',
|
||||
'trading_days',
|
||||
'treasury_period_return',
|
||||
]
|
||||
|
||||
|
||||
def run_example(example_name, environ):
|
||||
"""
|
||||
Run an example module from zipline.examples.
|
||||
"""
|
||||
mod = EXAMPLE_MODULES[example_name]
|
||||
return run_algorithm(
|
||||
initialize=getattr(mod, 'initialize', None),
|
||||
handle_data=getattr(mod, 'handle_data', None),
|
||||
before_trading_start=getattr(mod, 'before_trading_start', None),
|
||||
analyze=getattr(mod, 'analyze', None),
|
||||
bundle='test',
|
||||
environ=environ,
|
||||
# Provide a default capital base, but allow the test to override.
|
||||
**merge({'capital_base': 1e7}, mod._test_args())
|
||||
)
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
A simple Pipeline algorithm that longs the top 3 stocks by RSI and shorts
|
||||
the bottom 3 each day.
|
||||
"""
|
||||
from six import viewkeys
|
||||
from zipline.api import (
|
||||
attach_pipeline,
|
||||
date_rules,
|
||||
order_target_percent,
|
||||
pipeline_output,
|
||||
record,
|
||||
schedule_function,
|
||||
)
|
||||
from zipline.pipeline import Pipeline
|
||||
from zipline.pipeline.factors import RSI
|
||||
|
||||
|
||||
def make_pipeline():
|
||||
rsi = RSI()
|
||||
return Pipeline(
|
||||
columns={
|
||||
'longs': rsi.top(3),
|
||||
'shorts': rsi.bottom(3),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def rebalance(context, data):
|
||||
|
||||
# Pipeline data will be a dataframe with boolean columns named 'longs' and
|
||||
# 'shorts'.
|
||||
pipeline_data = context.pipeline_data
|
||||
all_assets = pipeline_data.index
|
||||
|
||||
longs = all_assets[pipeline_data.longs]
|
||||
shorts = all_assets[pipeline_data.shorts]
|
||||
|
||||
record(universe_size=len(all_assets))
|
||||
|
||||
# Build a 2x-leveraged, equal-weight, long-short portfolio.
|
||||
one_third = 1.0 / 3.0
|
||||
for asset in longs:
|
||||
order_target_percent(asset, one_third)
|
||||
|
||||
for asset in shorts:
|
||||
order_target_percent(asset, -one_third)
|
||||
|
||||
# Remove any assets that should no longer be in our portfolio.
|
||||
portfolio_assets = longs | shorts
|
||||
positions = context.portfolio.positions
|
||||
for asset in viewkeys(positions) - set(portfolio_assets):
|
||||
# This will fail if the asset was removed from our portfolio because it
|
||||
# was delisted.
|
||||
if data.can_trade(asset):
|
||||
order_target_percent(asset, 0)
|
||||
|
||||
|
||||
def initialize(context):
|
||||
attach_pipeline(make_pipeline(), 'my_pipeline')
|
||||
|
||||
# Rebalance each day. In daily mode, this is equivalent to putting
|
||||
# `rebalance` in our handle_data, but in minute mode, it's equivalent to
|
||||
# running at the start of the day each day.
|
||||
schedule_function(rebalance, date_rules.every_day())
|
||||
|
||||
|
||||
def before_trading_start(context, data):
|
||||
context.pipeline_data = pipeline_output('my_pipeline')
|
||||
|
||||
|
||||
def _test_args():
|
||||
"""
|
||||
Extra arguments to use when zipline's automated tests run this example.
|
||||
|
||||
Notes for testers:
|
||||
|
||||
Gross leverage should be roughly 2.0 on every day except the first.
|
||||
Net leverage should be roughly 2.0 on every day except the first.
|
||||
|
||||
Longs Count should always be 3 after the first day.
|
||||
Shorts Count should be 3 after the first day, except on 2013-10-30, when it
|
||||
dips to 2 for a day because DELL is delisted.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
return {
|
||||
# We run through october of 2013 because DELL is in the test data and
|
||||
# it went private on 2013-10-29.
|
||||
'start': pd.Timestamp('2013-10-07', tz='utc'),
|
||||
'end': pd.Timestamp('2013-11-30', tz='utc'),
|
||||
'capital_base': 100000,
|
||||
}
|
||||
+12
-11
@@ -12,11 +12,7 @@ from six import (
|
||||
with_metaclass,
|
||||
)
|
||||
from numpy import array
|
||||
from pandas import (
|
||||
DataFrame,
|
||||
date_range,
|
||||
MultiIndex,
|
||||
)
|
||||
from pandas import DataFrame, MultiIndex
|
||||
from toolz import groupby, juxt
|
||||
from toolz.curried.operator import getitem
|
||||
|
||||
@@ -63,16 +59,21 @@ class PipelineEngine(with_metaclass(ABCMeta)):
|
||||
raise NotImplementedError("run_pipeline")
|
||||
|
||||
|
||||
class NoOpPipelineEngine(PipelineEngine):
|
||||
class NoEngineRegistered(Exception):
|
||||
"""
|
||||
Raised if a user tries to call pipeline_output in an algorithm that hasn't
|
||||
set up a pipeline engine.
|
||||
"""
|
||||
|
||||
|
||||
class ExplodingPipelineEngine(PipelineEngine):
|
||||
"""
|
||||
A PipelineEngine that doesn't do anything.
|
||||
"""
|
||||
def run_pipeline(self, pipeline, start_date, end_date):
|
||||
return DataFrame(
|
||||
index=MultiIndex.from_product(
|
||||
[date_range(start=start_date, end=end_date, freq='D'), ()],
|
||||
),
|
||||
columns=sorted(pipeline.columns.keys()),
|
||||
raise NoEngineRegistered(
|
||||
"Attempted to run a pipeline but no pipeline "
|
||||
"resources were registered."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -107,13 +107,15 @@ class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
|
||||
@final
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls._class_teardown_stack.close()
|
||||
# We need to get this before it's deleted by the loop.
|
||||
stack = cls._class_teardown_stack
|
||||
for name in set(vars(cls)) - cls._static_class_attributes:
|
||||
# Remove all of the attributes that were added after the class was
|
||||
# constructed. This cleans up any large test data that is class
|
||||
# scoped while still allowing subclasses to access class level
|
||||
# attributes.
|
||||
delattr(cls, name)
|
||||
stack.close()
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
@@ -171,9 +173,11 @@ class ZiplineTestCase(with_metaclass(FinalMeta, TestCase)):
|
||||
|
||||
@final
|
||||
def tearDown(self):
|
||||
self._instance_teardown_stack.close()
|
||||
# We need to get this before it's deleted by the loop.
|
||||
stack = self._instance_teardown_stack
|
||||
for attr in set(vars(self)) - self._pre_setup_attrs:
|
||||
delattr(self, attr)
|
||||
stack.close()
|
||||
|
||||
@final
|
||||
def enter_instance_context(self, context_manager):
|
||||
|
||||
@@ -18,6 +18,8 @@ from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.data.bundles.core import load
|
||||
from zipline.data.data_portal import DataPortal
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.pipeline.data import USEquityPricing
|
||||
from zipline.pipeline.loaders import USEquityPricingLoader
|
||||
import zipline.utils.paths as pth
|
||||
|
||||
|
||||
@@ -133,12 +135,25 @@ def _run(handle_data,
|
||||
adjustment_reader=bundle_data.adjustment_reader,
|
||||
)
|
||||
|
||||
pipeline_loader = USEquityPricingLoader(
|
||||
bundle_data.daily_bar_reader,
|
||||
bundle_data.adjustment_reader,
|
||||
)
|
||||
|
||||
def choose_loader(column):
|
||||
if column in USEquityPricing.columns:
|
||||
return pipeline_loader
|
||||
raise ValueError(
|
||||
"No PipelineLoader registered for column %s." % column
|
||||
)
|
||||
|
||||
perf = TradingAlgorithm(
|
||||
namespace=namespace,
|
||||
capital_base=capital_base,
|
||||
start=start,
|
||||
end=end,
|
||||
env=env,
|
||||
get_pipeline_loader=choose_loader,
|
||||
**{
|
||||
'initialize': initialize,
|
||||
'handle_data': handle_data,
|
||||
|
||||
Reference in New Issue
Block a user