diff --git a/tests/test_algorithm.py b/tests/test_algorithm.py index d32ea9d9..5245f012 100644 --- a/tests/test_algorithm.py +++ b/tests/test_algorithm.py @@ -13,16 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest import TestCase from datetime import timedelta +from mock import MagicMock +from six.moves import range +from unittest import TestCase + import numpy as np import pandas as pd -from mock import MagicMock -from zipline.utils.test_utils import setup_logger +from zipline.utils.test_utils import ( + nullctx, + setup_logger +) import zipline.utils.factory as factory import zipline.utils.simfactory as simfactory +from zipline.errors import ( + RegisterTradingControlPostInit, + TradingControlViolation, +) from zipline.test_algorithms import ( AmbitiousStopLimitAlgorithm, EmptyPositionsAlgorithm, @@ -37,6 +46,10 @@ from zipline.test_algorithms import ( TestTargetAlgorithm, TestTargetPercentAlgorithm, TestTargetValueAlgorithm, + SetLongOnlyAlgorithm, + SetMaxPositionSizeAlgorithm, + SetMaxOrderCountAlgorithm, + SetMaxOrderSizeAlgorithm, api_algo, api_symbol_algo, call_all_order_methods, @@ -560,3 +573,234 @@ def handle_data(context, data): algo = TradingAlgorithm(script=history_algo, data_frequency='minute') output = algo.run(source) self.assertIsNot(output, None) + + +class TestTradingControls(TestCase): + + def setUp(self): + self.sim_params = factory.create_simulation_parameters(num_days=4) + self.sid = 133 + self.trade_history = factory.create_trade_history( + self.sid, + [10.0, 10.0, 11.0, 11.0], + [100, 100, 100, 300], + timedelta(days=1), + self.sim_params + ) + + self.source = SpecificEquityTrades(event_list=self.trade_history) + + def _check_algo(self, + algo, + handle_data, + expected_order_count, + expected_exc): + + algo._handle_data = handle_data + with self.assertRaises(expected_exc) if expected_exc else nullctx(): + algo.run(self.source) + self.assertEqual(algo.order_count, expected_order_count) + self.source.rewind() + + def check_algo_succeeds(self, algo, handle_data, order_count=4): + # Default for order_count assumes one order per handle_data call. + self._check_algo(algo, handle_data, order_count, None) + + def check_algo_fails(self, algo, handle_data, order_count): + self._check_algo(algo, + handle_data, + order_count, + TradingControlViolation) + + def test_set_max_position_size(self): + + # Buy one share four times. Should be fine. + def handle_data(algo, data): + algo.order(self.sid, 1) + algo.order_count += 1 + algo = SetMaxPositionSizeAlgorithm(sid=self.sid, + max_shares=10, + max_notional=500.0) + self.check_algo_succeeds(algo, handle_data) + + # Buy three shares four times. Should bail on the fourth before it's + # placed. + def handle_data(algo, data): + algo.order(self.sid, 3) + algo.order_count += 1 + + algo = SetMaxPositionSizeAlgorithm(sid=self.sid, + max_shares=10, + max_notional=500.0) + self.check_algo_fails(algo, handle_data, 3) + + # Buy two shares four times. Should bail due to max_notional on the + # third attempt. + def handle_data(algo, data): + algo.order(self.sid, 3) + algo.order_count += 1 + + algo = SetMaxPositionSizeAlgorithm(sid=self.sid, + max_shares=10, + max_notional=61.0) + self.check_algo_fails(algo, handle_data, 2) + + # Set the trading control to a different sid, then BUY ALL THE THINGS!. + # Should continue normally. + def handle_data(algo, data): + algo.order(self.sid, 10000) + algo.order_count += 1 + algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1, + max_shares=10, + max_notional=61.0) + self.check_algo_succeeds(algo, handle_data) + + # Set the trading control sid to None, then BUY ALL THE THINGS!. Should + # fail because setting sid to None makes the control apply to all sids. + def handle_data(algo, data): + algo.order(self.sid, 10000) + algo.order_count += 1 + algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0) + self.check_algo_fails(algo, handle_data, 0) + + def test_set_max_order_size(self): + + # Buy one share. + def handle_data(algo, data): + algo.order(self.sid, 1) + algo.order_count += 1 + algo = SetMaxOrderSizeAlgorithm(sid=self.sid, + max_shares=10, + max_notional=500.0) + self.check_algo_succeeds(algo, handle_data) + + # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt + # because we exceed shares. + def handle_data(algo, data): + algo.order(self.sid, algo.order_count + 1) + algo.order_count += 1 + + algo = SetMaxOrderSizeAlgorithm(sid=self.sid, + max_shares=3, + max_notional=500.0) + self.check_algo_fails(algo, handle_data, 3) + + # Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt + # because we exceed notional. + def handle_data(algo, data): + algo.order(self.sid, algo.order_count + 1) + algo.order_count += 1 + + algo = SetMaxOrderSizeAlgorithm(sid=self.sid, + max_shares=10, + max_notional=40.0) + self.check_algo_fails(algo, handle_data, 3) + + # Set the trading control to a different sid, then BUY ALL THE THINGS!. + # Should continue normally. + def handle_data(algo, data): + algo.order(self.sid, 10000) + algo.order_count += 1 + algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1, + max_shares=1, + max_notional=1.0) + self.check_algo_succeeds(algo, handle_data) + + # Set the trading control sid to None, then BUY ALL THE THINGS!. + # Should fail because not specifying a sid makes the trading control + # apply to all sids. + def handle_data(algo, data): + algo.order(self.sid, 10000) + algo.order_count += 1 + algo = SetMaxOrderSizeAlgorithm(max_shares=1, + max_notional=1.0) + self.check_algo_fails(algo, handle_data, 0) + + def test_set_max_order_count(self): + + # Override the default setUp to use six-hour intervals instead of full + # days so we can exercise trading-session rollover logic. + trade_history = factory.create_trade_history( + self.sid, + [10.0, 10.0, 11.0, 11.0], + [100, 100, 100, 300], + timedelta(hours=6), + self.sim_params + ) + self.source = SpecificEquityTrades(event_list=trade_history) + + def handle_data(algo, data): + for i in range(5): + algo.order(self.sid, 1) + algo.order_count += 1 + + algo = SetMaxOrderCountAlgorithm(3) + self.check_algo_fails(algo, handle_data, 3) + + # Second call to handle_data is the same day as the first, so the last + # order of the second call should fail. + algo = SetMaxOrderCountAlgorithm(9) + self.check_algo_fails(algo, handle_data, 9) + + # Only ten orders are placed per day, so this should pass even though + # in total more than 20 orders are placed. + algo = SetMaxOrderCountAlgorithm(10) + self.check_algo_succeeds(algo, handle_data, order_count=20) + + def test_long_only(self): + + # Sell immediately -> fail immediately. + def handle_data(algo, data): + algo.order(self.sid, -1) + algo.order_count += 1 + algo = SetLongOnlyAlgorithm() + self.check_algo_fails(algo, handle_data, 0) + + # Buy on even days, sell on odd days. Never takes a short position, so + # should succeed. + def handle_data(algo, data): + if (algo.order_count % 2) == 0: + algo.order(self.sid, 1) + else: + algo.order(self.sid, -1) + algo.order_count += 1 + algo = SetLongOnlyAlgorithm() + self.check_algo_succeeds(algo, handle_data) + + # Buy on first three days, then sell off holdings. Should succeed. + def handle_data(algo, data): + amounts = [1, 1, 1, -3] + algo.order(self.sid, amounts[algo.order_count]) + algo.order_count += 1 + algo = SetLongOnlyAlgorithm() + self.check_algo_succeeds(algo, handle_data) + + # Buy on first three days, then sell off holdings plus an extra share. + # Should fail on the last sale. + def handle_data(algo, data): + amounts = [1, 1, 1, -4] + algo.order(self.sid, amounts[algo.order_count]) + algo.order_count += 1 + algo = SetLongOnlyAlgorithm() + self.check_algo_fails(algo, handle_data, 3) + + def test_register_post_init(self): + + def initialize(algo): + algo.initialized = True + + def handle_data(algo, data): + + with self.assertRaises(RegisterTradingControlPostInit): + algo.set_max_position_size(self.sid, 1, 1) + with self.assertRaises(RegisterTradingControlPostInit): + algo.set_max_order_size(self.sid, 1, 1) + with self.assertRaises(RegisterTradingControlPostInit): + algo.set_max_order_count(1) + with self.assertRaises(RegisterTradingControlPostInit): + algo.set_long_only() + + algo = TradingAlgorithm(initialize=initialize, + handle_data=handle_data) + algo.run(self.source) + self.source.rewind() diff --git a/zipline/algorithm.py b/zipline/algorithm.py index fc418f65..d8e20fbd 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -26,16 +26,23 @@ from six import iteritems, exec_ from operator import attrgetter from zipline.errors import ( - UnsupportedSlippageModel, - OverrideSlippagePostInit, - UnsupportedCommissionModel, OverrideCommissionPostInit, - UnsupportedOrderParameters + OverrideSlippagePostInit, + RegisterTradingControlPostInit, + UnsupportedCommissionModel, + UnsupportedOrderParameters, + UnsupportedSlippageModel, ) from zipline.finance import trading from zipline.finance.blotter import Blotter from zipline.finance.commission import PerShare, PerTrade, PerDollar +from zipline.finance.controls import ( + LongOnly, + MaxOrderCount, + MaxOrderSize, + MaxPositionSize, +) from zipline.finance.constants import ANNUALIZER from zipline.finance.execution import ( LimitOrder, @@ -125,6 +132,9 @@ class TradingAlgorithm(object): self.transforms = [] self.sources = [] + # List of trading controls to be used to validate orders. + self.trading_controls = [] + self._recorded_vars = {} self.namespace = kwargs.get('namespace', {}) @@ -525,6 +535,13 @@ class TradingAlgorithm(object): msg="Passing both stop_price and style is not supported." ) + for control in self.trading_controls: + control.validate(sid, + amount, + self.updated_portfolio(), + self.get_datetime(), + self.trading_client.current_data) + @staticmethod def __convert_order_params_for_blotter(limit_price, stop_price, style): """ @@ -799,3 +816,68 @@ class TradingAlgorithm(object): bar_count, frequency, field, ffill) history_spec = self.history_specs[spec_key_str] return self.history_container.get_history(history_spec, self.datetime) + + #################### + # Trading Controls # + #################### + + def register_trading_control(self, control): + """ + Register a new TradingControl to be checked prior to order calls. + """ + if self.initialized: + raise RegisterTradingControlPostInit() + self.trading_controls.append(control) + + @api_method + def set_max_position_size(self, + sid=None, + max_shares=None, + max_notional=None): + """ + Set a limit on the number of shares and/or dollar value held for the + given sid. Limits are treated as absolute values and are enforced at + the time that the algo attempts to place an order for sid. This means + that it's possible to end up with more than the max number of shares + due to splits/dividends, and more than the max notional due to price + improvement. + + If an algorithm attempts to place an order that would result in + increasing the absolute value of shares/dollar value exceeding one of + these limits, raise a TradingControlException. + """ + control = MaxPositionSize(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + self.register_trading_control(control) + + @api_method + def set_max_order_size(self, sid=None, max_shares=None, max_notional=None): + """ + Set a limit on the number of shares and/or dollar value of any single + order placed for sid. Limits are treated as absolute values and are + enforced at the time that the algo attempts to place an order for sid. + + If an algorithm attempts to place an order that would result in + exceeding one of these limits, raise a TradingControlException. + """ + control = MaxOrderSize(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + self.register_trading_control(control) + + @api_method + def set_max_order_count(self, max_count): + """ + Set a limit on the number of orders that can be placed within the given + time interval. + """ + control = MaxOrderCount(max_count) + self.register_trading_control(control) + + @api_method + def set_long_only(self): + """ + Set a rule specifying that this algorithm cannot take short positions. + """ + self.register_trading_control(LongOnly()) diff --git a/zipline/errors.py b/zipline/errors.py index 0fc2d5d1..47e81f90 100644 --- a/zipline/errors.py +++ b/zipline/errors.py @@ -60,6 +60,15 @@ method. """.strip() +class RegisterTradingControlPostInit(ZiplineError): + # Raised if a user's script register's a trading control after initialize + # has been run. + msg = """ +You attempted to set a trading control after the simulation has \ +started. Trading controls may only be set during initialize. +""".strip() + + class UnsupportedCommissionModel(ZiplineError): """ Raised if a user script calls the override_commission magic @@ -128,3 +137,12 @@ class UnsupportedOrderParameters(ZiplineError): call. """ msg = "{msg}" + + +class TradingControlViolation(ZiplineError): + """ + Raised if an order would violate a constraint set by a TradingControl. + """ + msg = """ +Order for {amount} shares of {sid} violates trading constraint {constraint}. +""".strip() diff --git a/zipline/finance/controls.py b/zipline/finance/controls.py new file mode 100644 index 00000000..4e5f6511 --- /dev/null +++ b/zipline/finance/controls.py @@ -0,0 +1,237 @@ +# +# Copyright 2014 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc + +from six import with_metaclass + +from zipline.errors import TradingControlViolation + + +class TradingControl(with_metaclass(abc.ABCMeta)): + """ + Abstract base class representing a fail-safe control on the behavior of any + algorithm. + """ + + def __init__(self, **kwargs): + """ + Track any arguments that should be printed in the error message + generated by self.fail. + """ + self.__fail_args = kwargs + + @abc.abstractmethod + def validate(self, + sid, + amount, + portfolio, + algo_datetime, + algo_current_data): + """ + Before any order is executed by TradingAlgorithm, this method should be + called *exactly once* on each registered TradingControl object. + + If the specified sid and amount do not violate this TradingControl's + restraint given the information in `portfolio`, this method should + return None and have no externally-visible side-effects. + + If the desired order violates this TradingControl's contraint, this + method should call self.fail(sid, amount). + """ + raise NotImplementedError + + def fail(self, sid, amount): + """ + Raise a TradingControlViolation with information about the failure. + """ + raise TradingControlViolation(sid=sid, + amount=amount, + constraint=repr(self)) + + def __repr__(self): + return "{name}({attrs})".format(name=self.__class__.__name__, + attrs=self.__fail_args) + + +class MaxOrderCount(TradingControl): + """ + TradingControl representing a limit on the number of orders that can be + placed in a given trading day. + """ + + def __init__(self, max_count): + + super(MaxOrderCount, self).__init__(max_count=max_count) + self.orders_placed = 0 + self.max_count = max_count + self.current_date = None + + def validate(self, + sid, + amount, + _portfolio, + algo_datetime, + _algo_current_data): + """ + Fail if we've already placed self.max_count orders today. + """ + algo_date = algo_datetime.date() + + # Reset order count if it's a new day. + if self.current_date and self.current_date != algo_date: + self.orders_placed = 0 + self.current_date = algo_date + + if self.orders_placed >= self.max_count: + self.fail(sid, amount) + self.orders_placed += 1 + + +class MaxOrderSize(TradingControl): + """ + TradingControl representing a limit on the magnitude of any single order + placed with the given security. Can be specified by share or by dollar + value. + """ + + def __init__(self, sid=None, max_shares=None, max_notional=None): + super(MaxOrderSize, self).__init__(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + self.sid = sid + self.max_shares = max_shares + self.max_notional = max_notional + + if max_shares is None and max_notional is None: + raise ValueError( + "Must supply at least one of max_shares and max_notional" + ) + + if max_shares and max_shares < 0: + raise ValueError( + "max_shares cannot be negative." + ) + + if max_notional and max_notional < 0: + raise ValueError( + "max_notional must be positive." + ) + + def validate(self, + sid, + amount, + portfolio, + _algo_datetime, + algo_current_data): + """ + Fail if the magnitude of the given order exceeds either self.max_shares + or self.max_notional. + """ + + if self.sid is not None and self.sid != sid: + return + + if self.max_shares is not None and abs(amount) > self.max_shares: + self.fail(sid, amount) + + current_sid_price = algo_current_data[sid].price + order_value = amount * current_sid_price + + too_much_value = (self.max_notional is not None and + abs(order_value) > self.max_notional) + + if too_much_value: + self.fail(sid, amount) + + +class MaxPositionSize(TradingControl): + """ + TradingControl representing a limit on the maximum position size that can + be held by an algo for a given security. + """ + + def __init__(self, sid=None, max_shares=None, max_notional=None): + super(MaxPositionSize, self).__init__(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + self.sid = sid + self.max_shares = max_shares + self.max_notional = max_notional + + if max_shares is None and max_notional is None: + raise ValueError( + "Must supply at least one of max_shares and max_notional" + ) + + if max_shares and max_shares < 0: + raise ValueError( + "max_shares cannot be negative." + ) + + if max_notional and max_notional < 0: + raise ValueError( + "max_notional must be positive." + ) + + def validate(self, + sid, + amount, + portfolio, + algo_datetime, + algo_current_data): + """ + Fail if the given order would cause the magnitude of our position to be + greater in shares than self.max_shares or greater in dollar value than + self.max_notional. + """ + + if self.sid is not None and self.sid != sid: + return + + current_share_count = portfolio.positions[sid].amount + shares_post_order = current_share_count + amount + + too_many_shares = (self.max_shares is not None and + abs(shares_post_order) > self.max_shares) + if too_many_shares: + self.fail(sid, amount) + + current_price = algo_current_data[sid].price + value_post_order = shares_post_order * current_price + + too_much_value = (self.max_notional is not None and + abs(value_post_order) > self.max_notional) + + if too_much_value: + self.fail(sid, amount) + + +class LongOnly(TradingControl): + """ + TradingControl representing a prohibition against holding short positions. + """ + + def validate(self, + sid, + amount, + portfolio, + _algo_datetime, + _algo_current_data): + """ + Fail if we would hold negative shares of sid after completing this + order. + """ + if portfolio.positions[sid].amount + amount < 0: + self.fail(sid, amount) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index c5113536..6349ba1c 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -109,7 +109,7 @@ class TestAlgorithm(TradingAlgorithm): self.sid_filter = [self.sid] def handle_data(self, data): - # place an order for 100 shares of sid + # place an order for amount shares of sid if self.incr < self.count: self.order(self.sid, self.amount) self.incr += 1 @@ -399,7 +399,39 @@ class TestTargetValueAlgorithm(TradingAlgorithm): self.target_shares = np.round(20 / data[0].price) -from zipline.algorithm import TradingAlgorithm +############################ +# TradingControl Test Algos# +############################ + + +class SetMaxPositionSizeAlgorithm(TradingAlgorithm): + def initialize(self, sid=None, max_shares=None, max_notional=None): + self.order_count = 0 + self.set_max_position_size(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + + +class SetMaxOrderSizeAlgorithm(TradingAlgorithm): + def initialize(self, sid=None, max_shares=None, max_notional=None): + self.order_count = 0 + self.set_max_order_size(sid=sid, + max_shares=max_shares, + max_notional=max_notional) + + +class SetMaxOrderCountAlgorithm(TradingAlgorithm): + def initialize(self, count): + self.order_count = 0 + self.set_max_order_count(count) + + +class SetLongOnlyAlgorithm(TradingAlgorithm): + def initialize(self): + self.order_count = 0 + self.set_long_only() + + from zipline.transforms import BatchTransform, batch_transform from zipline.transforms import MovingAverage diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 9f348c3b..e027b802 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from logbook import FileHandler from zipline.finance.blotter import ORDER_STATUS @@ -110,3 +111,15 @@ class ExceptionTransform(object): def update(self, event): assert False, "An assertion message" + + +@contextmanager +def nullctx(): + """ + Null context manager. Useful for conditionally adding a contextmanager in + a single line, e.g.: + + with SomeContextManager() if some_expr else nullcontext: + do_stuff() + """ + yield