diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index b19eb601..4c6b1965 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -58,7 +58,15 @@ from zipline.errors import ( UnsupportedDatetimeFormat, CannotOrderDelistedAsset, SetCancelPolicyPostInit, - UnsupportedCancelPolicy + UnsupportedCancelPolicy, + OrderInBeforeTradingStart) +from zipline.api import ( + order, + order_value, + order_percent, + order_target, + order_target_value, + order_target_percent ) from zipline.finance.commission import PerShare @@ -362,8 +370,8 @@ def handle_data(context, data): self.assertEqual(all_orders[2], orders_2) self.assertEqual(len(all_orders[2]), 3) - for order in orders_2: - algo.cancel_order(order) + for order_ in orders_2: + algo.cancel_order(order_) all_orders = algo.get_open_orders() self.assertEqual(all_orders, {}) @@ -759,6 +767,32 @@ class TestTransformAlgorithm(WithLogger, ) algo.run(self.data_portal) + @parameterized.expand([ + (order, 1), + (order_value, 1000), + (order_target, 1), + (order_target_value, 1000), + (order_percent, 1), + (order_target_percent, 1), + ]) + def test_cannot_order_in_before_trading_start(self, order_method, amount): + algotext = """ +from zipline.api import sid +from zipline.api import {order_func} + +def initialize(context): + context.asset = sid(133) + +def before_trading_start(context, data): + {order_func}(context.asset, {arg}) + """.format(order_func=order_method.__name__, arg=amount) + + algo = TradingAlgorithm(script=algotext, sim_params=self.sim_params, + data_frequency='daily', env=self.env) + + with self.assertRaises(OrderInBeforeTradingStart): + algo.run(self.data_portal) + def test_run_twice(self): algo1 = TestRegisterTransformAlgorithm( sim_params=self.sim_params, diff --git a/zipline/algorithm.py b/zipline/algorithm.py index dea5c74e..16ef2612 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -51,7 +51,8 @@ from zipline.errors import ( UnsupportedDatetimeFormat, UnsupportedOrderParameters, UnsupportedSlippageModel, - CannotOrderDelistedAsset, UnsupportedCancelPolicy, SetCancelPolicyPostInit) + CannotOrderDelistedAsset, UnsupportedCancelPolicy, SetCancelPolicyPostInit, + OrderInBeforeTradingStart) from zipline.finance.trading import TradingEnvironment from zipline.finance.blotter import Blotter from zipline.finance.commission import PerShare, PerTrade, PerDollar @@ -87,7 +88,7 @@ from zipline.utils.api_support import ( require_initialized, require_not_initialized, ZiplineAPI, -) + disallowed_in_before_trading_start) from zipline.utils.input_validation import ensure_upper_case from zipline.utils.cache import CachedObject, Expired import zipline.utils.events @@ -968,6 +969,7 @@ class TradingAlgorithm(object): return True @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order(self, asset, amount, limit_price=None, stop_price=None, @@ -1055,6 +1057,7 @@ class TradingAlgorithm(object): return MarketOrder() @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order_value(self, asset, value, limit_price=None, stop_price=None, style=None): """ @@ -1205,6 +1208,7 @@ class TradingAlgorithm(object): self.sim_params.data_frequency = value @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order_percent(self, asset, percent, limit_price=None, stop_price=None, style=None): """ @@ -1223,6 +1227,7 @@ class TradingAlgorithm(object): style=style) @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order_target(self, asset, target, limit_price=None, stop_price=None, style=None): """ @@ -1249,6 +1254,7 @@ class TradingAlgorithm(object): style=style) @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order_target_value(self, asset, target, limit_price=None, stop_price=None, style=None): """ @@ -1270,6 +1276,7 @@ class TradingAlgorithm(object): style=style) @api_method + @disallowed_in_before_trading_start(OrderInBeforeTradingStart()) def order_target_percent(self, asset, target, limit_price=None, stop_price=None, style=None): """ diff --git a/zipline/errors.py b/zipline/errors.py index e809d93f..675c3eac 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -263,6 +263,13 @@ class HistoryInInitialize(ZiplineError): msg = "history() should only be called in handle_data()" +class OrderInBeforeTradingStart(ZiplineError): + """ + Raised when an algorithm calls an order method in before_trading_start. + """ + msg = "Cannot place orders inside before_trading_start." + + class MultipleSymbolsFound(ZiplineError): """ Raised when a symbol() call contains a symbol that changed over diff --git a/zipline/utils/api_support.py b/zipline/utils/api_support.py index 5096fc63..75444460 100644 --- a/zipline/utils/api_support.py +++ b/zipline/utils/api_support.py @@ -98,3 +98,25 @@ def require_initialized(exception): return method(self, *args, **kwargs) return wrapped_method return decorator + + +def disallowed_in_before_trading_start(exception): + """ + Decorator for API methods that cannot be called from within + TradingAlgorithm.before_trading_start. `exception` will be raised if the + method is called inside `before_trading_start`. + + Usage + ----- + @disallowed_in_before_trading_start(SomeException("Don't do that!")) + def method(self): + # Do stuff that is not allowed inside before_trading_start. + """ + def decorator(method): + @wraps(method) + def wrapped_method(self, *args, **kwargs): + if self._in_before_trading_start: + raise exception + return method(self, *args, **kwargs) + return wrapped_method + return decorator