BUG: Fix off-by-one error in TALib wrapper.

When setting timeperiod in the talib function it subtracts by 1. We then used this subtracted value to set the window_length in the batch_transform which was then not passing a big enough panel. Ultimately this caused the talib transforms to always return nans.

This also makes the unittest more stringent by explicitly comparing the output of the wrapped TALib moving average to pandas rolling_mean().

Finally, this also allows passing of window_length instead of timeperiod to allow usage of the same interface as before.
This commit is contained in:
Thomas Wiecki
2013-05-09 11:52:21 -04:00
parent 5cf1b2880d
commit aa7d5de073
2 changed files with 16 additions and 2 deletions
+10 -1
View File
@@ -15,6 +15,7 @@
import pytz
import numpy as np
import pandas as pd
from datetime import timedelta, datetime
from unittest import TestCase
@@ -343,10 +344,18 @@ class TestTALIB(TestCase):
# factory.create_test_panel_ohlc_source(self.sim_params)
def test_multiple_talib_with_args(self):
zipline_transforms = [ta.MA(0, timeperiod=10), ta.MA(0, timeperiod=25)]
zipline_transforms = [ta.MA(0, window_length=10),
ta.MA(0, window_length=25)]
talib_fn = talib.abstract.MA
algo = TALIBAlgorithm(talib=zipline_transforms)
algo.run(self.source)
# Test if computed values match those computed by pandas rolling mean.
np.testing.assert_array_equal(np.array(algo.talib_results.values()[0]),
pd.rolling_mean(self.panel[0]['price'],
10).values)
np.testing.assert_array_equal(np.array(algo.talib_results.values()[1]),
pd.rolling_mean(self.panel[0]['price'],
25).values)
for t in zipline_transforms:
talib_result = np.array(algo.talib_results[t][-1])
talib_data = dict()
+6 -1
View File
@@ -98,6 +98,11 @@ def make_transform(talib_fn, name):
'volume': volume,
'close': close}
# Rename window_length to timeperiod to conform with
# external batch_transform interface.
if 'window_length' in kwargs:
kwargs['timeperiod'] = kwargs['window_length']
self.call_kwargs = kwargs
# Make deepcopy of talib abstract function.
@@ -157,7 +162,7 @@ def make_transform(talib_fn, name):
func=zipline_wrapper,
sids=sid,
refresh_period=refresh_period,
window_length=max(1, self.lookback))
window_length=max(1, self.lookback + 1))
def __repr__(self):
return 'Zipline BatchTransform: {0}'.format(