ENH: Add trading controls to zipline API.

Adds four new methods to the Zipline API that can be used as circuit-breakers
to interrupt the execution of an algorithm.  The API methods are:

`set_max_position_size`
`set_max_order_size`
`set_max_order_count`
`set_long_only`

Internally, these methods are implemented by each registering a TradingControl
callback object with the TradingAlgorithm.  During
TradingAlgorithm.__validate_order_params (and thus before any side-effects of
the order call occur), each callback's `validate` method is called with
information about the order to be placed and the algorithm's current state,
raising an exception if the callback detects that an error condition has been breached.
This commit is contained in:
Scott Sanderson
2014-05-03 01:25:19 -04:00
parent 9953c7ea28
commit 644486e6da
6 changed files with 635 additions and 9 deletions
+247 -3
View File
@@ -13,16 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import TestCase
from datetime import timedelta
from mock import MagicMock
from six.moves import range
from unittest import TestCase
import numpy as np
import pandas as pd
from mock import MagicMock
from zipline.utils.test_utils import setup_logger
from zipline.utils.test_utils import (
nullctx,
setup_logger
)
import zipline.utils.factory as factory
import zipline.utils.simfactory as simfactory
from zipline.errors import (
RegisterTradingControlPostInit,
TradingControlViolation,
)
from zipline.test_algorithms import (
AmbitiousStopLimitAlgorithm,
EmptyPositionsAlgorithm,
@@ -37,6 +46,10 @@ from zipline.test_algorithms import (
TestTargetAlgorithm,
TestTargetPercentAlgorithm,
TestTargetValueAlgorithm,
SetLongOnlyAlgorithm,
SetMaxPositionSizeAlgorithm,
SetMaxOrderCountAlgorithm,
SetMaxOrderSizeAlgorithm,
api_algo,
api_symbol_algo,
call_all_order_methods,
@@ -560,3 +573,234 @@ def handle_data(context, data):
algo = TradingAlgorithm(script=history_algo, data_frequency='minute')
output = algo.run(source)
self.assertIsNot(output, None)
class TestTradingControls(TestCase):
def setUp(self):
self.sim_params = factory.create_simulation_parameters(num_days=4)
self.sid = 133
self.trade_history = factory.create_trade_history(
self.sid,
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(days=1),
self.sim_params
)
self.source = SpecificEquityTrades(event_list=self.trade_history)
def _check_algo(self,
algo,
handle_data,
expected_order_count,
expected_exc):
algo._handle_data = handle_data
with self.assertRaises(expected_exc) if expected_exc else nullctx():
algo.run(self.source)
self.assertEqual(algo.order_count, expected_order_count)
self.source.rewind()
def check_algo_succeeds(self, algo, handle_data, order_count=4):
# Default for order_count assumes one order per handle_data call.
self._check_algo(algo, handle_data, order_count, None)
def check_algo_fails(self, algo, handle_data, order_count):
self._check_algo(algo,
handle_data,
order_count,
TradingControlViolation)
def test_set_max_position_size(self):
# Buy one share four times. Should be fine.
def handle_data(algo, data):
algo.order(self.sid, 1)
algo.order_count += 1
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
max_shares=10,
max_notional=500.0)
self.check_algo_succeeds(algo, handle_data)
# Buy three shares four times. Should bail on the fourth before it's
# placed.
def handle_data(algo, data):
algo.order(self.sid, 3)
algo.order_count += 1
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
max_shares=10,
max_notional=500.0)
self.check_algo_fails(algo, handle_data, 3)
# Buy two shares four times. Should bail due to max_notional on the
# third attempt.
def handle_data(algo, data):
algo.order(self.sid, 3)
algo.order_count += 1
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
max_shares=10,
max_notional=61.0)
self.check_algo_fails(algo, handle_data, 2)
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
# Should continue normally.
def handle_data(algo, data):
algo.order(self.sid, 10000)
algo.order_count += 1
algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
max_shares=10,
max_notional=61.0)
self.check_algo_succeeds(algo, handle_data)
# Set the trading control sid to None, then BUY ALL THE THINGS!. Should
# fail because setting sid to None makes the control apply to all sids.
def handle_data(algo, data):
algo.order(self.sid, 10000)
algo.order_count += 1
algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0)
self.check_algo_fails(algo, handle_data, 0)
def test_set_max_order_size(self):
# Buy one share.
def handle_data(algo, data):
algo.order(self.sid, 1)
algo.order_count += 1
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
max_shares=10,
max_notional=500.0)
self.check_algo_succeeds(algo, handle_data)
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
# because we exceed shares.
def handle_data(algo, data):
algo.order(self.sid, algo.order_count + 1)
algo.order_count += 1
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
max_shares=3,
max_notional=500.0)
self.check_algo_fails(algo, handle_data, 3)
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
# because we exceed notional.
def handle_data(algo, data):
algo.order(self.sid, algo.order_count + 1)
algo.order_count += 1
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
max_shares=10,
max_notional=40.0)
self.check_algo_fails(algo, handle_data, 3)
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
# Should continue normally.
def handle_data(algo, data):
algo.order(self.sid, 10000)
algo.order_count += 1
algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
max_shares=1,
max_notional=1.0)
self.check_algo_succeeds(algo, handle_data)
# Set the trading control sid to None, then BUY ALL THE THINGS!.
# Should fail because not specifying a sid makes the trading control
# apply to all sids.
def handle_data(algo, data):
algo.order(self.sid, 10000)
algo.order_count += 1
algo = SetMaxOrderSizeAlgorithm(max_shares=1,
max_notional=1.0)
self.check_algo_fails(algo, handle_data, 0)
def test_set_max_order_count(self):
# Override the default setUp to use six-hour intervals instead of full
# days so we can exercise trading-session rollover logic.
trade_history = factory.create_trade_history(
self.sid,
[10.0, 10.0, 11.0, 11.0],
[100, 100, 100, 300],
timedelta(hours=6),
self.sim_params
)
self.source = SpecificEquityTrades(event_list=trade_history)
def handle_data(algo, data):
for i in range(5):
algo.order(self.sid, 1)
algo.order_count += 1
algo = SetMaxOrderCountAlgorithm(3)
self.check_algo_fails(algo, handle_data, 3)
# Second call to handle_data is the same day as the first, so the last
# order of the second call should fail.
algo = SetMaxOrderCountAlgorithm(9)
self.check_algo_fails(algo, handle_data, 9)
# Only ten orders are placed per day, so this should pass even though
# in total more than 20 orders are placed.
algo = SetMaxOrderCountAlgorithm(10)
self.check_algo_succeeds(algo, handle_data, order_count=20)
def test_long_only(self):
# Sell immediately -> fail immediately.
def handle_data(algo, data):
algo.order(self.sid, -1)
algo.order_count += 1
algo = SetLongOnlyAlgorithm()
self.check_algo_fails(algo, handle_data, 0)
# Buy on even days, sell on odd days. Never takes a short position, so
# should succeed.
def handle_data(algo, data):
if (algo.order_count % 2) == 0:
algo.order(self.sid, 1)
else:
algo.order(self.sid, -1)
algo.order_count += 1
algo = SetLongOnlyAlgorithm()
self.check_algo_succeeds(algo, handle_data)
# Buy on first three days, then sell off holdings. Should succeed.
def handle_data(algo, data):
amounts = [1, 1, 1, -3]
algo.order(self.sid, amounts[algo.order_count])
algo.order_count += 1
algo = SetLongOnlyAlgorithm()
self.check_algo_succeeds(algo, handle_data)
# Buy on first three days, then sell off holdings plus an extra share.
# Should fail on the last sale.
def handle_data(algo, data):
amounts = [1, 1, 1, -4]
algo.order(self.sid, amounts[algo.order_count])
algo.order_count += 1
algo = SetLongOnlyAlgorithm()
self.check_algo_fails(algo, handle_data, 3)
def test_register_post_init(self):
def initialize(algo):
algo.initialized = True
def handle_data(algo, data):
with self.assertRaises(RegisterTradingControlPostInit):
algo.set_max_position_size(self.sid, 1, 1)
with self.assertRaises(RegisterTradingControlPostInit):
algo.set_max_order_size(self.sid, 1, 1)
with self.assertRaises(RegisterTradingControlPostInit):
algo.set_max_order_count(1)
with self.assertRaises(RegisterTradingControlPostInit):
algo.set_long_only()
algo = TradingAlgorithm(initialize=initialize,
handle_data=handle_data)
algo.run(self.source)
self.source.rewind()
+86 -4
View File
@@ -26,16 +26,23 @@ from six import iteritems, exec_
from operator import attrgetter
from zipline.errors import (
UnsupportedSlippageModel,
OverrideSlippagePostInit,
UnsupportedCommissionModel,
OverrideCommissionPostInit,
UnsupportedOrderParameters
OverrideSlippagePostInit,
RegisterTradingControlPostInit,
UnsupportedCommissionModel,
UnsupportedOrderParameters,
UnsupportedSlippageModel,
)
from zipline.finance import trading
from zipline.finance.blotter import Blotter
from zipline.finance.commission import PerShare, PerTrade, PerDollar
from zipline.finance.controls import (
LongOnly,
MaxOrderCount,
MaxOrderSize,
MaxPositionSize,
)
from zipline.finance.constants import ANNUALIZER
from zipline.finance.execution import (
LimitOrder,
@@ -125,6 +132,9 @@ class TradingAlgorithm(object):
self.transforms = []
self.sources = []
# List of trading controls to be used to validate orders.
self.trading_controls = []
self._recorded_vars = {}
self.namespace = kwargs.get('namespace', {})
@@ -525,6 +535,13 @@ class TradingAlgorithm(object):
msg="Passing both stop_price and style is not supported."
)
for control in self.trading_controls:
control.validate(sid,
amount,
self.updated_portfolio(),
self.get_datetime(),
self.trading_client.current_data)
@staticmethod
def __convert_order_params_for_blotter(limit_price, stop_price, style):
"""
@@ -799,3 +816,68 @@ class TradingAlgorithm(object):
bar_count, frequency, field, ffill)
history_spec = self.history_specs[spec_key_str]
return self.history_container.get_history(history_spec, self.datetime)
####################
# Trading Controls #
####################
def register_trading_control(self, control):
"""
Register a new TradingControl to be checked prior to order calls.
"""
if self.initialized:
raise RegisterTradingControlPostInit()
self.trading_controls.append(control)
@api_method
def set_max_position_size(self,
sid=None,
max_shares=None,
max_notional=None):
"""
Set a limit on the number of shares and/or dollar value held for the
given sid. Limits are treated as absolute values and are enforced at
the time that the algo attempts to place an order for sid. This means
that it's possible to end up with more than the max number of shares
due to splits/dividends, and more than the max notional due to price
improvement.
If an algorithm attempts to place an order that would result in
increasing the absolute value of shares/dollar value exceeding one of
these limits, raise a TradingControlException.
"""
control = MaxPositionSize(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
self.register_trading_control(control)
@api_method
def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
"""
Set a limit on the number of shares and/or dollar value of any single
order placed for sid. Limits are treated as absolute values and are
enforced at the time that the algo attempts to place an order for sid.
If an algorithm attempts to place an order that would result in
exceeding one of these limits, raise a TradingControlException.
"""
control = MaxOrderSize(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
self.register_trading_control(control)
@api_method
def set_max_order_count(self, max_count):
"""
Set a limit on the number of orders that can be placed within the given
time interval.
"""
control = MaxOrderCount(max_count)
self.register_trading_control(control)
@api_method
def set_long_only(self):
"""
Set a rule specifying that this algorithm cannot take short positions.
"""
self.register_trading_control(LongOnly())
+18
View File
@@ -60,6 +60,15 @@ method.
""".strip()
class RegisterTradingControlPostInit(ZiplineError):
# Raised if a user's script register's a trading control after initialize
# has been run.
msg = """
You attempted to set a trading control after the simulation has \
started. Trading controls may only be set during initialize.
""".strip()
class UnsupportedCommissionModel(ZiplineError):
"""
Raised if a user script calls the override_commission magic
@@ -128,3 +137,12 @@ class UnsupportedOrderParameters(ZiplineError):
call.
"""
msg = "{msg}"
class TradingControlViolation(ZiplineError):
"""
Raised if an order would violate a constraint set by a TradingControl.
"""
msg = """
Order for {amount} shares of {sid} violates trading constraint {constraint}.
""".strip()
+237
View File
@@ -0,0 +1,237 @@
#
# Copyright 2014 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from six import with_metaclass
from zipline.errors import TradingControlViolation
class TradingControl(with_metaclass(abc.ABCMeta)):
"""
Abstract base class representing a fail-safe control on the behavior of any
algorithm.
"""
def __init__(self, **kwargs):
"""
Track any arguments that should be printed in the error message
generated by self.fail.
"""
self.__fail_args = kwargs
@abc.abstractmethod
def validate(self,
sid,
amount,
portfolio,
algo_datetime,
algo_current_data):
"""
Before any order is executed by TradingAlgorithm, this method should be
called *exactly once* on each registered TradingControl object.
If the specified sid and amount do not violate this TradingControl's
restraint given the information in `portfolio`, this method should
return None and have no externally-visible side-effects.
If the desired order violates this TradingControl's contraint, this
method should call self.fail(sid, amount).
"""
raise NotImplementedError
def fail(self, sid, amount):
"""
Raise a TradingControlViolation with information about the failure.
"""
raise TradingControlViolation(sid=sid,
amount=amount,
constraint=repr(self))
def __repr__(self):
return "{name}({attrs})".format(name=self.__class__.__name__,
attrs=self.__fail_args)
class MaxOrderCount(TradingControl):
"""
TradingControl representing a limit on the number of orders that can be
placed in a given trading day.
"""
def __init__(self, max_count):
super(MaxOrderCount, self).__init__(max_count=max_count)
self.orders_placed = 0
self.max_count = max_count
self.current_date = None
def validate(self,
sid,
amount,
_portfolio,
algo_datetime,
_algo_current_data):
"""
Fail if we've already placed self.max_count orders today.
"""
algo_date = algo_datetime.date()
# Reset order count if it's a new day.
if self.current_date and self.current_date != algo_date:
self.orders_placed = 0
self.current_date = algo_date
if self.orders_placed >= self.max_count:
self.fail(sid, amount)
self.orders_placed += 1
class MaxOrderSize(TradingControl):
"""
TradingControl representing a limit on the magnitude of any single order
placed with the given security. Can be specified by share or by dollar
value.
"""
def __init__(self, sid=None, max_shares=None, max_notional=None):
super(MaxOrderSize, self).__init__(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
self.sid = sid
self.max_shares = max_shares
self.max_notional = max_notional
if max_shares is None and max_notional is None:
raise ValueError(
"Must supply at least one of max_shares and max_notional"
)
if max_shares and max_shares < 0:
raise ValueError(
"max_shares cannot be negative."
)
if max_notional and max_notional < 0:
raise ValueError(
"max_notional must be positive."
)
def validate(self,
sid,
amount,
portfolio,
_algo_datetime,
algo_current_data):
"""
Fail if the magnitude of the given order exceeds either self.max_shares
or self.max_notional.
"""
if self.sid is not None and self.sid != sid:
return
if self.max_shares is not None and abs(amount) > self.max_shares:
self.fail(sid, amount)
current_sid_price = algo_current_data[sid].price
order_value = amount * current_sid_price
too_much_value = (self.max_notional is not None and
abs(order_value) > self.max_notional)
if too_much_value:
self.fail(sid, amount)
class MaxPositionSize(TradingControl):
"""
TradingControl representing a limit on the maximum position size that can
be held by an algo for a given security.
"""
def __init__(self, sid=None, max_shares=None, max_notional=None):
super(MaxPositionSize, self).__init__(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
self.sid = sid
self.max_shares = max_shares
self.max_notional = max_notional
if max_shares is None and max_notional is None:
raise ValueError(
"Must supply at least one of max_shares and max_notional"
)
if max_shares and max_shares < 0:
raise ValueError(
"max_shares cannot be negative."
)
if max_notional and max_notional < 0:
raise ValueError(
"max_notional must be positive."
)
def validate(self,
sid,
amount,
portfolio,
algo_datetime,
algo_current_data):
"""
Fail if the given order would cause the magnitude of our position to be
greater in shares than self.max_shares or greater in dollar value than
self.max_notional.
"""
if self.sid is not None and self.sid != sid:
return
current_share_count = portfolio.positions[sid].amount
shares_post_order = current_share_count + amount
too_many_shares = (self.max_shares is not None and
abs(shares_post_order) > self.max_shares)
if too_many_shares:
self.fail(sid, amount)
current_price = algo_current_data[sid].price
value_post_order = shares_post_order * current_price
too_much_value = (self.max_notional is not None and
abs(value_post_order) > self.max_notional)
if too_much_value:
self.fail(sid, amount)
class LongOnly(TradingControl):
"""
TradingControl representing a prohibition against holding short positions.
"""
def validate(self,
sid,
amount,
portfolio,
_algo_datetime,
_algo_current_data):
"""
Fail if we would hold negative shares of sid after completing this
order.
"""
if portfolio.positions[sid].amount + amount < 0:
self.fail(sid, amount)
+34 -2
View File
@@ -109,7 +109,7 @@ class TestAlgorithm(TradingAlgorithm):
self.sid_filter = [self.sid]
def handle_data(self, data):
# place an order for 100 shares of sid
# place an order for amount shares of sid
if self.incr < self.count:
self.order(self.sid, self.amount)
self.incr += 1
@@ -399,7 +399,39 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
self.target_shares = np.round(20 / data[0].price)
from zipline.algorithm import TradingAlgorithm
############################
# TradingControl Test Algos#
############################
class SetMaxPositionSizeAlgorithm(TradingAlgorithm):
def initialize(self, sid=None, max_shares=None, max_notional=None):
self.order_count = 0
self.set_max_position_size(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
class SetMaxOrderSizeAlgorithm(TradingAlgorithm):
def initialize(self, sid=None, max_shares=None, max_notional=None):
self.order_count = 0
self.set_max_order_size(sid=sid,
max_shares=max_shares,
max_notional=max_notional)
class SetMaxOrderCountAlgorithm(TradingAlgorithm):
def initialize(self, count):
self.order_count = 0
self.set_max_order_count(count)
class SetLongOnlyAlgorithm(TradingAlgorithm):
def initialize(self):
self.order_count = 0
self.set_long_only()
from zipline.transforms import BatchTransform, batch_transform
from zipline.transforms import MovingAverage
+13
View File
@@ -1,3 +1,4 @@
from contextlib import contextmanager
from logbook import FileHandler
from zipline.finance.blotter import ORDER_STATUS
@@ -110,3 +111,15 @@ class ExceptionTransform(object):
def update(self, event):
assert False, "An assertion message"
@contextmanager
def nullctx():
"""
Null context manager. Useful for conditionally adding a contextmanager in
a single line, e.g.:
with SomeContextManager() if some_expr else nullcontext:
do_stuff()
"""
yield