mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 14:29:26 +08:00
ENH: Add trading controls to zipline API.
Adds four new methods to the Zipline API that can be used as circuit-breakers to interrupt the execution of an algorithm. The API methods are: `set_max_position_size` `set_max_order_size` `set_max_order_count` `set_long_only` Internally, these methods are implemented by each registering a TradingControl callback object with the TradingAlgorithm. During TradingAlgorithm.__validate_order_params (and thus before any side-effects of the order call occur), each callback's `validate` method is called with information about the order to be placed and the algorithm's current state, raising an exception if the callback detects that an error condition has been breached.
This commit is contained in:
+247
-3
@@ -13,16 +13,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import TestCase
|
||||
from datetime import timedelta
|
||||
from mock import MagicMock
|
||||
from six.moves import range
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mock import MagicMock
|
||||
|
||||
from zipline.utils.test_utils import setup_logger
|
||||
from zipline.utils.test_utils import (
|
||||
nullctx,
|
||||
setup_logger
|
||||
)
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.utils.simfactory as simfactory
|
||||
|
||||
from zipline.errors import (
|
||||
RegisterTradingControlPostInit,
|
||||
TradingControlViolation,
|
||||
)
|
||||
from zipline.test_algorithms import (
|
||||
AmbitiousStopLimitAlgorithm,
|
||||
EmptyPositionsAlgorithm,
|
||||
@@ -37,6 +46,10 @@ from zipline.test_algorithms import (
|
||||
TestTargetAlgorithm,
|
||||
TestTargetPercentAlgorithm,
|
||||
TestTargetValueAlgorithm,
|
||||
SetLongOnlyAlgorithm,
|
||||
SetMaxPositionSizeAlgorithm,
|
||||
SetMaxOrderCountAlgorithm,
|
||||
SetMaxOrderSizeAlgorithm,
|
||||
api_algo,
|
||||
api_symbol_algo,
|
||||
call_all_order_methods,
|
||||
@@ -560,3 +573,234 @@ def handle_data(context, data):
|
||||
algo = TradingAlgorithm(script=history_algo, data_frequency='minute')
|
||||
output = algo.run(source)
|
||||
self.assertIsNot(output, None)
|
||||
|
||||
|
||||
class TestTradingControls(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sim_params = factory.create_simulation_parameters(num_days=4)
|
||||
self.sid = 133
|
||||
self.trade_history = factory.create_trade_history(
|
||||
self.sid,
|
||||
[10.0, 10.0, 11.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(days=1),
|
||||
self.sim_params
|
||||
)
|
||||
|
||||
self.source = SpecificEquityTrades(event_list=self.trade_history)
|
||||
|
||||
def _check_algo(self,
|
||||
algo,
|
||||
handle_data,
|
||||
expected_order_count,
|
||||
expected_exc):
|
||||
|
||||
algo._handle_data = handle_data
|
||||
with self.assertRaises(expected_exc) if expected_exc else nullctx():
|
||||
algo.run(self.source)
|
||||
self.assertEqual(algo.order_count, expected_order_count)
|
||||
self.source.rewind()
|
||||
|
||||
def check_algo_succeeds(self, algo, handle_data, order_count=4):
|
||||
# Default for order_count assumes one order per handle_data call.
|
||||
self._check_algo(algo, handle_data, order_count, None)
|
||||
|
||||
def check_algo_fails(self, algo, handle_data, order_count):
|
||||
self._check_algo(algo,
|
||||
handle_data,
|
||||
order_count,
|
||||
TradingControlViolation)
|
||||
|
||||
def test_set_max_position_size(self):
|
||||
|
||||
# Buy one share four times. Should be fine.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy three shares four times. Should bail on the fourth before it's
|
||||
# placed.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 3)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Buy two shares four times. Should bail due to max_notional on the
|
||||
# third attempt.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 3)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=61.0)
|
||||
self.check_algo_fails(algo, handle_data, 2)
|
||||
|
||||
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
|
||||
# Should continue normally.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
|
||||
max_shares=10,
|
||||
max_notional=61.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Set the trading control sid to None, then BUY ALL THE THINGS!. Should
|
||||
# fail because setting sid to None makes the control apply to all sids.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0)
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
def test_set_max_order_size(self):
|
||||
|
||||
# Buy one share.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
|
||||
# because we exceed shares.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, algo.order_count + 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=3,
|
||||
max_notional=500.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
|
||||
# because we exceed notional.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, algo.order_count + 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=40.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
|
||||
# Should continue normally.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
|
||||
max_shares=1,
|
||||
max_notional=1.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Set the trading control sid to None, then BUY ALL THE THINGS!.
|
||||
# Should fail because not specifying a sid makes the trading control
|
||||
# apply to all sids.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(max_shares=1,
|
||||
max_notional=1.0)
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
def test_set_max_order_count(self):
|
||||
|
||||
# Override the default setUp to use six-hour intervals instead of full
|
||||
# days so we can exercise trading-session rollover logic.
|
||||
trade_history = factory.create_trade_history(
|
||||
self.sid,
|
||||
[10.0, 10.0, 11.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(hours=6),
|
||||
self.sim_params
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
def handle_data(algo, data):
|
||||
for i in range(5):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderCountAlgorithm(3)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Second call to handle_data is the same day as the first, so the last
|
||||
# order of the second call should fail.
|
||||
algo = SetMaxOrderCountAlgorithm(9)
|
||||
self.check_algo_fails(algo, handle_data, 9)
|
||||
|
||||
# Only ten orders are placed per day, so this should pass even though
|
||||
# in total more than 20 orders are placed.
|
||||
algo = SetMaxOrderCountAlgorithm(10)
|
||||
self.check_algo_succeeds(algo, handle_data, order_count=20)
|
||||
|
||||
def test_long_only(self):
|
||||
|
||||
# Sell immediately -> fail immediately.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, -1)
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
# Buy on even days, sell on odd days. Never takes a short position, so
|
||||
# should succeed.
|
||||
def handle_data(algo, data):
|
||||
if (algo.order_count % 2) == 0:
|
||||
algo.order(self.sid, 1)
|
||||
else:
|
||||
algo.order(self.sid, -1)
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy on first three days, then sell off holdings. Should succeed.
|
||||
def handle_data(algo, data):
|
||||
amounts = [1, 1, 1, -3]
|
||||
algo.order(self.sid, amounts[algo.order_count])
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy on first three days, then sell off holdings plus an extra share.
|
||||
# Should fail on the last sale.
|
||||
def handle_data(algo, data):
|
||||
amounts = [1, 1, 1, -4]
|
||||
algo.order(self.sid, amounts[algo.order_count])
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
def test_register_post_init(self):
|
||||
|
||||
def initialize(algo):
|
||||
algo.initialized = True
|
||||
|
||||
def handle_data(algo, data):
|
||||
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_position_size(self.sid, 1, 1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_order_size(self.sid, 1, 1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_order_count(1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_long_only()
|
||||
|
||||
algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
algo.run(self.source)
|
||||
self.source.rewind()
|
||||
|
||||
+86
-4
@@ -26,16 +26,23 @@ from six import iteritems, exec_
|
||||
from operator import attrgetter
|
||||
|
||||
from zipline.errors import (
|
||||
UnsupportedSlippageModel,
|
||||
OverrideSlippagePostInit,
|
||||
UnsupportedCommissionModel,
|
||||
OverrideCommissionPostInit,
|
||||
UnsupportedOrderParameters
|
||||
OverrideSlippagePostInit,
|
||||
RegisterTradingControlPostInit,
|
||||
UnsupportedCommissionModel,
|
||||
UnsupportedOrderParameters,
|
||||
UnsupportedSlippageModel,
|
||||
)
|
||||
|
||||
from zipline.finance import trading
|
||||
from zipline.finance.blotter import Blotter
|
||||
from zipline.finance.commission import PerShare, PerTrade, PerDollar
|
||||
from zipline.finance.controls import (
|
||||
LongOnly,
|
||||
MaxOrderCount,
|
||||
MaxOrderSize,
|
||||
MaxPositionSize,
|
||||
)
|
||||
from zipline.finance.constants import ANNUALIZER
|
||||
from zipline.finance.execution import (
|
||||
LimitOrder,
|
||||
@@ -125,6 +132,9 @@ class TradingAlgorithm(object):
|
||||
self.transforms = []
|
||||
self.sources = []
|
||||
|
||||
# List of trading controls to be used to validate orders.
|
||||
self.trading_controls = []
|
||||
|
||||
self._recorded_vars = {}
|
||||
self.namespace = kwargs.get('namespace', {})
|
||||
|
||||
@@ -525,6 +535,13 @@ class TradingAlgorithm(object):
|
||||
msg="Passing both stop_price and style is not supported."
|
||||
)
|
||||
|
||||
for control in self.trading_controls:
|
||||
control.validate(sid,
|
||||
amount,
|
||||
self.updated_portfolio(),
|
||||
self.get_datetime(),
|
||||
self.trading_client.current_data)
|
||||
|
||||
@staticmethod
|
||||
def __convert_order_params_for_blotter(limit_price, stop_price, style):
|
||||
"""
|
||||
@@ -799,3 +816,68 @@ class TradingAlgorithm(object):
|
||||
bar_count, frequency, field, ffill)
|
||||
history_spec = self.history_specs[spec_key_str]
|
||||
return self.history_container.get_history(history_spec, self.datetime)
|
||||
|
||||
####################
|
||||
# Trading Controls #
|
||||
####################
|
||||
|
||||
def register_trading_control(self, control):
|
||||
"""
|
||||
Register a new TradingControl to be checked prior to order calls.
|
||||
"""
|
||||
if self.initialized:
|
||||
raise RegisterTradingControlPostInit()
|
||||
self.trading_controls.append(control)
|
||||
|
||||
@api_method
|
||||
def set_max_position_size(self,
|
||||
sid=None,
|
||||
max_shares=None,
|
||||
max_notional=None):
|
||||
"""
|
||||
Set a limit on the number of shares and/or dollar value held for the
|
||||
given sid. Limits are treated as absolute values and are enforced at
|
||||
the time that the algo attempts to place an order for sid. This means
|
||||
that it's possible to end up with more than the max number of shares
|
||||
due to splits/dividends, and more than the max notional due to price
|
||||
improvement.
|
||||
|
||||
If an algorithm attempts to place an order that would result in
|
||||
increasing the absolute value of shares/dollar value exceeding one of
|
||||
these limits, raise a TradingControlException.
|
||||
"""
|
||||
control = MaxPositionSize(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.register_trading_control(control)
|
||||
|
||||
@api_method
|
||||
def set_max_order_size(self, sid=None, max_shares=None, max_notional=None):
|
||||
"""
|
||||
Set a limit on the number of shares and/or dollar value of any single
|
||||
order placed for sid. Limits are treated as absolute values and are
|
||||
enforced at the time that the algo attempts to place an order for sid.
|
||||
|
||||
If an algorithm attempts to place an order that would result in
|
||||
exceeding one of these limits, raise a TradingControlException.
|
||||
"""
|
||||
control = MaxOrderSize(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.register_trading_control(control)
|
||||
|
||||
@api_method
|
||||
def set_max_order_count(self, max_count):
|
||||
"""
|
||||
Set a limit on the number of orders that can be placed within the given
|
||||
time interval.
|
||||
"""
|
||||
control = MaxOrderCount(max_count)
|
||||
self.register_trading_control(control)
|
||||
|
||||
@api_method
|
||||
def set_long_only(self):
|
||||
"""
|
||||
Set a rule specifying that this algorithm cannot take short positions.
|
||||
"""
|
||||
self.register_trading_control(LongOnly())
|
||||
|
||||
@@ -60,6 +60,15 @@ method.
|
||||
""".strip()
|
||||
|
||||
|
||||
class RegisterTradingControlPostInit(ZiplineError):
|
||||
# Raised if a user's script register's a trading control after initialize
|
||||
# has been run.
|
||||
msg = """
|
||||
You attempted to set a trading control after the simulation has \
|
||||
started. Trading controls may only be set during initialize.
|
||||
""".strip()
|
||||
|
||||
|
||||
class UnsupportedCommissionModel(ZiplineError):
|
||||
"""
|
||||
Raised if a user script calls the override_commission magic
|
||||
@@ -128,3 +137,12 @@ class UnsupportedOrderParameters(ZiplineError):
|
||||
call.
|
||||
"""
|
||||
msg = "{msg}"
|
||||
|
||||
|
||||
class TradingControlViolation(ZiplineError):
|
||||
"""
|
||||
Raised if an order would violate a constraint set by a TradingControl.
|
||||
"""
|
||||
msg = """
|
||||
Order for {amount} shares of {sid} violates trading constraint {constraint}.
|
||||
""".strip()
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
#
|
||||
# Copyright 2014 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
|
||||
from six import with_metaclass
|
||||
|
||||
from zipline.errors import TradingControlViolation
|
||||
|
||||
|
||||
class TradingControl(with_metaclass(abc.ABCMeta)):
|
||||
"""
|
||||
Abstract base class representing a fail-safe control on the behavior of any
|
||||
algorithm.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
Track any arguments that should be printed in the error message
|
||||
generated by self.fail.
|
||||
"""
|
||||
self.__fail_args = kwargs
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate(self,
|
||||
sid,
|
||||
amount,
|
||||
portfolio,
|
||||
algo_datetime,
|
||||
algo_current_data):
|
||||
"""
|
||||
Before any order is executed by TradingAlgorithm, this method should be
|
||||
called *exactly once* on each registered TradingControl object.
|
||||
|
||||
If the specified sid and amount do not violate this TradingControl's
|
||||
restraint given the information in `portfolio`, this method should
|
||||
return None and have no externally-visible side-effects.
|
||||
|
||||
If the desired order violates this TradingControl's contraint, this
|
||||
method should call self.fail(sid, amount).
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def fail(self, sid, amount):
|
||||
"""
|
||||
Raise a TradingControlViolation with information about the failure.
|
||||
"""
|
||||
raise TradingControlViolation(sid=sid,
|
||||
amount=amount,
|
||||
constraint=repr(self))
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}({attrs})".format(name=self.__class__.__name__,
|
||||
attrs=self.__fail_args)
|
||||
|
||||
|
||||
class MaxOrderCount(TradingControl):
|
||||
"""
|
||||
TradingControl representing a limit on the number of orders that can be
|
||||
placed in a given trading day.
|
||||
"""
|
||||
|
||||
def __init__(self, max_count):
|
||||
|
||||
super(MaxOrderCount, self).__init__(max_count=max_count)
|
||||
self.orders_placed = 0
|
||||
self.max_count = max_count
|
||||
self.current_date = None
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
amount,
|
||||
_portfolio,
|
||||
algo_datetime,
|
||||
_algo_current_data):
|
||||
"""
|
||||
Fail if we've already placed self.max_count orders today.
|
||||
"""
|
||||
algo_date = algo_datetime.date()
|
||||
|
||||
# Reset order count if it's a new day.
|
||||
if self.current_date and self.current_date != algo_date:
|
||||
self.orders_placed = 0
|
||||
self.current_date = algo_date
|
||||
|
||||
if self.orders_placed >= self.max_count:
|
||||
self.fail(sid, amount)
|
||||
self.orders_placed += 1
|
||||
|
||||
|
||||
class MaxOrderSize(TradingControl):
|
||||
"""
|
||||
TradingControl representing a limit on the magnitude of any single order
|
||||
placed with the given security. Can be specified by share or by dollar
|
||||
value.
|
||||
"""
|
||||
|
||||
def __init__(self, sid=None, max_shares=None, max_notional=None):
|
||||
super(MaxOrderSize, self).__init__(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.sid = sid
|
||||
self.max_shares = max_shares
|
||||
self.max_notional = max_notional
|
||||
|
||||
if max_shares is None and max_notional is None:
|
||||
raise ValueError(
|
||||
"Must supply at least one of max_shares and max_notional"
|
||||
)
|
||||
|
||||
if max_shares and max_shares < 0:
|
||||
raise ValueError(
|
||||
"max_shares cannot be negative."
|
||||
)
|
||||
|
||||
if max_notional and max_notional < 0:
|
||||
raise ValueError(
|
||||
"max_notional must be positive."
|
||||
)
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
amount,
|
||||
portfolio,
|
||||
_algo_datetime,
|
||||
algo_current_data):
|
||||
"""
|
||||
Fail if the magnitude of the given order exceeds either self.max_shares
|
||||
or self.max_notional.
|
||||
"""
|
||||
|
||||
if self.sid is not None and self.sid != sid:
|
||||
return
|
||||
|
||||
if self.max_shares is not None and abs(amount) > self.max_shares:
|
||||
self.fail(sid, amount)
|
||||
|
||||
current_sid_price = algo_current_data[sid].price
|
||||
order_value = amount * current_sid_price
|
||||
|
||||
too_much_value = (self.max_notional is not None and
|
||||
abs(order_value) > self.max_notional)
|
||||
|
||||
if too_much_value:
|
||||
self.fail(sid, amount)
|
||||
|
||||
|
||||
class MaxPositionSize(TradingControl):
|
||||
"""
|
||||
TradingControl representing a limit on the maximum position size that can
|
||||
be held by an algo for a given security.
|
||||
"""
|
||||
|
||||
def __init__(self, sid=None, max_shares=None, max_notional=None):
|
||||
super(MaxPositionSize, self).__init__(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
self.sid = sid
|
||||
self.max_shares = max_shares
|
||||
self.max_notional = max_notional
|
||||
|
||||
if max_shares is None and max_notional is None:
|
||||
raise ValueError(
|
||||
"Must supply at least one of max_shares and max_notional"
|
||||
)
|
||||
|
||||
if max_shares and max_shares < 0:
|
||||
raise ValueError(
|
||||
"max_shares cannot be negative."
|
||||
)
|
||||
|
||||
if max_notional and max_notional < 0:
|
||||
raise ValueError(
|
||||
"max_notional must be positive."
|
||||
)
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
amount,
|
||||
portfolio,
|
||||
algo_datetime,
|
||||
algo_current_data):
|
||||
"""
|
||||
Fail if the given order would cause the magnitude of our position to be
|
||||
greater in shares than self.max_shares or greater in dollar value than
|
||||
self.max_notional.
|
||||
"""
|
||||
|
||||
if self.sid is not None and self.sid != sid:
|
||||
return
|
||||
|
||||
current_share_count = portfolio.positions[sid].amount
|
||||
shares_post_order = current_share_count + amount
|
||||
|
||||
too_many_shares = (self.max_shares is not None and
|
||||
abs(shares_post_order) > self.max_shares)
|
||||
if too_many_shares:
|
||||
self.fail(sid, amount)
|
||||
|
||||
current_price = algo_current_data[sid].price
|
||||
value_post_order = shares_post_order * current_price
|
||||
|
||||
too_much_value = (self.max_notional is not None and
|
||||
abs(value_post_order) > self.max_notional)
|
||||
|
||||
if too_much_value:
|
||||
self.fail(sid, amount)
|
||||
|
||||
|
||||
class LongOnly(TradingControl):
|
||||
"""
|
||||
TradingControl representing a prohibition against holding short positions.
|
||||
"""
|
||||
|
||||
def validate(self,
|
||||
sid,
|
||||
amount,
|
||||
portfolio,
|
||||
_algo_datetime,
|
||||
_algo_current_data):
|
||||
"""
|
||||
Fail if we would hold negative shares of sid after completing this
|
||||
order.
|
||||
"""
|
||||
if portfolio.positions[sid].amount + amount < 0:
|
||||
self.fail(sid, amount)
|
||||
@@ -109,7 +109,7 @@ class TestAlgorithm(TradingAlgorithm):
|
||||
self.sid_filter = [self.sid]
|
||||
|
||||
def handle_data(self, data):
|
||||
# place an order for 100 shares of sid
|
||||
# place an order for amount shares of sid
|
||||
if self.incr < self.count:
|
||||
self.order(self.sid, self.amount)
|
||||
self.incr += 1
|
||||
@@ -399,7 +399,39 @@ class TestTargetValueAlgorithm(TradingAlgorithm):
|
||||
self.target_shares = np.round(20 / data[0].price)
|
||||
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
############################
|
||||
# TradingControl Test Algos#
|
||||
############################
|
||||
|
||||
|
||||
class SetMaxPositionSizeAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, sid=None, max_shares=None, max_notional=None):
|
||||
self.order_count = 0
|
||||
self.set_max_position_size(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
|
||||
|
||||
class SetMaxOrderSizeAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, sid=None, max_shares=None, max_notional=None):
|
||||
self.order_count = 0
|
||||
self.set_max_order_size(sid=sid,
|
||||
max_shares=max_shares,
|
||||
max_notional=max_notional)
|
||||
|
||||
|
||||
class SetMaxOrderCountAlgorithm(TradingAlgorithm):
|
||||
def initialize(self, count):
|
||||
self.order_count = 0
|
||||
self.set_max_order_count(count)
|
||||
|
||||
|
||||
class SetLongOnlyAlgorithm(TradingAlgorithm):
|
||||
def initialize(self):
|
||||
self.order_count = 0
|
||||
self.set_long_only()
|
||||
|
||||
|
||||
from zipline.transforms import BatchTransform, batch_transform
|
||||
from zipline.transforms import MovingAverage
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from contextlib import contextmanager
|
||||
from logbook import FileHandler
|
||||
from zipline.finance.blotter import ORDER_STATUS
|
||||
|
||||
@@ -110,3 +111,15 @@ class ExceptionTransform(object):
|
||||
|
||||
def update(self, event):
|
||||
assert False, "An assertion message"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def nullctx():
|
||||
"""
|
||||
Null context manager. Useful for conditionally adding a contextmanager in
|
||||
a single line, e.g.:
|
||||
|
||||
with SomeContextManager() if some_expr else nullcontext:
|
||||
do_stuff()
|
||||
"""
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user