diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 81e842ee..71df3863 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -1787,9 +1787,14 @@ def handle_data(context, data): def handle_data(context, data): if not context.placed: - batch_order_target_percent(OrderedDict(zip( + orders = batch_order_target_percent(OrderedDict(zip( context.assets, {weights} ))) + assert len(orders) == 2, \ + "len(orders) was %s but expected 2" % len(orders) + for o in orders: + assert o is not None, "An order is None" + context.placed = True """).format(weights=list(weights)), @@ -1808,6 +1813,41 @@ def handle_data(context, data): ) assert_equal(multi_stats, batch_stats) + def test_batch_order_target_percent_filters_null_orders(self): + weights = pd.Series([1, 0]) + + batch_blotter = RecordBatchBlotter(self.SIM_PARAMS_DATA_FREQUENCY, + self.asset_finder) + batch_test_algo = TradingAlgorithm( + script=dedent("""\ + from collections import OrderedDict + + from zipline.api import sid, batch_order_target_percent + + + def initialize(context): + context.assets = [sid(0), sid(3)] + context.placed = False + + def handle_data(context, data): + if not context.placed: + orders = batch_order_target_percent(OrderedDict(zip( + context.assets, {weights} + ))) + assert len(orders) == 1, \ + "len(orders) was %s but expected 1" % len(orders) + for o in orders: + assert o is not None, "An order is None" + + context.placed = True + + """).format(weights=list(weights)), + blotter=batch_blotter, + env=self.env, + ) + batch_test_algo.run(self.data_portal) + self.assertTrue(batch_blotter.order_batch_called) + def test_order_dead_asset(self): # after asset 0 is dead params = SimulationParameters( diff --git a/zipline/algorithm.py b/zipline/algorithm.py index b076e71a..69962113 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -2003,7 +2003,8 @@ class TradingAlgorithm(object): order_args[asset] = (asset, amount, style) order_ids = self.blotter.batch_order(viewvalues(order_args)) - return pd.Series(data=order_ids, index=order_args) + order_ids = pd.Series(data=order_ids, index=order_args) + return order_ids[~order_ids.isnull()] @error_keywords(sid='Keyword argument `sid` is no longer supported for ' 'get_open_orders. Use `asset` instead.')