From c5f4d00bf137aead51bba77ac4aec5d7906ba450 Mon Sep 17 00:00:00 2001 From: Wes McKinney Date: Thu, 21 Mar 2013 23:16:32 -0400 Subject: [PATCH 1/5] ENH: prototype data structure for managing a rolling datapanel Manage a rolling window collection of collection of panels for computation purposes. --- tests/test_data_util.py | 108 ++++++++++++++++++++++++++++++++++++++++ zipline/utils/data.py | 86 ++++++++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 tests/test_data_util.py create mode 100644 zipline/utils/data.py diff --git a/tests/test_data_util.py b/tests/test_data_util.py new file mode 100644 index 00000000..ded2c8ca --- /dev/null +++ b/tests/test_data_util.py @@ -0,0 +1,108 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from collections import deque + +import numpy as np + +import pandas as pd +import pandas.util.testing as tm + +from zipline.utils.data import RollingPanel + + +class TestRollingPanel(unittest.TestCase): + + def test_basics(self): + items = ['foo', 'bar', 'baz'] + minor = ['A', 'B', 'C', 'D'] + + window = 10 + + rp = RollingPanel(window, items, minor, cap_multiple=2) + + dates = pd.date_range('2000-01-01', periods=30, tz='utc') + + major_deque = deque() + + frames = {} + + for i in range(30): + frame = pd.DataFrame(np.random.randn(3, 4), index=items, + columns=minor) + date = dates[i] + + rp.add_frame(date, frame) + + frames[date] = frame + major_deque.append(date) + + if i >= window: + major_deque.popleft() + + result = rp.get_current() + expected = pd.Panel(frames, items=list(major_deque), + major_axis=items, minor_axis=minor) + tm.assert_panel_equal(result, expected.swapaxes(0, 1)) + + +def f(option='clever', n=500, copy=False): + items = range(5) + minor = range(20) + window = 100 + periods = n + + dates = pd.date_range('2000-01-01', periods=periods, tz='utc') + frames = {} + + if option == 'clever': + rp = RollingPanel(window, items, minor, cap_multiple=2) + major_deque = deque() + dummy = pd.DataFrame(np.random.randn(len(items), len(minor)), + index=items, columns=minor) + + for i in range(periods): + frame = dummy * (1 + 0.001 * i) + date = dates[i] + + rp.add_frame(date, frame) + + frames[date] = frame + major_deque.append(date) + + if i >= window: + del frames[major_deque.popleft()] + + result = rp.get_current() + if copy: + result = result.copy() + else: + major_deque = deque() + dummy = pd.DataFrame(np.random.randn(len(items), len(minor)), + index=items, columns=minor) + + for i in range(periods): + frame = dummy * (1 + 0.001 * i) + date = dates[i] + frames[date] = frame + major_deque.append(date) + + if i >= window: + del frames[major_deque.popleft()] + + result = pd.Panel(frames, items=list(major_deque), + major_axis=items, minor_axis=minor) diff --git a/zipline/utils/data.py b/zipline/utils/data.py new file mode 100644 index 00000000..ca9457b6 --- /dev/null +++ b/zipline/utils/data.py @@ -0,0 +1,86 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd + + +def _ensure_index(x): + if not isinstance(x, pd.Index): + x = pd.Index(x) + + return x + + +class RollingPanel(object): + """ + Preallocation strategies for rolling window over expanding data set + + Restrictions: major_axis can only be a DatetimeIndex for now + """ + + def __init__(self, window, items, minor_axis, cap_multiple=2, + dtype=np.float64): + self.pos = 0 + self.window = window + + self.items = _ensure_index(items) + self.minor_axis = _ensure_index(minor_axis) + + self.cap_multiple = cap_multiple + self.cap = cap_multiple * window + + self.dtype = dtype + self.buffer = np.empty((len(items), self.cap, len(minor_axis)), + dtype=dtype) + self.index_buf = np.empty(self.cap, dtype='M8[ns]') + + def add_frame(self, tick, frame): + """ + + TODO: this assumes the DataFrame has the right shape + """ + if self.pos == self.cap: + self._roll_data() + + self.buffer[:, self.pos, :] = frame.values + self.index_buf[self.pos] = tick + + self.pos += 1 + + def get_current(self): + """ + Get a Panel that is the current data in view. It is not safe to persist + these objects because internal data might change + """ + where = slice(max(self.pos - self.window, 0), self.pos) + major_axis = pd.DatetimeIndex(self.index_buf[where], tz='utc') + return pd.Panel(self.buffer[:, where, :], self.items, major_axis, + self.minor_axis) + + def _roll_data(self): + """ + Roll window worth of data up to position zero. + Save the effort of having to expensively roll at each iteration + """ + self.buffer[:, :self.window, :] = self.buffer[:, -self.window:] + self.index_buf[:self.window] = self.index_buf[-self.window:] + self.pos = self.window + + +class NaiveRollingPanel(object): + + def __init__(self, window, items, minor_axis, cap_multiple=2): + pass From 2be7014d516b6f215b9087f1ab2bd0e203459f77 Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Tue, 2 Apr 2013 12:33:56 -0700 Subject: [PATCH 2/5] ENH: Rewrite of batch_transform to use rolling panel. - Added unittest to test for newly appearing sids. - Fixed logic bug where window was only full after window_length+1 events got passed. --- tests/test_batchtransform.py | 157 ++++++++++++++++++-- zipline/test_algorithms.py | 6 +- zipline/transforms/utils.py | 269 ++++++++++++++++------------------- zipline/utils/data.py | 22 +-- 4 files changed, 285 insertions(+), 169 deletions(-) diff --git a/tests/test_batchtransform.py b/tests/test_batchtransform.py index 35e8e2cc..ad73239d 100644 --- a/tests/test_batchtransform.py +++ b/tests/test_batchtransform.py @@ -1,3 +1,18 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import deque import pytz @@ -9,9 +24,105 @@ from unittest import TestCase from zipline.utils.test_utils import setup_logger +from zipline.sources.data_source import DataSource import zipline.utils.factory as factory -from zipline.test_algorithms import BatchTransformAlgorithm +from zipline.test_algorithms import (BatchTransformAlgorithm, + batch_transform, + ReturnPriceBatchTransform) + +from zipline.algorithm import TradingAlgorithm +from zipline.utils.tradingcalendar import trading_days +from copy import deepcopy + + +@batch_transform +def return_price(data): + return data.price + + +class BatchTransformAlgorithmSetSid(TradingAlgorithm): + def initialize(self, sids): + self.history = [] + + self.batch_transform = return_price( + refresh_period=1, + window_length=10, + clean_nans=False, + sids=sids, + compute_only_full=False + ) + + def handle_data(self, data): + self.history.append( + deepcopy(self.batch_transform.handle_data(data))) + + +class DifferentSidSource(DataSource): + def __init__(self): + self.dates = pd.date_range('1990-01-01', periods=180, tz='utc') + self.start = self.dates[0] + self.end = self.dates[-1] + self._raw_data = None + self.sids = range(90) + self.sid = 0 + self.trading_days = [] + + @property + def instance_hash(self): + return '1234' + + @property + def raw_data(self): + if not self._raw_data: + self._raw_data = self.raw_data_gen() + return self._raw_data + + @property + def mapping(self): + return { + 'dt': (lambda x: x, 'dt'), + 'sid': (lambda x: x, 'sid'), + 'price': (float, 'price'), + 'volume': (int, 'volume'), + } + + def raw_data_gen(self): + # Create differente sid for each event + for date in self.dates: + if date not in trading_days: + continue + event = {'dt': date, + 'sid': self.sid, + 'price': self.sid, + 'volume': self.sid} + self.sid += 1 + self.trading_days.append(date) + yield event + + +class TestChangeOfSids(TestCase): + def setUp(self): + self.sids = range(90) + self.sim_params = factory.create_simulation_parameters( + start=datetime(1990, 1, 1, tzinfo=pytz.utc), + end=datetime(1990, 1, 8, tzinfo=pytz.utc) + ) + + def test_all_sids_passed(self): + algo = BatchTransformAlgorithmSetSid(self.sids, + sim_params=self.sim_params) + source = DifferentSidSource() + algo.run(source) + for df, date in zip(algo.history, source.trading_days): + self.assertEqual(df.index[-1], date, "Newest event doesn't \ + match.") + + for sid in self.sids: + self.assertIn(sid, df.columns) + + last_elem = len(df) - 1 + self.assertEqual(df[last_elem][last_elem], last_elem) class TestBatchTransform(TestCase): @@ -24,20 +135,23 @@ class TestBatchTransform(TestCase): self.source, self.df = \ factory.create_test_df_source(self.sim_params) - def test_event_window(self): + def test_core_functionality(self): algo = BatchTransformAlgorithm(sim_params=self.sim_params) algo.run(self.source) wl = algo.window_length # The following assertion depend on window length of 3 self.assertEqual(wl, 3) - self.assertEqual(algo.history_return_price_class[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + + # If window_length is 3, there should be 2 None events, as the + # window fills up on the 3rd day. + n_none_events = 2 + self.assertEqual(algo.history_return_price_class[:n_none_events], + [None] * n_none_events, + "First two iterations should return None." + "\n" + "i.e. no returned values until window is full'" + "%s" % (algo.history_return_price_class,)) - self.assertEqual(algo.history_return_price_decorator[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + + self.assertEqual(algo.history_return_price_decorator[:n_none_events], + [None] * n_none_events, + "First two iterations should return None." + "\n" + "i.e. no returned values until window is full'" + "%s" % (algo.history_return_price_decorator,)) @@ -90,8 +204,8 @@ class TestBatchTransform(TestCase): ) def test_passing_of_args(self): - algo = BatchTransformAlgorithm(1, - kwarg='str', sim_params=self.sim_params) + algo = BatchTransformAlgorithm(1, kwarg='str', + sim_params=self.sim_params) self.assertEqual(algo.args, (1,)) self.assertEqual(algo.kwargs, {'kwarg': 'str'}) @@ -105,10 +219,29 @@ class TestBatchTransform(TestCase): None, # 1990-01-03 - window not full None, - # 1990-01-04 - window not full, 3rd event - None, + # 1990-01-04 - window now full, 3rd event + expected_item, # 1990-01-05 - window now full expected_item, # 1990-01-08 - window now full expected_item ]) + + +def run_batchtransform(window_length=10): + sim_params = factory.create_simulation_parameters( + start=datetime(1990, 1, 1, tzinfo=pytz.utc), + end=datetime(1995, 1, 8, tzinfo=pytz.utc) + ) + source, df = factory.create_test_df_source(sim_params) + + return_price_class = ReturnPriceBatchTransform( + refresh_period=1, + window_length=window_length, + clean_nans=False + ) + + for raw_event in source: + raw_event['datetime'] = raw_event.dt + event = {0: raw_event} + return_price_class.handle_data(event) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index c5ba42be..785df560 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -340,7 +340,7 @@ class BatchTransformAlgorithm(TradingAlgorithm): ) self.return_not_full = return_data( - refresh_period=0, + refresh_period=1, window_length=self.window_length, compute_only_full=False ) @@ -378,9 +378,9 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.price_multiple.handle_data(data, 1, extra_arg=1) if self.price_multiple.full: - pre = len(self.price_multiple.ticks) + pre = self.price_multiple.rolling_panel.get_current().shape[0] result1 = self.price_multiple.handle_data(data, 1, extra_arg=1) - post = len(self.price_multiple.ticks) + post = self.price_multiple.rolling_panel.get_current().shape[0] assert pre == post, "batch transform is appending redundant events" result2 = self.price_multiple.handle_data(data, 1, extra_arg=1) assert result1 is result2, "batch transform is not idempotent" diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index dabad6d8..2c436822 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -20,17 +20,22 @@ Generator versions of transforms. import functools import types import logbook + import numpy +from numbers import Integral + +import pandas as pd + +from zipline.utils.data import RollingPanel +from zipline.protocol import Event + from copy import deepcopy from datetime import datetime from collections import deque from abc import ABCMeta, abstractmethod -from numbers import Integral -import pandas as pd - -from zipline.protocol import Event, DATASOURCE_TYPE +from zipline.protocol import DATASOURCE_TYPE from zipline.gens.utils import assert_sort_unframe_protocol, hash_args import zipline.finance.trading as trading @@ -263,7 +268,7 @@ class EventWindow(object): (event, self.ticks[0]) -class BatchTransform(EventWindow): +class BatchTransform(object): """Base class for batch transforms with a trailing window of variable length. As opposed to pure EventWindows that get a stream of events and are bound to a single SID, this class creates stream @@ -336,9 +341,6 @@ class BatchTransform(EventWindow): Only call the user-defined function once the window is full. Returns None if window is not full yet. """ - - super(BatchTransform, self).__init__(True, window_length=window_length) - if func is not None: self.compute_transform_value = func else: @@ -351,8 +353,8 @@ class BatchTransform(EventWindow): # to operate on the data, but to also allow new symbols to # enter the batch transform's window IFF a sid filter is not # specified. - self.sids = None - if sids: + self.sids = set() + if sids is not None: self.static_sids = True self.sids = sids if isinstance(sids, (basestring, Integral)): @@ -369,12 +371,12 @@ class BatchTransform(EventWindow): self.refresh_period = refresh_period self.window_length = window_length - self.trading_days_since_update = 0 self.trading_days_total = 0 self.window = None self.full = False - self.last_dt = None + # Set to -inf essentially to cause update on first attempt. + self.last_dt = pd.Timestamp('1900-1-1', tz='UTC') self.updated = False self.cached = None @@ -387,10 +389,11 @@ class BatchTransform(EventWindow): # set of stocks per quarter self.supplemental_data = None + self.rolling_panel = None + def handle_data(self, data, *args, **kwargs): """ - New method to handle a data frame as sent to the algorithm's - handle_data method. + Point of entry. Process an event frame. """ # extract dates #dts = [data[sid].datetime for sid in self.sids] @@ -412,20 +415,115 @@ class BatchTransform(EventWindow): # only modify the trailing window if this is # a new event. This is intended to make handle_data # idempotent. - if event not in self.ticks: - # append data frame to window. update() will call handle_add() and - # handle_remove() appropriately, and self.updated - # will be modified based on the refresh_period - self.update(event) + if self.last_dt < event.dt: + self.updated = True + self._append_to_window(event) else: - # we are recalculating based on an old event, so - # there is no change in the contents of the trailing - # window self.updated = False # return newly computed or cached value return self.get_transform_value(*args, **kwargs) + def _append_to_window(self, event): + self.field_names = self._get_field_names(event) + + if not self.static_sids: + event_sids = set(event.data.keys()) + self.sids = set.union(self.sids, event_sids) + + # Create rolling panel if not existant + if self.rolling_panel is None: + self.rolling_panel = RollingPanel(self.window_length, + self.field_names, self.sids) + + # Store event in rolling frame + self.rolling_panel.add_frame(event.dt, + pd.DataFrame(event.data, + index=self.field_names, + columns=self.sids)) + + # update trading day counters + if self.last_dt.day != event.dt.day: + self.last_dt = event.dt + self.trading_days_total += 1 + + if self.trading_days_total >= self.window_length: + self.full = True + + def get_transform_value(self, *args, **kwargs): + """Call user-defined batch-transform function passing all + arguments. + + Note that this will only call the transform if the datapanel + has actually been updated. Otherwise, the previously, cached + value will be returned. + """ + if self.compute_only_full and not self.full: + return None + + ################################################# + # Determine whether we should call the transform + # 1. Is the refresh period over? + period_over = self.trading_days_total % self.refresh_period == 0 + # 2. Have the args or kwargs been changed since last time? + args_updated = args != self.last_args or kwargs != self.last_kwargs + recalculate_needed = args_updated or (period_over and + self.updated) + + if recalculate_needed: + self.cached = self.compute_transform_value( + self.get_data(), + *args, + **kwargs + ) + + self.last_args = args + self.last_kwargs = kwargs + return self.cached + + def get_data(self): + """Create a pandas.Panel (i.e. 3d DataFrame) from the + events in the current window. + + Returns: + The resulting panel looks like this: + index : field_name (e.g. price) + major axis/rows : dt + minor axis/colums : sid + """ + data = self.rolling_panel.get_current() + + if self.supplemental_data: + # item will be a date stamp + for item in data.items: + try: + data[item] = self.supplemental_data[item].combine_first( + data[item]) + except KeyError: + # Only filling in data available in supplemental data. + pass + + if self.clean_nans: + # Fills in gaps of missing data during transform + # of multiple stocks. E.g. we may be missing + # minute data because of illiquidity of one stock + data = data.fillna(method='ffill') + + # Hold on to a reference to the data, + # so that it's easier to find the current data when stepping + # through with a debugger + self._curr_data = data + + return data + + def get_value(self, *args, **kwargs): + raise NotImplementedError( + "Either overwrite get_value or provide a func argument.") + + def __call__(self, f): + self.compute_transform_value = f + return self.handle_data + def _extract_field_names(self, event): # extract field names from sids (price, volume etc), make sure # every sid has the same fields. @@ -448,129 +546,12 @@ class BatchTransform(EventWindow): 'datetime', 'source_id']) return union - unwanted_fields - def handle_add(self, event): - if not self.last_dt: - self.last_dt = event.dt - - if self.initial_field_names is None: + def _get_field_names(self, event): + if self.initial_field_names is not None: + return self.initial_field_names + else: self.latest_names = self._extract_field_names(event) - if self.field_names: - self.field_names = \ - set.union(self.field_names, self.latest_names) - else: - self.field_names = self.latest_names - else: - self.field_names = self.initial_field_names - - if not self.static_sids: - if self.sids: - event_sids = set(event.data.keys()) - self.sids = set.union(self.sids, event_sids) - else: - self.sids = set(event.data.keys()) - - # update trading day counters - if self.last_dt.day != event.dt.day: - self.last_dt = event.dt - self.trading_days_since_update += 1 - self.trading_days_total += 1 - - if self.trading_days_total >= self.window_length: - self.full = True - - if self.trading_days_since_update >= self.refresh_period: - # Setting updated to True will cause get_transform_value() - # to call the user-defined batch-transform with the most - # recent datapanel - self.updated = True - else: - self.updated = False - - def get_data(self): - """Create a pandas.Panel (i.e. 3d DataFrame) from the - events in the current window. - - Returns: - The resulting panel looks like this: - index : field_name (e.g. price) - major axis/rows : dt - minor axis/colums : sid - """ - # This Panel data structure ultimately gets passed to the - # user-overloaded get_value() method. - data_dict = {tick['dt']: tick['data'] for tick in self.ticks} - data = pd.Panel(data_dict, major_axis=self.field_names, - minor_axis=self.sids, - dtype='float') - - if self.supplemental_data: - # item will be a date stamp - for item in data.items: - try: - data[item] = self.supplemental_data[item].combine_first( - data[item]) - except KeyError: - # Only filling in data available in supplemental data. - pass - - data = data.swapaxes(0, 1) - - if self.clean_nans: - # Fills in gaps of missing data during transform - # of multiple stocks. E.g. we may be missing - # minute data because of illiquidity of one stock - data = data.fillna(method='ffill') - - # Hold on to a reference to the data, - # so that it's easier to find the current data when stepping - # through with a debugger - self.curr_data = data - - return data - - def handle_remove(self, event): - pass - - def get_value(self, *args, **kwargs): - raise NotImplementedError( - "Either overwrite get_value or provide a func argument.") - - def get_transform_value(self, *args, **kwargs): - """Call user-defined batch-transform function passing all - arguments. - - Note that this will only call the transform if the datapanel - has actually been updated. Otherwise, the previously, cached - value will be returned. - """ - if self.compute_only_full and not self.full: - return None - - recalculate_needed = False - if self.updated: - # Create new pandas panel - self.window = self.get_data() - # reset our counter for refresh_period - self.trading_days_since_update = 0 - recalculate_needed = True - else: - recalculate_needed = \ - args != self.last_args or kwargs != self.last_kwargs - - if recalculate_needed: - self.cached = self.compute_transform_value( - self.window, - *args, - **kwargs - ) - - self.last_args = args - self.last_kwargs = kwargs - return self.cached - - def __call__(self, f): - self.compute_transform_value = f - return self.handle_data + return set.union(self.field_names, self.latest_names) def batch_transform(func): diff --git a/zipline/utils/data.py b/zipline/utils/data.py index ca9457b6..542ea394 100644 --- a/zipline/utils/data.py +++ b/zipline/utils/data.py @@ -15,6 +15,7 @@ import numpy as np import pandas as pd +from copy import deepcopy def _ensure_index(x): @@ -43,19 +44,17 @@ class RollingPanel(object): self.cap = cap_multiple * window self.dtype = dtype - self.buffer = np.empty((len(items), self.cap, len(minor_axis)), - dtype=dtype) self.index_buf = np.empty(self.cap, dtype='M8[ns]') + self.buffer = pd.Panel(items=items, minor_axis=minor_axis, + major_axis=range(self.cap), + dtype=dtype) def add_frame(self, tick, frame): """ - - TODO: this assumes the DataFrame has the right shape """ if self.pos == self.cap: self._roll_data() - - self.buffer[:, self.pos, :] = frame.values + self.buffer.values[:, self.pos, :] = frame.ix[self.items].values self.index_buf[self.pos] = tick self.pos += 1 @@ -66,16 +65,19 @@ class RollingPanel(object): these objects because internal data might change """ where = slice(max(self.pos - self.window, 0), self.pos) - major_axis = pd.DatetimeIndex(self.index_buf[where], tz='utc') - return pd.Panel(self.buffer[:, where, :], self.items, major_axis, - self.minor_axis) + major_axis = pd.DatetimeIndex(deepcopy(self.index_buf[where]), + tz='utc') + + return pd.Panel(self.buffer.values[:, where, :], self.items, + major_axis, self.minor_axis) def _roll_data(self): """ Roll window worth of data up to position zero. Save the effort of having to expensively roll at each iteration """ - self.buffer[:, :self.window, :] = self.buffer[:, -self.window:] + self.buffer.values[:, :self.window, :] = \ + self.buffer.values[:, -self.window:] self.index_buf[:self.window] = self.index_buf[-self.window:] self.pos = self.window From c12102d7b1e28585a32477b7731704e907cf1d1a Mon Sep 17 00:00:00 2001 From: Thomas Wiecki Date: Thu, 18 Apr 2013 13:44:50 -0700 Subject: [PATCH 3/5] MAINT: Removed self.sids from batchtransform. Externally sets static_sids. --- zipline/transforms/utils.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index 2c436822..d9799e51 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -353,16 +353,13 @@ class BatchTransform(object): # to operate on the data, but to also allow new symbols to # enter the batch transform's window IFF a sid filter is not # specified. - self.sids = set() if sids is not None: - self.static_sids = True - self.sids = sids if isinstance(sids, (basestring, Integral)): - self.sids = set([sids]) + self.static_sids = set([sids]) else: - self.sids = set(sids) + self.static_sids = set(sids) else: - self.static_sids = False + self.static_sids = None self.initial_field_names = fields if isinstance(self.initial_field_names, basestring): @@ -396,7 +393,6 @@ class BatchTransform(object): Point of entry. Process an event frame. """ # extract dates - #dts = [data[sid].datetime for sid in self.sids] dts = [event.datetime for event in data.itervalues()] # we have to provide the event with a dt. This is only for # checking if the event is outside the window or not so a @@ -427,20 +423,21 @@ class BatchTransform(object): def _append_to_window(self, event): self.field_names = self._get_field_names(event) - if not self.static_sids: - event_sids = set(event.data.keys()) - self.sids = set.union(self.sids, event_sids) + if self.static_sids is None: + sids = set(event.data.keys()) + else: + sids = self.static_sids # Create rolling panel if not existant if self.rolling_panel is None: self.rolling_panel = RollingPanel(self.window_length, - self.field_names, self.sids) + self.field_names, sids) # Store event in rolling frame self.rolling_panel.add_frame(event.dt, pd.DataFrame(event.data, index=self.field_names, - columns=self.sids)) + columns=sids)) # update trading day counters if self.last_dt.day != event.dt.day: From 38ae8bbb675c33ab95009acace9fd1c7eb6c46c0 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Mon, 29 Apr 2013 14:34:50 -0400 Subject: [PATCH 4/5] BUG: Ensure that window length value is sanity checked. When moving BatchTransform off of EventWindow as a base object, the checking of window length was lost, restore that check using the same function as EventWindow. --- zipline/transforms/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index d9799e51..c8a431e1 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -367,7 +367,10 @@ class BatchTransform(object): self.field_names = set() self.refresh_period = refresh_period + + check_window_length(window_length) self.window_length = window_length + self.trading_days_total = 0 self.window = None From 06a01b146957ea333ca4b6c27b7652dfab069a2a Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Mon, 29 Apr 2013 14:50:52 -0400 Subject: [PATCH 5/5] BUG: Explicitly add support for refresh period of 0. In the previous implementation of batch transform it happened that a window_length of `0` caused the transform to update on every bar, for the time being that behavior should be retained, though the new rolling implementation more correctly aligns to the term of 'period' so a period of 1 would achieve the same effect. --- zipline/transforms/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index c8a431e1..59f40cb8 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -463,11 +463,17 @@ class BatchTransform(object): ################################################# # Determine whether we should call the transform + # 0. Support historical/legacy usage of '0' signaling, + # 'update on every bar' + if self.refresh_period == 0: + period_signals_update = True + else: # 1. Is the refresh period over? - period_over = self.trading_days_total % self.refresh_period == 0 + period_signals_update = ( + self.trading_days_total % self.refresh_period == 0) # 2. Have the args or kwargs been changed since last time? args_updated = args != self.last_args or kwargs != self.last_kwargs - recalculate_needed = args_updated or (period_over and + recalculate_needed = args_updated or (period_signals_update and self.updated) if recalculate_needed: