mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 11:59:14 +08:00
pluggable slippage.
This commit is contained in:
@@ -18,17 +18,10 @@ class VolumeShareSlippage(object):
|
||||
def __init__(self,
|
||||
volume_limit=.25,
|
||||
price_impact=0.1,
|
||||
commission=0.03,
|
||||
ttl=None):
|
||||
commission=0.03):
|
||||
self.volume_limit = volume_limit
|
||||
self.price_impact = price_impact
|
||||
self.commission = commission
|
||||
if ttl:
|
||||
assert isinstance(ttl, timedelta), \
|
||||
"ttl must be a datetime.timedelta"
|
||||
self.ttl = ttl
|
||||
else:
|
||||
self.ttl = timedelta(days=1)
|
||||
|
||||
def simulate(self, event, open_orders):
|
||||
|
||||
|
||||
@@ -130,6 +130,9 @@ class AlgorithmSimulator(object):
|
||||
self.algolog = Logger("AlgoLog")
|
||||
self.algo.set_logger(self.algolog)
|
||||
|
||||
# Porived user algorithm with slippage override.
|
||||
self.algo.set_slippage_override(self.override_slippage)
|
||||
|
||||
# Handler for heartbeats during calls to handle_data.
|
||||
def log_heartbeats(beat_count, stackframe):
|
||||
t = beat_count * HEARTBEAT_INTERVAL
|
||||
@@ -174,6 +177,9 @@ class AlgorithmSimulator(object):
|
||||
# to monkey patch sys.stdout with a logbook interface.
|
||||
self.stdout_capture = stdout_only_pipe
|
||||
|
||||
def override_slippage(self, slippage):
|
||||
self.order_book.slippage = slippage
|
||||
|
||||
def order(self, sid, amount):
|
||||
"""
|
||||
Closure to pass into the user's algo to allow placing orders
|
||||
|
||||
@@ -44,6 +44,12 @@ The algorithm must expose methods:
|
||||
|
||||
self.Portfolio[sid(133)]['cost_basis']
|
||||
|
||||
- set_slippage_override: method that accepts a callable. Will
|
||||
be set as the value of the set_slippage_override method of
|
||||
the trading_client. This allows an algorithm to change the
|
||||
slippage model used to predict transactions based on orders
|
||||
and trade events.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -90,7 +96,10 @@ class TestAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return self.sid_filter
|
||||
|
||||
#
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
#
|
||||
class HeavyBuyAlgorithm():
|
||||
"""
|
||||
This algorithm will send a specified number of orders, to allow unit tests
|
||||
@@ -128,6 +137,9 @@ class HeavyBuyAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class NoopAlgorithm(object):
|
||||
"""
|
||||
Dolce fa niente.
|
||||
@@ -151,6 +163,9 @@ class NoopAlgorithm(object):
|
||||
def get_sid_filter(self):
|
||||
return []
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class ExceptionAlgorithm(object):
|
||||
"""
|
||||
Throw an exception from the method name specified in the
|
||||
@@ -194,6 +209,9 @@ class ExceptionAlgorithm(object):
|
||||
else:
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class DivByZeroAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
@@ -221,6 +239,9 @@ class DivByZeroAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class InitializeTimeoutAlgorithm():
|
||||
def __init__(self, sid):
|
||||
self.sid = sid
|
||||
@@ -247,6 +268,9 @@ class InitializeTimeoutAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class TooMuchProcessingAlgorithm():
|
||||
def __init__(self, sid):
|
||||
self.sid = sid
|
||||
@@ -272,6 +296,9 @@ class TooMuchProcessingAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class TimeoutAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
@@ -299,6 +326,9 @@ class TimeoutAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class TestPrintAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
@@ -323,6 +353,9 @@ class TestPrintAlgorithm():
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
class TestLoggingAlgorithm():
|
||||
|
||||
def __init__(self, sid):
|
||||
@@ -346,3 +379,6 @@ class TestLoggingAlgorithm():
|
||||
|
||||
def get_sid_filter(self):
|
||||
return [self.sid]
|
||||
|
||||
def set_slippage_override(self, slippage_callable):
|
||||
pass
|
||||
|
||||
@@ -80,10 +80,16 @@ def assert_single_position(test, zipline):
|
||||
|
||||
output, transaction_count = drain_zipline(test, zipline)
|
||||
|
||||
test.assertEqual(
|
||||
test.zipline_test_config['order_count'],
|
||||
transaction_count
|
||||
)
|
||||
if 'expected_transactions' in test.zipline_test_config:
|
||||
test.assertEqual(
|
||||
test.zipline_test_config['expected_transactions'],
|
||||
transaction_count
|
||||
)
|
||||
else:
|
||||
test.assertEqual(
|
||||
test.zipline_test_config['order_count'],
|
||||
transaction_count
|
||||
)
|
||||
|
||||
# the final message is the risk report, the second to
|
||||
# last is the final day's results. Positions is a list of
|
||||
@@ -103,6 +109,8 @@ def assert_single_position(test, zipline):
|
||||
"Portfolio should have one position in " + str(sid)
|
||||
)
|
||||
|
||||
return output, transaction_count
|
||||
|
||||
|
||||
class ExceptionSource(object):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user