TST: Ensure batch_order_target_percent orders like order_target_percent

This commit is contained in:
Richard Frank
2016-12-20 18:56:40 -05:00
parent 1cb85b70f2
commit 581e827208
3 changed files with 87 additions and 0 deletions
+71
View File
@@ -98,6 +98,7 @@ from zipline.testing import (
to_utc,
trades_by_sid_to_dfs,
)
from zipline.testing import RecordBatchBlotter
from zipline.testing.fixtures import (
WithDataPortal,
WithLogger,
@@ -171,6 +172,7 @@ from zipline.test_algorithms import (
set_benchmark_algo,
no_handle_data,
)
from zipline.testing.predicates import assert_equal
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
from zipline.utils.calendars import get_calendar, register_calendar
from zipline.utils.context_tricks import CallbackManager
@@ -1737,6 +1739,75 @@ def handle_data(context, data):
)
test_algo.run(self.data_portal)
def test_batch_order_target_percent_matches_multi_order(self):
weights = pd.Series([.3, .7])
multi_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
self.asset_finder)
multi_test_algo = TradingAlgorithm(
script=dedent("""\
from collections import OrderedDict
from six import iteritems
from zipline.api import sid, order_target_percent
def initialize(context):
context.assets = [sid(0), sid(3)]
context.placed = False
def handle_data(context, data):
if not context.placed:
for asset, weight in iteritems(OrderedDict(zip(
context.assets, {weights}
))):
order_target_percent(asset, weight)
context.placed = True
""").format(weights=list(weights)),
blotter=multi_blotter,
env=self.env,
)
multi_stats = multi_test_algo.run(self.data_portal)
self.assertFalse(multi_blotter.order_batch_called)
batch_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
self.asset_finder)
batch_test_algo = TradingAlgorithm(
script=dedent("""\
from collections import OrderedDict
from zipline.api import sid, batch_order_target_percent
def initialize(context):
context.assets = [sid(0), sid(3)]
context.placed = False
def handle_data(context, data):
if not context.placed:
batch_order_target_percent(OrderedDict(zip(
context.assets, {weights}
)))
context.placed = True
""").format(weights=list(weights)),
blotter=batch_blotter,
env=self.env,
)
batch_stats = batch_test_algo.run(self.data_portal)
self.assertTrue(batch_blotter.order_batch_called)
for stats in (multi_stats, batch_stats):
stats.orders = stats.orders.apply(
lambda orders: [toolz.dissoc(o, 'id') for o in orders]
)
stats.transactions = stats.transactions.apply(
lambda txns: [toolz.dissoc(txn, 'order_id') for txn in txns]
)
assert_equal(multi_stats, batch_stats)
def test_order_dead_asset(self):
# after asset 0 is dead
params = SimulationParameters(
+1
View File
@@ -7,6 +7,7 @@ from .core import ( # noqa
FetcherDataPortal,
MockDailyBarReader,
OpenPrice,
RecordBatchBlotter,
add_security_data,
all_pairs_matching_predicate,
all_subindices,
+15
View File
@@ -39,6 +39,7 @@ from zipline.data.us_equity_pricing import (
BcolzDailyBarWriter,
SQLiteAdjustmentWriter,
)
from zipline.finance.blotter import Blotter
from zipline.finance.trading import TradingEnvironment
from zipline.finance.order import ORDER_STATUS
from zipline.lib.labelarray import LabelArray
@@ -1502,6 +1503,20 @@ def ensure_doctest(f, name=None):
return f
class RecordBatchBlotter(Blotter):
"""Blotter that tracks how its batch_order method was called.
"""
def __init__(self, data_frequency, asset_finder):
super(RecordBatchBlotter, self).__init__(
data_frequency, asset_finder,
)
self.order_batch_called = []
def batch_order(self, *args, **kwargs):
self.order_batch_called.append((args, kwargs))
return super(RecordBatchBlotter, self).batch_order(*args, **kwargs)
####################################
# Shared factors for pipeline tests.
####################################