From 581e8272082a2b62f247c50c8a72b86e3efdce5b Mon Sep 17 00:00:00 2001 From: Richard Frank Date: Tue, 20 Dec 2016 18:56:40 -0500 Subject: [PATCH] TST: Ensure batch_order_target_percent orders like order_target_percent --- tests/test_algorithm.py | 71 +++++++++++++++++++++++++++++++++++++ zipline/testing/__init__.py | 1 + zipline/testing/core.py | 15 ++++++++ 3 files changed, 87 insertions(+) diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 85a7a65f..309d96a5 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -98,6 +98,7 @@ from zipline.testing import ( to_utc, trades_by_sid_to_dfs, ) +from zipline.testing import RecordBatchBlotter from zipline.testing.fixtures import ( WithDataPortal, WithLogger, @@ -171,6 +172,7 @@ from zipline.test_algorithms import ( set_benchmark_algo, no_handle_data, ) +from zipline.testing.predicates import assert_equal from zipline.utils.api_support import ZiplineAPI, set_algo_instance from zipline.utils.calendars import get_calendar, register_calendar from zipline.utils.context_tricks import CallbackManager @@ -1737,6 +1739,75 @@ def handle_data(context, data): ) test_algo.run(self.data_portal) + def test_batch_order_target_percent_matches_multi_order(self): + weights = pd.Series([.3, .7]) + + multi_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY, + self.asset_finder) + multi_test_algo = TradingAlgorithm( + script=dedent("""\ + from collections import OrderedDict + from six import iteritems + + from zipline.api import sid, order_target_percent + + + def initialize(context): + context.assets = [sid(0), sid(3)] + context.placed = False + + def handle_data(context, data): + if not context.placed: + for asset, weight in iteritems(OrderedDict(zip( + context.assets, {weights} + ))): + order_target_percent(asset, weight) + + context.placed = True + + """).format(weights=list(weights)), + blotter=multi_blotter, + env=self.env, + ) + multi_stats = multi_test_algo.run(self.data_portal) + self.assertFalse(multi_blotter.order_batch_called) + + batch_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY, + self.asset_finder) + batch_test_algo = TradingAlgorithm( + script=dedent("""\ + from collections import OrderedDict + + from zipline.api import sid, batch_order_target_percent + + + def initialize(context): + context.assets = [sid(0), sid(3)] + context.placed = False + + def handle_data(context, data): + if not context.placed: + batch_order_target_percent(OrderedDict(zip( + context.assets, {weights} + ))) + context.placed = True + + """).format(weights=list(weights)), + blotter=batch_blotter, + env=self.env, + ) + batch_stats = batch_test_algo.run(self.data_portal) + self.assertTrue(batch_blotter.order_batch_called) + + for stats in (multi_stats, batch_stats): + stats.orders = stats.orders.apply( + lambda orders: [toolz.dissoc(o, 'id') for o in orders] + ) + stats.transactions = stats.transactions.apply( + lambda txns: [toolz.dissoc(txn, 'order_id') for txn in txns] + ) + assert_equal(multi_stats, batch_stats) + def test_order_dead_asset(self): # after asset 0 is dead params = SimulationParameters( diff --git a/zipline/testing/__init__.py b/zipline/testing/__init__.py index ee2aad9d..1a5449b0 100644 --- a/zipline/testing/__init__.py +++ b/zipline/testing/__init__.py @@ -7,6 +7,7 @@ from .core import ( # noqa FetcherDataPortal, MockDailyBarReader, OpenPrice, + RecordBatchBlotter, add_security_data, all_pairs_matching_predicate, all_subindices, diff --git a/zipline/testing/core.py b/zipline/testing/core.py index 24821d93..faaa2857 100644 --- a/zipline/testing/core.py +++ b/zipline/testing/core.py @@ -39,6 +39,7 @@ from zipline.data.us_equity_pricing import ( BcolzDailyBarWriter, SQLiteAdjustmentWriter, ) +from zipline.finance.blotter import Blotter from zipline.finance.trading import TradingEnvironment from zipline.finance.order import ORDER_STATUS from zipline.lib.labelarray import LabelArray @@ -1502,6 +1503,20 @@ def ensure_doctest(f, name=None): return f +class RecordBatchBlotter(Blotter): + """Blotter that tracks how its batch_order method was called. + """ + def __init__(self, data_frequency, asset_finder): + super(RecordBatchBlotter, self).__init__( + data_frequency, asset_finder, + ) + self.order_batch_called = [] + + def batch_order(self, *args, **kwargs): + self.order_batch_called.append((args, kwargs)) + return super(RecordBatchBlotter, self).batch_order(*args, **kwargs) + + #################################### # Shared factors for pipeline tests. ####################################