From 37352210c083f8bedabb8003f93568fa6a158feb Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Wed, 10 Jul 2013 13:52:15 -0400 Subject: [PATCH] MAINT: Make TALib zipline_wrapper a module level function. Prepare for making the zipline_wrapper operate on multiple sids, as the needed nested logic will get cramped within the nested function. Also, should help clearly define the inputs of the zipline_wrapper function that are needed before it is passed to the BatchTransform constructor. --- zipline/transforms/ta.py | 75 +++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/zipline/transforms/ta.py b/zipline/transforms/ta.py index 1cc1ae30..a4dd2f02 100644 --- a/zipline/transforms/ta.py +++ b/zipline/transforms/ta.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import math import numpy as np @@ -20,6 +21,42 @@ import copy from zipline.transforms import BatchTransform +def zipline_wrapper(talib_fn, key_map, data): + # get required TA-Lib input names + if 'price' in talib_fn.input_names: + req_inputs = [talib_fn.input_names['price']] + elif 'prices' in talib_fn.input_names: + req_inputs = talib_fn.input_names['prices'] + else: + req_inputs = [] + + # build talib_data from zipline data + talib_data = dict() + for talib_key, zipline_key in key_map.iteritems(): + # if zipline_key is found, add it to talib_data + if zipline_key in data: + talib_data[talib_key] = data[zipline_key].values[:, 0] + # if zipline_key is not found and not required, add zeros + elif talib_key not in req_inputs: + talib_data[talib_key] = np.zeros(data.shape[1]) + # if zipline key is not found and required, raise error + else: + raise KeyError( + 'Tried to set required TA-Lib data with key ' + '\'{0}\' but no Zipline data is available under ' + 'expected key \'{1}\'.'.format( + talib_key, zipline_key)) + + # call talib + result = talib_fn(talib_data) + + # keep only the most recent result + if isinstance(result, (list, tuple)): + return tuple([r[-1] for r in result]) + else: + return result[-1] + + def make_transform(talib_fn, name): """ A factory for BatchTransforms based on TALIB abstract functions. @@ -125,43 +162,11 @@ def make_transform(talib_fn, name): # Ensure that window_length is at least 1 day's worth of data. window_length = max(lookback, 1) - def zipline_wrapper(data): - # get required TA-Lib input names - if 'price' in self.talib_fn.input_names: - req_inputs = [self.talib_fn.input_names['price']] - elif 'prices' in self.talib_fn.input_names: - req_inputs = self.talib_fn.input_names['prices'] - else: - req_inputs = [] - - # build talib_data from zipline data - talib_data = dict() - for talib_key, zipline_key in key_map.iteritems(): - # if zipline_key is found, add it to talib_data - if zipline_key in data: - talib_data[talib_key] = data[zipline_key].values[:, 0] - # if zipline_key is not found and not required, add zeros - elif talib_key not in req_inputs: - talib_data[talib_key] = np.zeros(data.shape[1]) - # if zipline key is not found and required, raise error - else: - raise KeyError( - 'Tried to set required TA-Lib data with key ' - '\'{0}\' but no Zipline data is available under ' - 'expected key \'{1}\'.'.format( - talib_key, zipline_key)) - - # call talib - result = self.talib_fn(talib_data) - - # keep only the most recent result - if isinstance(result, (list, tuple)): - return tuple([r[-1] for r in result]) - else: - return result[-1] + transform_func = functools.partial( + zipline_wrapper, self.talib_fn, key_map) super(TALibTransform, self).__init__( - func=zipline_wrapper, + func=transform_func, refresh_period=refresh_period, window_length=window_length, compute_only_full=False,