mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 08:51:44 +08:00
BUG: Fix off-by-one error in TALib wrapper.
When setting timeperiod in the talib function it subtracts by 1. We then used this subtracted value to set the window_length in the batch_transform which was then not passing a big enough panel. Ultimately this caused the talib transforms to always return nans. This also makes the unittest more stringent by explicitly comparing the output of the wrapped TALib moving average to pandas rolling_mean(). Finally, this also allows passing of window_length instead of timeperiod to allow usage of the same interface as before.
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
import pytz
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from datetime import timedelta, datetime
|
||||
from unittest import TestCase
|
||||
@@ -343,10 +344,18 @@ class TestTALIB(TestCase):
|
||||
# factory.create_test_panel_ohlc_source(self.sim_params)
|
||||
|
||||
def test_multiple_talib_with_args(self):
|
||||
zipline_transforms = [ta.MA(0, timeperiod=10), ta.MA(0, timeperiod=25)]
|
||||
zipline_transforms = [ta.MA(0, window_length=10),
|
||||
ta.MA(0, window_length=25)]
|
||||
talib_fn = talib.abstract.MA
|
||||
algo = TALIBAlgorithm(talib=zipline_transforms)
|
||||
algo.run(self.source)
|
||||
# Test if computed values match those computed by pandas rolling mean.
|
||||
np.testing.assert_array_equal(np.array(algo.talib_results.values()[0]),
|
||||
pd.rolling_mean(self.panel[0]['price'],
|
||||
10).values)
|
||||
np.testing.assert_array_equal(np.array(algo.talib_results.values()[1]),
|
||||
pd.rolling_mean(self.panel[0]['price'],
|
||||
25).values)
|
||||
for t in zipline_transforms:
|
||||
talib_result = np.array(algo.talib_results[t][-1])
|
||||
talib_data = dict()
|
||||
|
||||
@@ -98,6 +98,11 @@ def make_transform(talib_fn, name):
|
||||
'volume': volume,
|
||||
'close': close}
|
||||
|
||||
# Rename window_length to timeperiod to conform with
|
||||
# external batch_transform interface.
|
||||
if 'window_length' in kwargs:
|
||||
kwargs['timeperiod'] = kwargs['window_length']
|
||||
|
||||
self.call_kwargs = kwargs
|
||||
|
||||
# Make deepcopy of talib abstract function.
|
||||
@@ -157,7 +162,7 @@ def make_transform(talib_fn, name):
|
||||
func=zipline_wrapper,
|
||||
sids=sid,
|
||||
refresh_period=refresh_period,
|
||||
window_length=max(1, self.lookback))
|
||||
window_length=max(1, self.lookback + 1))
|
||||
|
||||
def __repr__(self):
|
||||
return 'Zipline BatchTransform: {0}'.format(
|
||||
|
||||
Reference in New Issue
Block a user