diff --git a/tests/test_batchtransform.py b/tests/test_batchtransform.py index 35e8e2cc..ad73239d 100644 --- a/tests/test_batchtransform.py +++ b/tests/test_batchtransform.py @@ -1,3 +1,18 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections import deque import pytz @@ -9,9 +24,105 @@ from unittest import TestCase from zipline.utils.test_utils import setup_logger +from zipline.sources.data_source import DataSource import zipline.utils.factory as factory -from zipline.test_algorithms import BatchTransformAlgorithm +from zipline.test_algorithms import (BatchTransformAlgorithm, + batch_transform, + ReturnPriceBatchTransform) + +from zipline.algorithm import TradingAlgorithm +from zipline.utils.tradingcalendar import trading_days +from copy import deepcopy + + +@batch_transform +def return_price(data): + return data.price + + +class BatchTransformAlgorithmSetSid(TradingAlgorithm): + def initialize(self, sids): + self.history = [] + + self.batch_transform = return_price( + refresh_period=1, + window_length=10, + clean_nans=False, + sids=sids, + compute_only_full=False + ) + + def handle_data(self, data): + self.history.append( + deepcopy(self.batch_transform.handle_data(data))) + + +class DifferentSidSource(DataSource): + def __init__(self): + self.dates = pd.date_range('1990-01-01', periods=180, tz='utc') + self.start = self.dates[0] + self.end = self.dates[-1] + self._raw_data = None + self.sids = range(90) + self.sid = 0 + self.trading_days = [] + + @property + def instance_hash(self): + return '1234' + + @property + def raw_data(self): + if not self._raw_data: + self._raw_data = self.raw_data_gen() + return self._raw_data + + @property + def mapping(self): + return { + 'dt': (lambda x: x, 'dt'), + 'sid': (lambda x: x, 'sid'), + 'price': (float, 'price'), + 'volume': (int, 'volume'), + } + + def raw_data_gen(self): + # Create differente sid for each event + for date in self.dates: + if date not in trading_days: + continue + event = {'dt': date, + 'sid': self.sid, + 'price': self.sid, + 'volume': self.sid} + self.sid += 1 + self.trading_days.append(date) + yield event + + +class TestChangeOfSids(TestCase): + def setUp(self): + self.sids = range(90) + self.sim_params = factory.create_simulation_parameters( + start=datetime(1990, 1, 1, tzinfo=pytz.utc), + end=datetime(1990, 1, 8, tzinfo=pytz.utc) + ) + + def test_all_sids_passed(self): + algo = BatchTransformAlgorithmSetSid(self.sids, + sim_params=self.sim_params) + source = DifferentSidSource() + algo.run(source) + for df, date in zip(algo.history, source.trading_days): + self.assertEqual(df.index[-1], date, "Newest event doesn't \ + match.") + + for sid in self.sids: + self.assertIn(sid, df.columns) + + last_elem = len(df) - 1 + self.assertEqual(df[last_elem][last_elem], last_elem) class TestBatchTransform(TestCase): @@ -24,20 +135,23 @@ class TestBatchTransform(TestCase): self.source, self.df = \ factory.create_test_df_source(self.sim_params) - def test_event_window(self): + def test_core_functionality(self): algo = BatchTransformAlgorithm(sim_params=self.sim_params) algo.run(self.source) wl = algo.window_length # The following assertion depend on window length of 3 self.assertEqual(wl, 3) - self.assertEqual(algo.history_return_price_class[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + + # If window_length is 3, there should be 2 None events, as the + # window fills up on the 3rd day. + n_none_events = 2 + self.assertEqual(algo.history_return_price_class[:n_none_events], + [None] * n_none_events, + "First two iterations should return None." + "\n" + "i.e. no returned values until window is full'" + "%s" % (algo.history_return_price_class,)) - self.assertEqual(algo.history_return_price_decorator[:wl], - [None] * wl, - "First three iterations should return None." + "\n" + + self.assertEqual(algo.history_return_price_decorator[:n_none_events], + [None] * n_none_events, + "First two iterations should return None." + "\n" + "i.e. no returned values until window is full'" + "%s" % (algo.history_return_price_decorator,)) @@ -90,8 +204,8 @@ class TestBatchTransform(TestCase): ) def test_passing_of_args(self): - algo = BatchTransformAlgorithm(1, - kwarg='str', sim_params=self.sim_params) + algo = BatchTransformAlgorithm(1, kwarg='str', + sim_params=self.sim_params) self.assertEqual(algo.args, (1,)) self.assertEqual(algo.kwargs, {'kwarg': 'str'}) @@ -105,10 +219,29 @@ class TestBatchTransform(TestCase): None, # 1990-01-03 - window not full None, - # 1990-01-04 - window not full, 3rd event - None, + # 1990-01-04 - window now full, 3rd event + expected_item, # 1990-01-05 - window now full expected_item, # 1990-01-08 - window now full expected_item ]) + + +def run_batchtransform(window_length=10): + sim_params = factory.create_simulation_parameters( + start=datetime(1990, 1, 1, tzinfo=pytz.utc), + end=datetime(1995, 1, 8, tzinfo=pytz.utc) + ) + source, df = factory.create_test_df_source(sim_params) + + return_price_class = ReturnPriceBatchTransform( + refresh_period=1, + window_length=window_length, + clean_nans=False + ) + + for raw_event in source: + raw_event['datetime'] = raw_event.dt + event = {0: raw_event} + return_price_class.handle_data(event) diff --git a/tests/test_data_util.py b/tests/test_data_util.py new file mode 100644 index 00000000..ded2c8ca --- /dev/null +++ b/tests/test_data_util.py @@ -0,0 +1,108 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from collections import deque + +import numpy as np + +import pandas as pd +import pandas.util.testing as tm + +from zipline.utils.data import RollingPanel + + +class TestRollingPanel(unittest.TestCase): + + def test_basics(self): + items = ['foo', 'bar', 'baz'] + minor = ['A', 'B', 'C', 'D'] + + window = 10 + + rp = RollingPanel(window, items, minor, cap_multiple=2) + + dates = pd.date_range('2000-01-01', periods=30, tz='utc') + + major_deque = deque() + + frames = {} + + for i in range(30): + frame = pd.DataFrame(np.random.randn(3, 4), index=items, + columns=minor) + date = dates[i] + + rp.add_frame(date, frame) + + frames[date] = frame + major_deque.append(date) + + if i >= window: + major_deque.popleft() + + result = rp.get_current() + expected = pd.Panel(frames, items=list(major_deque), + major_axis=items, minor_axis=minor) + tm.assert_panel_equal(result, expected.swapaxes(0, 1)) + + +def f(option='clever', n=500, copy=False): + items = range(5) + minor = range(20) + window = 100 + periods = n + + dates = pd.date_range('2000-01-01', periods=periods, tz='utc') + frames = {} + + if option == 'clever': + rp = RollingPanel(window, items, minor, cap_multiple=2) + major_deque = deque() + dummy = pd.DataFrame(np.random.randn(len(items), len(minor)), + index=items, columns=minor) + + for i in range(periods): + frame = dummy * (1 + 0.001 * i) + date = dates[i] + + rp.add_frame(date, frame) + + frames[date] = frame + major_deque.append(date) + + if i >= window: + del frames[major_deque.popleft()] + + result = rp.get_current() + if copy: + result = result.copy() + else: + major_deque = deque() + dummy = pd.DataFrame(np.random.randn(len(items), len(minor)), + index=items, columns=minor) + + for i in range(periods): + frame = dummy * (1 + 0.001 * i) + date = dates[i] + frames[date] = frame + major_deque.append(date) + + if i >= window: + del frames[major_deque.popleft()] + + result = pd.Panel(frames, items=list(major_deque), + major_axis=items, minor_axis=minor) diff --git a/zipline/test_algorithms.py b/zipline/test_algorithms.py index c5ba42be..785df560 100644 --- a/zipline/test_algorithms.py +++ b/zipline/test_algorithms.py @@ -340,7 +340,7 @@ class BatchTransformAlgorithm(TradingAlgorithm): ) self.return_not_full = return_data( - refresh_period=0, + refresh_period=1, window_length=self.window_length, compute_only_full=False ) @@ -378,9 +378,9 @@ class BatchTransformAlgorithm(TradingAlgorithm): self.price_multiple.handle_data(data, 1, extra_arg=1) if self.price_multiple.full: - pre = len(self.price_multiple.ticks) + pre = self.price_multiple.rolling_panel.get_current().shape[0] result1 = self.price_multiple.handle_data(data, 1, extra_arg=1) - post = len(self.price_multiple.ticks) + post = self.price_multiple.rolling_panel.get_current().shape[0] assert pre == post, "batch transform is appending redundant events" result2 = self.price_multiple.handle_data(data, 1, extra_arg=1) assert result1 is result2, "batch transform is not idempotent" diff --git a/zipline/transforms/utils.py b/zipline/transforms/utils.py index dabad6d8..59f40cb8 100644 --- a/zipline/transforms/utils.py +++ b/zipline/transforms/utils.py @@ -20,17 +20,22 @@ Generator versions of transforms. import functools import types import logbook + import numpy +from numbers import Integral + +import pandas as pd + +from zipline.utils.data import RollingPanel +from zipline.protocol import Event + from copy import deepcopy from datetime import datetime from collections import deque from abc import ABCMeta, abstractmethod -from numbers import Integral -import pandas as pd - -from zipline.protocol import Event, DATASOURCE_TYPE +from zipline.protocol import DATASOURCE_TYPE from zipline.gens.utils import assert_sort_unframe_protocol, hash_args import zipline.finance.trading as trading @@ -263,7 +268,7 @@ class EventWindow(object): (event, self.ticks[0]) -class BatchTransform(EventWindow): +class BatchTransform(object): """Base class for batch transforms with a trailing window of variable length. As opposed to pure EventWindows that get a stream of events and are bound to a single SID, this class creates stream @@ -336,9 +341,6 @@ class BatchTransform(EventWindow): Only call the user-defined function once the window is full. Returns None if window is not full yet. """ - - super(BatchTransform, self).__init__(True, window_length=window_length) - if func is not None: self.compute_transform_value = func else: @@ -351,16 +353,13 @@ class BatchTransform(EventWindow): # to operate on the data, but to also allow new symbols to # enter the batch transform's window IFF a sid filter is not # specified. - self.sids = None - if sids: - self.static_sids = True - self.sids = sids + if sids is not None: if isinstance(sids, (basestring, Integral)): - self.sids = set([sids]) + self.static_sids = set([sids]) else: - self.sids = set(sids) + self.static_sids = set(sids) else: - self.static_sids = False + self.static_sids = None self.initial_field_names = fields if isinstance(self.initial_field_names, basestring): @@ -368,13 +367,16 @@ class BatchTransform(EventWindow): self.field_names = set() self.refresh_period = refresh_period + + check_window_length(window_length) self.window_length = window_length - self.trading_days_since_update = 0 + self.trading_days_total = 0 self.window = None self.full = False - self.last_dt = None + # Set to -inf essentially to cause update on first attempt. + self.last_dt = pd.Timestamp('1900-1-1', tz='UTC') self.updated = False self.cached = None @@ -387,13 +389,13 @@ class BatchTransform(EventWindow): # set of stocks per quarter self.supplemental_data = None + self.rolling_panel = None + def handle_data(self, data, *args, **kwargs): """ - New method to handle a data frame as sent to the algorithm's - handle_data method. + Point of entry. Process an event frame. """ # extract dates - #dts = [data[sid].datetime for sid in self.sids] dts = [event.datetime for event in data.itervalues()] # we have to provide the event with a dt. This is only for # checking if the event is outside the window or not so a @@ -412,20 +414,122 @@ class BatchTransform(EventWindow): # only modify the trailing window if this is # a new event. This is intended to make handle_data # idempotent. - if event not in self.ticks: - # append data frame to window. update() will call handle_add() and - # handle_remove() appropriately, and self.updated - # will be modified based on the refresh_period - self.update(event) + if self.last_dt < event.dt: + self.updated = True + self._append_to_window(event) else: - # we are recalculating based on an old event, so - # there is no change in the contents of the trailing - # window self.updated = False # return newly computed or cached value return self.get_transform_value(*args, **kwargs) + def _append_to_window(self, event): + self.field_names = self._get_field_names(event) + + if self.static_sids is None: + sids = set(event.data.keys()) + else: + sids = self.static_sids + + # Create rolling panel if not existant + if self.rolling_panel is None: + self.rolling_panel = RollingPanel(self.window_length, + self.field_names, sids) + + # Store event in rolling frame + self.rolling_panel.add_frame(event.dt, + pd.DataFrame(event.data, + index=self.field_names, + columns=sids)) + + # update trading day counters + if self.last_dt.day != event.dt.day: + self.last_dt = event.dt + self.trading_days_total += 1 + + if self.trading_days_total >= self.window_length: + self.full = True + + def get_transform_value(self, *args, **kwargs): + """Call user-defined batch-transform function passing all + arguments. + + Note that this will only call the transform if the datapanel + has actually been updated. Otherwise, the previously, cached + value will be returned. + """ + if self.compute_only_full and not self.full: + return None + + ################################################# + # Determine whether we should call the transform + # 0. Support historical/legacy usage of '0' signaling, + # 'update on every bar' + if self.refresh_period == 0: + period_signals_update = True + else: + # 1. Is the refresh period over? + period_signals_update = ( + self.trading_days_total % self.refresh_period == 0) + # 2. Have the args or kwargs been changed since last time? + args_updated = args != self.last_args or kwargs != self.last_kwargs + recalculate_needed = args_updated or (period_signals_update and + self.updated) + + if recalculate_needed: + self.cached = self.compute_transform_value( + self.get_data(), + *args, + **kwargs + ) + + self.last_args = args + self.last_kwargs = kwargs + return self.cached + + def get_data(self): + """Create a pandas.Panel (i.e. 3d DataFrame) from the + events in the current window. + + Returns: + The resulting panel looks like this: + index : field_name (e.g. price) + major axis/rows : dt + minor axis/colums : sid + """ + data = self.rolling_panel.get_current() + + if self.supplemental_data: + # item will be a date stamp + for item in data.items: + try: + data[item] = self.supplemental_data[item].combine_first( + data[item]) + except KeyError: + # Only filling in data available in supplemental data. + pass + + if self.clean_nans: + # Fills in gaps of missing data during transform + # of multiple stocks. E.g. we may be missing + # minute data because of illiquidity of one stock + data = data.fillna(method='ffill') + + # Hold on to a reference to the data, + # so that it's easier to find the current data when stepping + # through with a debugger + self._curr_data = data + + return data + + def get_value(self, *args, **kwargs): + raise NotImplementedError( + "Either overwrite get_value or provide a func argument.") + + def __call__(self, f): + self.compute_transform_value = f + return self.handle_data + def _extract_field_names(self, event): # extract field names from sids (price, volume etc), make sure # every sid has the same fields. @@ -448,129 +552,12 @@ class BatchTransform(EventWindow): 'datetime', 'source_id']) return union - unwanted_fields - def handle_add(self, event): - if not self.last_dt: - self.last_dt = event.dt - - if self.initial_field_names is None: + def _get_field_names(self, event): + if self.initial_field_names is not None: + return self.initial_field_names + else: self.latest_names = self._extract_field_names(event) - if self.field_names: - self.field_names = \ - set.union(self.field_names, self.latest_names) - else: - self.field_names = self.latest_names - else: - self.field_names = self.initial_field_names - - if not self.static_sids: - if self.sids: - event_sids = set(event.data.keys()) - self.sids = set.union(self.sids, event_sids) - else: - self.sids = set(event.data.keys()) - - # update trading day counters - if self.last_dt.day != event.dt.day: - self.last_dt = event.dt - self.trading_days_since_update += 1 - self.trading_days_total += 1 - - if self.trading_days_total >= self.window_length: - self.full = True - - if self.trading_days_since_update >= self.refresh_period: - # Setting updated to True will cause get_transform_value() - # to call the user-defined batch-transform with the most - # recent datapanel - self.updated = True - else: - self.updated = False - - def get_data(self): - """Create a pandas.Panel (i.e. 3d DataFrame) from the - events in the current window. - - Returns: - The resulting panel looks like this: - index : field_name (e.g. price) - major axis/rows : dt - minor axis/colums : sid - """ - # This Panel data structure ultimately gets passed to the - # user-overloaded get_value() method. - data_dict = {tick['dt']: tick['data'] for tick in self.ticks} - data = pd.Panel(data_dict, major_axis=self.field_names, - minor_axis=self.sids, - dtype='float') - - if self.supplemental_data: - # item will be a date stamp - for item in data.items: - try: - data[item] = self.supplemental_data[item].combine_first( - data[item]) - except KeyError: - # Only filling in data available in supplemental data. - pass - - data = data.swapaxes(0, 1) - - if self.clean_nans: - # Fills in gaps of missing data during transform - # of multiple stocks. E.g. we may be missing - # minute data because of illiquidity of one stock - data = data.fillna(method='ffill') - - # Hold on to a reference to the data, - # so that it's easier to find the current data when stepping - # through with a debugger - self.curr_data = data - - return data - - def handle_remove(self, event): - pass - - def get_value(self, *args, **kwargs): - raise NotImplementedError( - "Either overwrite get_value or provide a func argument.") - - def get_transform_value(self, *args, **kwargs): - """Call user-defined batch-transform function passing all - arguments. - - Note that this will only call the transform if the datapanel - has actually been updated. Otherwise, the previously, cached - value will be returned. - """ - if self.compute_only_full and not self.full: - return None - - recalculate_needed = False - if self.updated: - # Create new pandas panel - self.window = self.get_data() - # reset our counter for refresh_period - self.trading_days_since_update = 0 - recalculate_needed = True - else: - recalculate_needed = \ - args != self.last_args or kwargs != self.last_kwargs - - if recalculate_needed: - self.cached = self.compute_transform_value( - self.window, - *args, - **kwargs - ) - - self.last_args = args - self.last_kwargs = kwargs - return self.cached - - def __call__(self, f): - self.compute_transform_value = f - return self.handle_data + return set.union(self.field_names, self.latest_names) def batch_transform(func): diff --git a/zipline/utils/data.py b/zipline/utils/data.py new file mode 100644 index 00000000..542ea394 --- /dev/null +++ b/zipline/utils/data.py @@ -0,0 +1,88 @@ +# +# Copyright 2013 Quantopian, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +from copy import deepcopy + + +def _ensure_index(x): + if not isinstance(x, pd.Index): + x = pd.Index(x) + + return x + + +class RollingPanel(object): + """ + Preallocation strategies for rolling window over expanding data set + + Restrictions: major_axis can only be a DatetimeIndex for now + """ + + def __init__(self, window, items, minor_axis, cap_multiple=2, + dtype=np.float64): + self.pos = 0 + self.window = window + + self.items = _ensure_index(items) + self.minor_axis = _ensure_index(minor_axis) + + self.cap_multiple = cap_multiple + self.cap = cap_multiple * window + + self.dtype = dtype + self.index_buf = np.empty(self.cap, dtype='M8[ns]') + self.buffer = pd.Panel(items=items, minor_axis=minor_axis, + major_axis=range(self.cap), + dtype=dtype) + + def add_frame(self, tick, frame): + """ + """ + if self.pos == self.cap: + self._roll_data() + self.buffer.values[:, self.pos, :] = frame.ix[self.items].values + self.index_buf[self.pos] = tick + + self.pos += 1 + + def get_current(self): + """ + Get a Panel that is the current data in view. It is not safe to persist + these objects because internal data might change + """ + where = slice(max(self.pos - self.window, 0), self.pos) + major_axis = pd.DatetimeIndex(deepcopy(self.index_buf[where]), + tz='utc') + + return pd.Panel(self.buffer.values[:, where, :], self.items, + major_axis, self.minor_axis) + + def _roll_data(self): + """ + Roll window worth of data up to position zero. + Save the effort of having to expensively roll at each iteration + """ + self.buffer.values[:, :self.window, :] = \ + self.buffer.values[:, -self.window:] + self.index_buf[:self.window] = self.index_buf[-self.window:] + self.pos = self.window + + +class NaiveRollingPanel(object): + + def __init__(self, window, items, minor_axis, cap_multiple=2): + pass