diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index 5245f012..4490576d 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -70,6 +70,7 @@ from zipline.sources import (SpecificEquityTrades, RandomWalkSource) from zipline.transforms import MovingAverage +from zipline.finance.execution import LimitOrder from zipline.finance.trading import SimulationParameters from zipline.utils.api_support import set_algo_instance from zipline.algorithm import TradingAlgorithm @@ -100,11 +101,74 @@ class TestRecordAlgorithm(TestCase): range(1, len(output) + 1)) +class TestMiscellaneousAPI(TestCase): + def setUp(self): + setup_logger(self) + + sids = [1, 2] + self.sim_params = factory.create_simulation_parameters(num_days=2, + sids=sids) + self.source = factory.create_minutely_trade_source( + sids, + trade_count=100, + sim_params=self.sim_params, + concurrent=True, + ) + + def test_get_open_orders(self): + + def initialize(algo): + algo.minute = 0 + + def handle_data(algo, data): + if algo.minute == 0: + + # Should be filled by the next minute + algo.order(1, 1) + + # Won't be filled because the price is too low. + algo.order(2, 1, style=LimitOrder(0.01)) + algo.order(2, 1, style=LimitOrder(0.01)) + algo.order(2, 1, style=LimitOrder(0.01)) + + all_orders = algo.get_open_orders() + self.assertEqual(list(all_orders.keys()), [1, 2]) + + self.assertEqual(all_orders[1], algo.get_open_orders(1)) + self.assertEqual(len(all_orders[1]), 1) + + self.assertEqual(all_orders[2], algo.get_open_orders(2)) + self.assertEqual(len(all_orders[2]), 3) + + if algo.minute == 1: + # First order should have filled. + # Second order should still be open. + all_orders = algo.get_open_orders() + self.assertEqual(list(all_orders.keys()), [2]) + + self.assertEqual([], algo.get_open_orders(1)) + + orders_2 = algo.get_open_orders(2) + self.assertEqual(all_orders[2], orders_2) + self.assertEqual(len(all_orders[2]), 3) + + for order in orders_2: + algo.cancel_order(order) + + all_orders = algo.get_open_orders() + self.assertEqual(all_orders, {}) + + algo.minute += 1 + + algo = TradingAlgorithm(initialize=initialize, + handle_data=handle_data) + algo.run(self.source, sim_params=self.sim_params) + + class TestTransformAlgorithm(TestCase): def setUp(self): setup_logger(self) self.sim_params = factory.create_simulation_parameters(num_days=4) - setup_logger(self) trade_history = factory.create_trade_history( 133, diff --git a/zipline/algorithm.py b/zipline/algorithm.py index 4634c862..d67e4ab8 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -758,9 +758,11 @@ class TradingAlgorithm(object): @api_method def get_open_orders(self, sid=None): if sid is None: - return {key: [order.to_api_obj() for order in orders] - for key, orders - in self.blotter.open_orders.iteritems()} + return { + key: [order.to_api_obj() for order in orders] + for key, orders in iteritems(self.blotter.open_orders) + if orders + } if sid in self.blotter.open_orders: orders = self.blotter.open_orders[sid] return [order.to_api_obj() for order in orders]