mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-29 00:53:42 +08:00
ENH: Add trading controls to zipline API.
Adds four new methods to the Zipline API that can be used as circuit-breakers to interrupt the execution of an algorithm. The API methods are: `set_max_position_size` `set_max_order_size` `set_max_order_count` `set_long_only` Internally, these methods are implemented by each registering a TradingControl callback object with the TradingAlgorithm. During TradingAlgorithm.__validate_order_params (and thus before any side-effects of the order call occur), each callback's `validate` method is called with information about the order to be placed and the algorithm's current state, raising an exception if the callback detects that an error condition has been breached.
This commit is contained in:
+247
-3
@@ -13,16 +13,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest import TestCase
|
||||
from datetime import timedelta
|
||||
from mock import MagicMock
|
||||
from six.moves import range
|
||||
from unittest import TestCase
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from mock import MagicMock
|
||||
|
||||
from zipline.utils.test_utils import setup_logger
|
||||
from zipline.utils.test_utils import (
|
||||
nullctx,
|
||||
setup_logger
|
||||
)
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.utils.simfactory as simfactory
|
||||
|
||||
from zipline.errors import (
|
||||
RegisterTradingControlPostInit,
|
||||
TradingControlViolation,
|
||||
)
|
||||
from zipline.test_algorithms import (
|
||||
AmbitiousStopLimitAlgorithm,
|
||||
EmptyPositionsAlgorithm,
|
||||
@@ -37,6 +46,10 @@ from zipline.test_algorithms import (
|
||||
TestTargetAlgorithm,
|
||||
TestTargetPercentAlgorithm,
|
||||
TestTargetValueAlgorithm,
|
||||
SetLongOnlyAlgorithm,
|
||||
SetMaxPositionSizeAlgorithm,
|
||||
SetMaxOrderCountAlgorithm,
|
||||
SetMaxOrderSizeAlgorithm,
|
||||
api_algo,
|
||||
api_symbol_algo,
|
||||
call_all_order_methods,
|
||||
@@ -560,3 +573,234 @@ def handle_data(context, data):
|
||||
algo = TradingAlgorithm(script=history_algo, data_frequency='minute')
|
||||
output = algo.run(source)
|
||||
self.assertIsNot(output, None)
|
||||
|
||||
|
||||
class TestTradingControls(TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sim_params = factory.create_simulation_parameters(num_days=4)
|
||||
self.sid = 133
|
||||
self.trade_history = factory.create_trade_history(
|
||||
self.sid,
|
||||
[10.0, 10.0, 11.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(days=1),
|
||||
self.sim_params
|
||||
)
|
||||
|
||||
self.source = SpecificEquityTrades(event_list=self.trade_history)
|
||||
|
||||
def _check_algo(self,
|
||||
algo,
|
||||
handle_data,
|
||||
expected_order_count,
|
||||
expected_exc):
|
||||
|
||||
algo._handle_data = handle_data
|
||||
with self.assertRaises(expected_exc) if expected_exc else nullctx():
|
||||
algo.run(self.source)
|
||||
self.assertEqual(algo.order_count, expected_order_count)
|
||||
self.source.rewind()
|
||||
|
||||
def check_algo_succeeds(self, algo, handle_data, order_count=4):
|
||||
# Default for order_count assumes one order per handle_data call.
|
||||
self._check_algo(algo, handle_data, order_count, None)
|
||||
|
||||
def check_algo_fails(self, algo, handle_data, order_count):
|
||||
self._check_algo(algo,
|
||||
handle_data,
|
||||
order_count,
|
||||
TradingControlViolation)
|
||||
|
||||
def test_set_max_position_size(self):
|
||||
|
||||
# Buy one share four times. Should be fine.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy three shares four times. Should bail on the fourth before it's
|
||||
# placed.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 3)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Buy two shares four times. Should bail due to max_notional on the
|
||||
# third attempt.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 3)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=61.0)
|
||||
self.check_algo_fails(algo, handle_data, 2)
|
||||
|
||||
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
|
||||
# Should continue normally.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(sid=self.sid + 1,
|
||||
max_shares=10,
|
||||
max_notional=61.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Set the trading control sid to None, then BUY ALL THE THINGS!. Should
|
||||
# fail because setting sid to None makes the control apply to all sids.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxPositionSizeAlgorithm(max_shares=10, max_notional=61.0)
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
def test_set_max_order_size(self):
|
||||
|
||||
# Buy one share.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=500.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
|
||||
# because we exceed shares.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, algo.order_count + 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=3,
|
||||
max_notional=500.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Buy 1, then 2, then 3, then 4 shares. Bail on the last attempt
|
||||
# because we exceed notional.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, algo.order_count + 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid,
|
||||
max_shares=10,
|
||||
max_notional=40.0)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Set the trading control to a different sid, then BUY ALL THE THINGS!.
|
||||
# Should continue normally.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(sid=self.sid + 1,
|
||||
max_shares=1,
|
||||
max_notional=1.0)
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Set the trading control sid to None, then BUY ALL THE THINGS!.
|
||||
# Should fail because not specifying a sid makes the trading control
|
||||
# apply to all sids.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, 10000)
|
||||
algo.order_count += 1
|
||||
algo = SetMaxOrderSizeAlgorithm(max_shares=1,
|
||||
max_notional=1.0)
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
def test_set_max_order_count(self):
|
||||
|
||||
# Override the default setUp to use six-hour intervals instead of full
|
||||
# days so we can exercise trading-session rollover logic.
|
||||
trade_history = factory.create_trade_history(
|
||||
self.sid,
|
||||
[10.0, 10.0, 11.0, 11.0],
|
||||
[100, 100, 100, 300],
|
||||
timedelta(hours=6),
|
||||
self.sim_params
|
||||
)
|
||||
self.source = SpecificEquityTrades(event_list=trade_history)
|
||||
|
||||
def handle_data(algo, data):
|
||||
for i in range(5):
|
||||
algo.order(self.sid, 1)
|
||||
algo.order_count += 1
|
||||
|
||||
algo = SetMaxOrderCountAlgorithm(3)
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
# Second call to handle_data is the same day as the first, so the last
|
||||
# order of the second call should fail.
|
||||
algo = SetMaxOrderCountAlgorithm(9)
|
||||
self.check_algo_fails(algo, handle_data, 9)
|
||||
|
||||
# Only ten orders are placed per day, so this should pass even though
|
||||
# in total more than 20 orders are placed.
|
||||
algo = SetMaxOrderCountAlgorithm(10)
|
||||
self.check_algo_succeeds(algo, handle_data, order_count=20)
|
||||
|
||||
def test_long_only(self):
|
||||
|
||||
# Sell immediately -> fail immediately.
|
||||
def handle_data(algo, data):
|
||||
algo.order(self.sid, -1)
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_fails(algo, handle_data, 0)
|
||||
|
||||
# Buy on even days, sell on odd days. Never takes a short position, so
|
||||
# should succeed.
|
||||
def handle_data(algo, data):
|
||||
if (algo.order_count % 2) == 0:
|
||||
algo.order(self.sid, 1)
|
||||
else:
|
||||
algo.order(self.sid, -1)
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy on first three days, then sell off holdings. Should succeed.
|
||||
def handle_data(algo, data):
|
||||
amounts = [1, 1, 1, -3]
|
||||
algo.order(self.sid, amounts[algo.order_count])
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_succeeds(algo, handle_data)
|
||||
|
||||
# Buy on first three days, then sell off holdings plus an extra share.
|
||||
# Should fail on the last sale.
|
||||
def handle_data(algo, data):
|
||||
amounts = [1, 1, 1, -4]
|
||||
algo.order(self.sid, amounts[algo.order_count])
|
||||
algo.order_count += 1
|
||||
algo = SetLongOnlyAlgorithm()
|
||||
self.check_algo_fails(algo, handle_data, 3)
|
||||
|
||||
def test_register_post_init(self):
|
||||
|
||||
def initialize(algo):
|
||||
algo.initialized = True
|
||||
|
||||
def handle_data(algo, data):
|
||||
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_position_size(self.sid, 1, 1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_order_size(self.sid, 1, 1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_max_order_count(1)
|
||||
with self.assertRaises(RegisterTradingControlPostInit):
|
||||
algo.set_long_only()
|
||||
|
||||
algo = TradingAlgorithm(initialize=initialize,
|
||||
handle_data=handle_data)
|
||||
algo.run(self.source)
|
||||
self.source.rewind()
|
||||
|
||||
Reference in New Issue
Block a user