mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 03:40:51 +08:00
TST: Ensure batch_order_target_percent orders like order_target_percent
This commit is contained in:
@@ -98,6 +98,7 @@ from zipline.testing import (
|
||||
to_utc,
|
||||
trades_by_sid_to_dfs,
|
||||
)
|
||||
from zipline.testing import RecordBatchBlotter
|
||||
from zipline.testing.fixtures import (
|
||||
WithDataPortal,
|
||||
WithLogger,
|
||||
@@ -171,6 +172,7 @@ from zipline.test_algorithms import (
|
||||
set_benchmark_algo,
|
||||
no_handle_data,
|
||||
)
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
|
||||
from zipline.utils.calendars import get_calendar, register_calendar
|
||||
from zipline.utils.context_tricks import CallbackManager
|
||||
@@ -1737,6 +1739,75 @@ def handle_data(context, data):
|
||||
)
|
||||
test_algo.run(self.data_portal)
|
||||
|
||||
def test_batch_order_target_percent_matches_multi_order(self):
|
||||
weights = pd.Series([.3, .7])
|
||||
|
||||
multi_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
|
||||
self.asset_finder)
|
||||
multi_test_algo = TradingAlgorithm(
|
||||
script=dedent("""\
|
||||
from collections import OrderedDict
|
||||
from six import iteritems
|
||||
|
||||
from zipline.api import sid, order_target_percent
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.assets = [sid(0), sid(3)]
|
||||
context.placed = False
|
||||
|
||||
def handle_data(context, data):
|
||||
if not context.placed:
|
||||
for asset, weight in iteritems(OrderedDict(zip(
|
||||
context.assets, {weights}
|
||||
))):
|
||||
order_target_percent(asset, weight)
|
||||
|
||||
context.placed = True
|
||||
|
||||
""").format(weights=list(weights)),
|
||||
blotter=multi_blotter,
|
||||
env=self.env,
|
||||
)
|
||||
multi_stats = multi_test_algo.run(self.data_portal)
|
||||
self.assertFalse(multi_blotter.order_batch_called)
|
||||
|
||||
batch_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
|
||||
self.asset_finder)
|
||||
batch_test_algo = TradingAlgorithm(
|
||||
script=dedent("""\
|
||||
from collections import OrderedDict
|
||||
|
||||
from zipline.api import sid, batch_order_target_percent
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.assets = [sid(0), sid(3)]
|
||||
context.placed = False
|
||||
|
||||
def handle_data(context, data):
|
||||
if not context.placed:
|
||||
batch_order_target_percent(OrderedDict(zip(
|
||||
context.assets, {weights}
|
||||
)))
|
||||
context.placed = True
|
||||
|
||||
""").format(weights=list(weights)),
|
||||
blotter=batch_blotter,
|
||||
env=self.env,
|
||||
)
|
||||
batch_stats = batch_test_algo.run(self.data_portal)
|
||||
self.assertTrue(batch_blotter.order_batch_called)
|
||||
|
||||
for stats in (multi_stats, batch_stats):
|
||||
stats.orders = stats.orders.apply(
|
||||
lambda orders: [toolz.dissoc(o, 'id') for o in orders]
|
||||
)
|
||||
stats.transactions = stats.transactions.apply(
|
||||
lambda txns: [toolz.dissoc(txn, 'order_id') for txn in txns]
|
||||
)
|
||||
assert_equal(multi_stats, batch_stats)
|
||||
|
||||
def test_order_dead_asset(self):
|
||||
# after asset 0 is dead
|
||||
params = SimulationParameters(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .core import ( # noqa
|
||||
FetcherDataPortal,
|
||||
MockDailyBarReader,
|
||||
OpenPrice,
|
||||
RecordBatchBlotter,
|
||||
add_security_data,
|
||||
all_pairs_matching_predicate,
|
||||
all_subindices,
|
||||
|
||||
@@ -39,6 +39,7 @@ from zipline.data.us_equity_pricing import (
|
||||
BcolzDailyBarWriter,
|
||||
SQLiteAdjustmentWriter,
|
||||
)
|
||||
from zipline.finance.blotter import Blotter
|
||||
from zipline.finance.trading import TradingEnvironment
|
||||
from zipline.finance.order import ORDER_STATUS
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
@@ -1502,6 +1503,20 @@ def ensure_doctest(f, name=None):
|
||||
return f
|
||||
|
||||
|
||||
class RecordBatchBlotter(Blotter):
|
||||
"""Blotter that tracks how its batch_order method was called.
|
||||
"""
|
||||
def __init__(self, data_frequency, asset_finder):
|
||||
super(RecordBatchBlotter, self).__init__(
|
||||
data_frequency, asset_finder,
|
||||
)
|
||||
self.order_batch_called = []
|
||||
|
||||
def batch_order(self, *args, **kwargs):
|
||||
self.order_batch_called.append((args, kwargs))
|
||||
return super(RecordBatchBlotter, self).batch_order(*args, **kwargs)
|
||||
|
||||
|
||||
####################################
|
||||
# Shared factors for pipeline tests.
|
||||
####################################
|
||||
|
||||
Reference in New Issue
Block a user