ENH: Add support for TALib based transforms.

Provide a subclass of BatchTransforms that are powerd by the ta-lib
library.
This commit is contained in:
Jeremiah Lowin
2013-04-27 21:31:34 -04:00
committed by Eddie Hebert
parent beecebc7d8
commit cc39ec3aef
4 changed files with 311 additions and 0 deletions
+81
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytz
import numpy as np
@@ -29,6 +30,8 @@ from zipline.transforms import MovingStandardDev
from zipline.transforms import Returns
import zipline.utils.factory as factory
from zipline.test_algorithms import TALIBAlgorithm
def to_dt(msg):
return Event({'dt': msg})
@@ -270,3 +273,81 @@ class TestFinanceTransforms(TestCase):
self.assertIsNone(v2)
continue
self.assertEquals(round(v1, 5), round(v2, 5))
############################################################
# Test TALIB
import talib
import zipline.transforms.ta as ta
class TestTALIB(TestCase):
def setUp(self):
setup_logger(self)
sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 3, 30, tzinfo=pytz.utc))
self.source, self.panel = \
factory.create_test_panel_ohlc_source(sim_params)
def test_talib_with_default_params(self):
BLACKLIST = ['make_transform', 'BatchTransform']
names = [n for n in dir(ta) if n[0].isupper()
and n not in BLACKLIST]
for name in names:
print name
zipline_transform = getattr(ta, name)(sid=0)
talib_fn = getattr(talib.abstract, name)
start = datetime(1990, 1, 1, tzinfo=pytz.utc)
end = start + timedelta(days=zipline_transform.lookback + 10)
sim_params = factory.create_simulation_parameters(
start=start, end=end)
source, panel = \
factory.create_test_panel_ohlc_source(sim_params)
algo = TALIBAlgorithm(talib=zipline_transform)
algo.run(source)
zipline_result = np.array(
algo.talib_results[zipline_transform][-1])
talib_data = dict()
data = zipline_transform.window
for key in ['open', 'high', 'low', 'volume']:
if key in data:
talib_data[key] = data[key][0].values
talib_data['close'] = data['price'][0].values
expected_result = talib_fn(talib_data)
if isinstance(expected_result, list):
expected_result = np.array([e[-1] for e in expected_result])
else:
expected_result = np.array(expected_result[-1])
if not (np.all(np.isnan(zipline_result))
and np.all(np.isnan(expected_result))):
self.assertTrue(np.allclose(zipline_result, expected_result))
else:
print '--- NAN'
# reset generator so next iteration has data
# self.source, self.panel = \
# factory.create_test_panel_ohlc_source(self.sim_params)
def test_multiple_talib_with_args(self):
zipline_transforms = [ta.MA(0, timeperiod=10), ta.MA(0, timeperiod=25)]
talib_fn = talib.abstract.MA
algo = TALIBAlgorithm(talib=zipline_transforms)
algo.run(self.source)
for t in zipline_transforms:
talib_result = np.array(algo.talib_results[t][-1])
talib_data = dict()
data = t.window
for key in ['open', 'high', 'low', 'volume']:
if key in data:
talib_data[key] = data[key][0].values
talib_data['close'] = data['price'][0].values
expected_result = talib_fn(talib_data, **t.call_kwargs)[-1]
self.assertTrue(np.allclose(talib_result, expected_result))
+29
View File
@@ -447,3 +447,32 @@ class SetPortfolioAlgorithm(TradingAlgorithm):
def handle_data(self, data):
self.portfolio = 3
class TALIBAlgorithm(TradingAlgorithm):
"""
An algorithm that applies a TA-Lib transform. The transform object can be
passed at initialization with the 'talib' keyword argument. The results are
stored in the talib_results array.
"""
def initialize(self, *args, **kwargs):
if 'talib' not in kwargs:
raise KeyError('No TA-LIB transform specified '
'(use keyword \'talib\').')
elif not isinstance(kwargs['talib'], (list, tuple)):
self.talib_transforms = (kwargs['talib'],)
else:
self.talib_transforms = kwargs['talib']
self.talib_results = dict((t, []) for t in self.talib_transforms)
def handle_data(self, data):
for t in self.talib_transforms:
result = t.handle_data(data)
if result is None:
if len(t.talib_fn.output_names) == 1:
result = np.nan
else:
result = (np.nan,) * len(t.talib_fn.output_names)
self.talib_results[t].append(result)
+174
View File
@@ -0,0 +1,174 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import talib
import copy
from zipline.transforms import BatchTransform
def make_transform(talib_fn):
"""
A factory for BatchTransforms based on TALIB abstract functions.
"""
class TALibTransform(BatchTransform):
"""
TA-Lib keyword arguments must be passed at initialization. For
example, to construct a moving average with timeperiod of 5, pass
"timeperiod=5" during initialization.
All abstract TA-Lib functions accept a data dictionary containing
'open', 'high', 'low', 'close', and 'volume' keys, even if they do
not require those keys to run. For example, talib.MA (moving
average) is always computed using the data under the 'close'
key. By default, Zipline constructs this data dictionary with the
appropriate sid data, but users may overwrite this by passing
mappings as keyword arguments. For example, to compute the moving
average of the sid's high, provide "close = 'high'" and Zipline's
'high' data will be used as TA-Lib's 'close' data. Similarly, if a
user had a data column named 'Oil', they could compute its moving
average by passing "close='Oil'".
Example
--------
A moving average of a data column called 'Oil' with timeperiod 5,
for sid 'XYZ':
talib.transforms.ta.MA('XYZ', close='Oil', timeperiod=5)
The user could find the default arguments and mappings by calling:
help(zipline.transforms.ta.MA)
Arguments
---------
sid : zipline sid
open : string, default 'open'
high : string, default 'high'
low : string, default 'low'
close : string, default 'price'
volume : string, default 'volume'
refresh_period : int, default 0
The refresh_period of the BatchTransform determines the number
of iterations that pass before the BatchTransform updates its
internal data.
**kwargs : any arguments to be passed to the TA-Lib function.
"""
def __init__(self,
sid,
close='price',
open='open',
high='high',
low='low',
volume='volume',
refresh_period=0,
**kwargs):
key_map = {'high': high,
'low': low,
'open': open,
'volume': volume,
'close': close}
self.call_kwargs = kwargs
# Make deepcopy of talib abstract function.
# This is necessary because talib abstract functions remember
# state, including parameters, and we need to set the parameters
# in order to compute the lookback period that will determine the
# BatchTransform window_length. TALIB has no way to restore default
# parameters, so the deepcopy lets us change this function's
# parameters without affecting other TALibTransforms of the same
# function.
self.talib_fn = copy.deepcopy(talib_fn)
# set the parameters
for param in self.talib_fn.get_parameters().keys():
if param in kwargs:
self.talib_fn.set_parameters({param: kwargs[param]})
# get the lookback
self.lookback = self.talib_fn.lookback
def zipline_wrapper(data):
# get required TA-Lib input names
if 'price' in self.talib_fn.input_names:
req_inputs = [self.talib_fn.input_names['price']]
elif 'prices' in self.talib_fn.input_names:
req_inputs = self.talib_fn.input_names['prices']
else:
req_inputs = []
# build talib_data from zipline data
talib_data = dict()
for talib_key, zipline_key in key_map.iteritems():
# if zipline_key is found, add it to talib_data
if zipline_key in data:
talib_data[talib_key] = data[zipline_key].values[:, 0]
# if zipline_key is not found and not required, add zeros
elif talib_key not in req_inputs:
talib_data[talib_key] = np.zeros(data.shape[1])
# if zipline key is not found and required, raise error
else:
raise KeyError(
'Tried to set required TA-Lib data with key '
'\'{0}\' but no Zipline data is available under '
'expected key \'{1}\'.'.format(
talib_key, zipline_key))
# call talib
result = self.talib_fn(talib_data)
# keep only the most recent result
if isinstance(result, (list, tuple)):
return tuple([r[-1] for r in result])
else:
return result[-1]
super(TALibTransform, self).__init__(
func=zipline_wrapper,
sids=sid,
refresh_period=refresh_period,
window_length=max(1, self.lookback))
def __repr__(self):
return 'Zipline BatchTransform: {0}'.format(
self.talib_fn.info['name'])
# make class docstring
header = '\n#---- TA-Lib docs\n\n'
talib_docs = getattr(talib, talib_fn.info['name']).__doc__
divider1 = '\n#---- Default mapping (TA-Lib : Zipline)\n\n'
mappings = '\n'.join(' {0} : {1}'.format(k, v)
for k, v in talib_fn.input_names.items())
divider2 = '\n\n#---- Zipline docs\n'
help_str = (header + talib_docs + divider1 + mappings
+ divider2 + TALibTransform.__doc__)
TALibTransform.__doc__ = help_str
#return class
return TALibTransform
# add all TA-Lib functions to locals
for name in talib.abstract.__all__:
fn = getattr(talib.abstract, name)
if name != 'Function':
locals()[name] = make_transform(fn)
+27
View File
@@ -329,6 +329,33 @@ def create_test_panel_source(sim_params=None):
return DataPanelSource(panel), panel
def create_test_panel_ohlc_source(sim_params=None):
start = sim_params.first_open \
if sim_params else pd.datetime(1990, 1, 3, 0, 0, 0, 0, pytz.utc)
end = sim_params.last_close \
if sim_params else pd.datetime(1990, 1, 8, 0, 0, 0, 0, pytz.utc)
index = pd.DatetimeIndex(start=start, end=end, freq=pd.datetools.day)
price = np.arange(0, len(index)) + 100
high = price * 1.05
low = price * 0.95
open_ = price + .1 * (price % 2 - .5)
volume = np.ones(len(index)) * 1000
arbitrary = np.ones(len(index))
df = pd.DataFrame({'price': price,
'high': high,
'low': low,
'open': open_,
'volume': volume,
'arbitrary': arbitrary},
index=index)
panel = pd.Panel.from_dict({0: df})
return DataPanelSource(panel), panel
def _load_raw_yahoo_data(indexes=None, stocks=None, start=None, end=None):
"""Load closing prices from yahoo finance.