From aa7d5de073e8cf2e6bfedc58ff8cc0494b1b32be Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 9 May 2013 11:52:21 -0400 Subject: [PATCH] BUG: Fix off-by-one error in TALib wrapper. When setting timeperiod in the talib function it subtracts by 1. We then used this subtracted value to set the window_length in the batch_transform which was then not passing a big enough panel. Ultimately this caused the talib transforms to always return nans. This also makes the unittest more stringent by explicitly comparing the output of the wrapped TALib moving average to pandas rolling_mean(). Finally, this also allows passing of window_length instead of timeperiod to allow usage of the same interface as before. --- tests/test_transforms.py | 11 ++++++++++- zipline/transforms/ta.py | 7 ++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index e922747b..aa44fbc2 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -15,6 +15,7 @@ import pytz import numpy as np +import pandas as pd from datetime import timedelta, datetime from unittest import TestCase @@ -343,10 +344,18 @@ class TestTALIB(TestCase): # factory.create_test_panel_ohlc_source(self.sim_params) def test_multiple_talib_with_args(self): - zipline_transforms = [ta.MA(0, timeperiod=10), ta.MA(0, timeperiod=25)] + zipline_transforms = [ta.MA(0, window_length=10), + ta.MA(0, window_length=25)] talib_fn = talib.abstract.MA algo = TALIBAlgorithm(talib=zipline_transforms) algo.run(self.source) + # Test if computed values match those computed by pandas rolling mean. + np.testing.assert_array_equal(np.array(algo.talib_results.values()[0]), + pd.rolling_mean(self.panel[0]['price'], + 10).values) + np.testing.assert_array_equal(np.array(algo.talib_results.values()[1]), + pd.rolling_mean(self.panel[0]['price'], + 25).values) for t in zipline_transforms: talib_result = np.array(algo.talib_results[t][-1]) talib_data = dict() diff --git a/zipline/transforms/ta.py b/zipline/transforms/ta.py index 9c7c918b..16d6a053 100644 --- a/zipline/transforms/ta.py +++ b/zipline/transforms/ta.py @@ -98,6 +98,11 @@ def make_transform(talib_fn, name): 'volume': volume, 'close': close} + # Rename window_length to timeperiod to conform with + # external batch_transform interface. + if 'window_length' in kwargs: + kwargs['timeperiod'] = kwargs['window_length'] + self.call_kwargs = kwargs # Make deepcopy of talib abstract function. @@ -157,7 +162,7 @@ def make_transform(talib_fn, name): func=zipline_wrapper, sids=sid, refresh_period=refresh_period, - window_length=max(1, self.lookback)) + window_length=max(1, self.lookback + 1)) def __repr__(self): return 'Zipline BatchTransform: {0}'.format(