mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 19:47:43 +08:00
TST: Ensure batch_order_target_percent orders like order_target_percent
This commit is contained in:
@@ -98,6 +98,7 @@ from zipline.testing import (
|
||||
to_utc,
|
||||
trades_by_sid_to_dfs,
|
||||
)
|
||||
from zipline.testing import RecordBatchBlotter
|
||||
from zipline.testing.fixtures import (
|
||||
WithDataPortal,
|
||||
WithLogger,
|
||||
@@ -171,6 +172,7 @@ from zipline.test_algorithms import (
|
||||
set_benchmark_algo,
|
||||
no_handle_data,
|
||||
)
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
|
||||
from zipline.utils.calendars import get_calendar, register_calendar
|
||||
from zipline.utils.context_tricks import CallbackManager
|
||||
@@ -1737,6 +1739,75 @@ def handle_data(context, data):
|
||||
)
|
||||
test_algo.run(self.data_portal)
|
||||
|
||||
def test_batch_order_target_percent_matches_multi_order(self):
|
||||
weights = pd.Series([.3, .7])
|
||||
|
||||
multi_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
|
||||
self.asset_finder)
|
||||
multi_test_algo = TradingAlgorithm(
|
||||
script=dedent("""\
|
||||
from collections import OrderedDict
|
||||
from six import iteritems
|
||||
|
||||
from zipline.api import sid, order_target_percent
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.assets = [sid(0), sid(3)]
|
||||
context.placed = False
|
||||
|
||||
def handle_data(context, data):
|
||||
if not context.placed:
|
||||
for asset, weight in iteritems(OrderedDict(zip(
|
||||
context.assets, {weights}
|
||||
))):
|
||||
order_target_percent(asset, weight)
|
||||
|
||||
context.placed = True
|
||||
|
||||
""").format(weights=list(weights)),
|
||||
blotter=multi_blotter,
|
||||
env=self.env,
|
||||
)
|
||||
multi_stats = multi_test_algo.run(self.data_portal)
|
||||
self.assertFalse(multi_blotter.order_batch_called)
|
||||
|
||||
batch_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY,
|
||||
self.asset_finder)
|
||||
batch_test_algo = TradingAlgorithm(
|
||||
script=dedent("""\
|
||||
from collections import OrderedDict
|
||||
|
||||
from zipline.api import sid, batch_order_target_percent
|
||||
|
||||
|
||||
def initialize(context):
|
||||
context.assets = [sid(0), sid(3)]
|
||||
context.placed = False
|
||||
|
||||
def handle_data(context, data):
|
||||
if not context.placed:
|
||||
batch_order_target_percent(OrderedDict(zip(
|
||||
context.assets, {weights}
|
||||
)))
|
||||
context.placed = True
|
||||
|
||||
""").format(weights=list(weights)),
|
||||
blotter=batch_blotter,
|
||||
env=self.env,
|
||||
)
|
||||
batch_stats = batch_test_algo.run(self.data_portal)
|
||||
self.assertTrue(batch_blotter.order_batch_called)
|
||||
|
||||
for stats in (multi_stats, batch_stats):
|
||||
stats.orders = stats.orders.apply(
|
||||
lambda orders: [toolz.dissoc(o, 'id') for o in orders]
|
||||
)
|
||||
stats.transactions = stats.transactions.apply(
|
||||
lambda txns: [toolz.dissoc(txn, 'order_id') for txn in txns]
|
||||
)
|
||||
assert_equal(multi_stats, batch_stats)
|
||||
|
||||
def test_order_dead_asset(self):
|
||||
# after asset 0 is dead
|
||||
params = SimulationParameters(
|
||||
|
||||
Reference in New Issue
Block a user