From 0bb0df525a5191b7907f9e173f83d6b99b89df41 Mon Sep 17 00:00:00 2001 From: fawce Date: Fri, 14 Sep 2012 12:33:10 -0400 Subject: [PATCH] pluggable slippage. --- zipline/finance/slippage.py | 9 +------- zipline/gens/tradesimulation.py | 6 ++++++ zipline/test_algorithms.py | 38 ++++++++++++++++++++++++++++++++- zipline/utils/test_utils.py | 16 ++++++++++---- 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/zipline/finance/slippage.py b/zipline/finance/slippage.py index 3b3570d9..9e1decb0 100644 --- a/zipline/finance/slippage.py +++ b/zipline/finance/slippage.py @@ -18,17 +18,10 @@ class VolumeShareSlippage(object): def __init__(self, volume_limit=.25, price_impact=0.1, - commission=0.03, - ttl=None): + commission=0.03): self.volume_limit = volume_limit self.price_impact = price_impact self.commission = commission - if ttl: - assert isinstance(ttl, timedelta), \ - "ttl must be a datetime.timedelta" - self.ttl = ttl - else: - self.ttl = timedelta(days=1) def simulate(self, event, open_orders): diff --git a/zipline/gens/tradesimulation.py b/zipline/gens/tradesimulation.py index ee0ca65b..9c2df532 100644 --- a/zipline/gens/tradesimulation.py +++ b/zipline/gens/tradesimulation.py @@ -130,6 +130,9 @@ class AlgorithmSimulator(object): self.algolog = Logger("AlgoLog") self.algo.set_logger(self.algolog) + # Porived user algorithm with slippage override. + self.algo.set_slippage_override(self.override_slippage) + # Handler for heartbeats during calls to handle_data. def log_heartbeats(beat_count, stackframe): t = beat_count * HEARTBEAT_INTERVAL @@ -174,6 +177,9 @@ class AlgorithmSimulator(object): # to monkey patch sys.stdout with a logbook interface. self.stdout_capture = stdout_only_pipe + def override_slippage(self, slippage): + self.order_book.slippage = slippage + def order(self, sid, amount): """ Closure to pass into the user's algo to allow placing orders diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index 62c78ffb..bcff7f81 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -44,6 +44,12 @@ The algorithm must expose methods: self.Portfolio[sid(133)]['cost_basis'] + - set_slippage_override: method that accepts a callable. Will + be set as the value of the set_slippage_override method of + the trading_client. This allows an algorithm to change the + slippage model used to predict transactions based on orders + and trade events. + """ @@ -90,7 +96,10 @@ class TestAlgorithm(): def get_sid_filter(self): return self.sid_filter -# + def set_slippage_override(self, slippage_callable): + pass + + # class HeavyBuyAlgorithm(): """ This algorithm will send a specified number of orders, to allow unit tests @@ -128,6 +137,9 @@ class HeavyBuyAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class NoopAlgorithm(object): """ Dolce fa niente. @@ -151,6 +163,9 @@ class NoopAlgorithm(object): def get_sid_filter(self): return [] + def set_slippage_override(self, slippage_callable): + pass + class ExceptionAlgorithm(object): """ Throw an exception from the method name specified in the @@ -194,6 +209,9 @@ class ExceptionAlgorithm(object): else: return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class DivByZeroAlgorithm(): def __init__(self, sid): @@ -221,6 +239,9 @@ class DivByZeroAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class InitializeTimeoutAlgorithm(): def __init__(self, sid): self.sid = sid @@ -247,6 +268,9 @@ class InitializeTimeoutAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class TooMuchProcessingAlgorithm(): def __init__(self, sid): self.sid = sid @@ -272,6 +296,9 @@ class TooMuchProcessingAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class TimeoutAlgorithm(): def __init__(self, sid): @@ -299,6 +326,9 @@ class TimeoutAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class TestPrintAlgorithm(): def __init__(self, sid): @@ -323,6 +353,9 @@ class TestPrintAlgorithm(): def get_sid_filter(self): return [self.sid] + def set_slippage_override(self, slippage_callable): + pass + class TestLoggingAlgorithm(): def __init__(self, sid): @@ -346,3 +379,6 @@ class TestLoggingAlgorithm(): def get_sid_filter(self): return [self.sid] + + def set_slippage_override(self, slippage_callable): + pass diff --git a/zipline/utils/test_utils.py b/zipline/utils/test_utils.py index 65b25e8d..4fdc3918 100644 --- a/zipline/utils/test_utils.py +++ b/zipline/utils/test_utils.py @@ -80,10 +80,16 @@ def assert_single_position(test, zipline): output, transaction_count = drain_zipline(test, zipline) - test.assertEqual( - test.zipline_test_config['order_count'], - transaction_count - ) + if 'expected_transactions' in test.zipline_test_config: + test.assertEqual( + test.zipline_test_config['expected_transactions'], + transaction_count + ) + else: + test.assertEqual( + test.zipline_test_config['order_count'], + transaction_count + ) # the final message is the risk report, the second to # last is the final day's results. Positions is a list of @@ -103,6 +109,8 @@ def assert_single_position(test, zipline): "Portfolio should have one position in " + str(sid) ) + return output, transaction_count + class ExceptionSource(object):