Merge batch transform and rolling panel enhancements.

Branch provides a rolling pandas data panel, and converts
batch transform to use the new panel type.
This commit is contained in:
Eddie Hebert
2013-04-29 15:31:38 -04:00
5 changed files with 481 additions and 165 deletions
+145 -12
View File
@@ -1,3 +1,18 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import pytz
@@ -9,9 +24,105 @@ from unittest import TestCase
from zipline.utils.test_utils import setup_logger
from zipline.sources.data_source import DataSource
import zipline.utils.factory as factory
from zipline.test_algorithms import BatchTransformAlgorithm
from zipline.test_algorithms import (BatchTransformAlgorithm,
batch_transform,
ReturnPriceBatchTransform)
from zipline.algorithm import TradingAlgorithm
from zipline.utils.tradingcalendar import trading_days
from copy import deepcopy
@batch_transform
def return_price(data):
return data.price
class BatchTransformAlgorithmSetSid(TradingAlgorithm):
def initialize(self, sids):
self.history = []
self.batch_transform = return_price(
refresh_period=1,
window_length=10,
clean_nans=False,
sids=sids,
compute_only_full=False
)
def handle_data(self, data):
self.history.append(
deepcopy(self.batch_transform.handle_data(data)))
class DifferentSidSource(DataSource):
def __init__(self):
self.dates = pd.date_range('1990-01-01', periods=180, tz='utc')
self.start = self.dates[0]
self.end = self.dates[-1]
self._raw_data = None
self.sids = range(90)
self.sid = 0
self.trading_days = []
@property
def instance_hash(self):
return '1234'
@property
def raw_data(self):
if not self._raw_data:
self._raw_data = self.raw_data_gen()
return self._raw_data
@property
def mapping(self):
return {
'dt': (lambda x: x, 'dt'),
'sid': (lambda x: x, 'sid'),
'price': (float, 'price'),
'volume': (int, 'volume'),
}
def raw_data_gen(self):
# Create differente sid for each event
for date in self.dates:
if date not in trading_days:
continue
event = {'dt': date,
'sid': self.sid,
'price': self.sid,
'volume': self.sid}
self.sid += 1
self.trading_days.append(date)
yield event
class TestChangeOfSids(TestCase):
def setUp(self):
self.sids = range(90)
self.sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1990, 1, 8, tzinfo=pytz.utc)
)
def test_all_sids_passed(self):
algo = BatchTransformAlgorithmSetSid(self.sids,
sim_params=self.sim_params)
source = DifferentSidSource()
algo.run(source)
for df, date in zip(algo.history, source.trading_days):
self.assertEqual(df.index[-1], date, "Newest event doesn't \
match.")
for sid in self.sids:
self.assertIn(sid, df.columns)
last_elem = len(df) - 1
self.assertEqual(df[last_elem][last_elem], last_elem)
class TestBatchTransform(TestCase):
@@ -24,20 +135,23 @@ class TestBatchTransform(TestCase):
self.source, self.df = \
factory.create_test_df_source(self.sim_params)
def test_event_window(self):
def test_core_functionality(self):
algo = BatchTransformAlgorithm(sim_params=self.sim_params)
algo.run(self.source)
wl = algo.window_length
# The following assertion depend on window length of 3
self.assertEqual(wl, 3)
self.assertEqual(algo.history_return_price_class[:wl],
[None] * wl,
"First three iterations should return None." + "\n" +
# If window_length is 3, there should be 2 None events, as the
# window fills up on the 3rd day.
n_none_events = 2
self.assertEqual(algo.history_return_price_class[:n_none_events],
[None] * n_none_events,
"First two iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_class,))
self.assertEqual(algo.history_return_price_decorator[:wl],
[None] * wl,
"First three iterations should return None." + "\n" +
self.assertEqual(algo.history_return_price_decorator[:n_none_events],
[None] * n_none_events,
"First two iterations should return None." + "\n" +
"i.e. no returned values until window is full'" +
"%s" % (algo.history_return_price_decorator,))
@@ -90,8 +204,8 @@ class TestBatchTransform(TestCase):
)
def test_passing_of_args(self):
algo = BatchTransformAlgorithm(1,
kwarg='str', sim_params=self.sim_params)
algo = BatchTransformAlgorithm(1, kwarg='str',
sim_params=self.sim_params)
self.assertEqual(algo.args, (1,))
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
@@ -105,10 +219,29 @@ class TestBatchTransform(TestCase):
None,
# 1990-01-03 - window not full
None,
# 1990-01-04 - window not full, 3rd event
None,
# 1990-01-04 - window now full, 3rd event
expected_item,
# 1990-01-05 - window now full
expected_item,
# 1990-01-08 - window now full
expected_item
])
def run_batchtransform(window_length=10):
sim_params = factory.create_simulation_parameters(
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
end=datetime(1995, 1, 8, tzinfo=pytz.utc)
)
source, df = factory.create_test_df_source(sim_params)
return_price_class = ReturnPriceBatchTransform(
refresh_period=1,
window_length=window_length,
clean_nans=False
)
for raw_event in source:
raw_event['datetime'] = raw_event.dt
event = {0: raw_event}
return_price_class.handle_data(event)
+108
View File
@@ -0,0 +1,108 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from collections import deque
import numpy as np
import pandas as pd
import pandas.util.testing as tm
from zipline.utils.data import RollingPanel
class TestRollingPanel(unittest.TestCase):
def test_basics(self):
items = ['foo', 'bar', 'baz']
minor = ['A', 'B', 'C', 'D']
window = 10
rp = RollingPanel(window, items, minor, cap_multiple=2)
dates = pd.date_range('2000-01-01', periods=30, tz='utc')
major_deque = deque()
frames = {}
for i in range(30):
frame = pd.DataFrame(np.random.randn(3, 4), index=items,
columns=minor)
date = dates[i]
rp.add_frame(date, frame)
frames[date] = frame
major_deque.append(date)
if i >= window:
major_deque.popleft()
result = rp.get_current()
expected = pd.Panel(frames, items=list(major_deque),
major_axis=items, minor_axis=minor)
tm.assert_panel_equal(result, expected.swapaxes(0, 1))
def f(option='clever', n=500, copy=False):
items = range(5)
minor = range(20)
window = 100
periods = n
dates = pd.date_range('2000-01-01', periods=periods, tz='utc')
frames = {}
if option == 'clever':
rp = RollingPanel(window, items, minor, cap_multiple=2)
major_deque = deque()
dummy = pd.DataFrame(np.random.randn(len(items), len(minor)),
index=items, columns=minor)
for i in range(periods):
frame = dummy * (1 + 0.001 * i)
date = dates[i]
rp.add_frame(date, frame)
frames[date] = frame
major_deque.append(date)
if i >= window:
del frames[major_deque.popleft()]
result = rp.get_current()
if copy:
result = result.copy()
else:
major_deque = deque()
dummy = pd.DataFrame(np.random.randn(len(items), len(minor)),
index=items, columns=minor)
for i in range(periods):
frame = dummy * (1 + 0.001 * i)
date = dates[i]
frames[date] = frame
major_deque.append(date)
if i >= window:
del frames[major_deque.popleft()]
result = pd.Panel(frames, items=list(major_deque),
major_axis=items, minor_axis=minor)
+3 -3
View File
@@ -340,7 +340,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
)
self.return_not_full = return_data(
refresh_period=0,
refresh_period=1,
window_length=self.window_length,
compute_only_full=False
)
@@ -378,9 +378,9 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.price_multiple.handle_data(data, 1, extra_arg=1)
if self.price_multiple.full:
pre = len(self.price_multiple.ticks)
pre = self.price_multiple.rolling_panel.get_current().shape[0]
result1 = self.price_multiple.handle_data(data, 1, extra_arg=1)
post = len(self.price_multiple.ticks)
post = self.price_multiple.rolling_panel.get_current().shape[0]
assert pre == post, "batch transform is appending redundant events"
result2 = self.price_multiple.handle_data(data, 1, extra_arg=1)
assert result1 is result2, "batch transform is not idempotent"
+137 -150
View File
@@ -20,17 +20,22 @@ Generator versions of transforms.
import functools
import types
import logbook
import numpy
from numbers import Integral
import pandas as pd
from zipline.utils.data import RollingPanel
from zipline.protocol import Event
from copy import deepcopy
from datetime import datetime
from collections import deque
from abc import ABCMeta, abstractmethod
from numbers import Integral
import pandas as pd
from zipline.protocol import Event, DATASOURCE_TYPE
from zipline.protocol import DATASOURCE_TYPE
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
import zipline.finance.trading as trading
@@ -263,7 +268,7 @@ class EventWindow(object):
(event, self.ticks[0])
class BatchTransform(EventWindow):
class BatchTransform(object):
"""Base class for batch transforms with a trailing window of
variable length. As opposed to pure EventWindows that get a stream
of events and are bound to a single SID, this class creates stream
@@ -336,9 +341,6 @@ class BatchTransform(EventWindow):
Only call the user-defined function once the window is
full. Returns None if window is not full yet.
"""
super(BatchTransform, self).__init__(True, window_length=window_length)
if func is not None:
self.compute_transform_value = func
else:
@@ -351,16 +353,13 @@ class BatchTransform(EventWindow):
# to operate on the data, but to also allow new symbols to
# enter the batch transform's window IFF a sid filter is not
# specified.
self.sids = None
if sids:
self.static_sids = True
self.sids = sids
if sids is not None:
if isinstance(sids, (basestring, Integral)):
self.sids = set([sids])
self.static_sids = set([sids])
else:
self.sids = set(sids)
self.static_sids = set(sids)
else:
self.static_sids = False
self.static_sids = None
self.initial_field_names = fields
if isinstance(self.initial_field_names, basestring):
@@ -368,13 +367,16 @@ class BatchTransform(EventWindow):
self.field_names = set()
self.refresh_period = refresh_period
check_window_length(window_length)
self.window_length = window_length
self.trading_days_since_update = 0
self.trading_days_total = 0
self.window = None
self.full = False
self.last_dt = None
# Set to -inf essentially to cause update on first attempt.
self.last_dt = pd.Timestamp('1900-1-1', tz='UTC')
self.updated = False
self.cached = None
@@ -387,13 +389,13 @@ class BatchTransform(EventWindow):
# set of stocks per quarter
self.supplemental_data = None
self.rolling_panel = None
def handle_data(self, data, *args, **kwargs):
"""
New method to handle a data frame as sent to the algorithm's
handle_data method.
Point of entry. Process an event frame.
"""
# extract dates
#dts = [data[sid].datetime for sid in self.sids]
dts = [event.datetime for event in data.itervalues()]
# we have to provide the event with a dt. This is only for
# checking if the event is outside the window or not so a
@@ -412,20 +414,122 @@ class BatchTransform(EventWindow):
# only modify the trailing window if this is
# a new event. This is intended to make handle_data
# idempotent.
if event not in self.ticks:
# append data frame to window. update() will call handle_add() and
# handle_remove() appropriately, and self.updated
# will be modified based on the refresh_period
self.update(event)
if self.last_dt < event.dt:
self.updated = True
self._append_to_window(event)
else:
# we are recalculating based on an old event, so
# there is no change in the contents of the trailing
# window
self.updated = False
# return newly computed or cached value
return self.get_transform_value(*args, **kwargs)
def _append_to_window(self, event):
self.field_names = self._get_field_names(event)
if self.static_sids is None:
sids = set(event.data.keys())
else:
sids = self.static_sids
# Create rolling panel if not existant
if self.rolling_panel is None:
self.rolling_panel = RollingPanel(self.window_length,
self.field_names, sids)
# Store event in rolling frame
self.rolling_panel.add_frame(event.dt,
pd.DataFrame(event.data,
index=self.field_names,
columns=sids))
# update trading day counters
if self.last_dt.day != event.dt.day:
self.last_dt = event.dt
self.trading_days_total += 1
if self.trading_days_total >= self.window_length:
self.full = True
def get_transform_value(self, *args, **kwargs):
"""Call user-defined batch-transform function passing all
arguments.
Note that this will only call the transform if the datapanel
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if self.compute_only_full and not self.full:
return None
#################################################
# Determine whether we should call the transform
# 0. Support historical/legacy usage of '0' signaling,
# 'update on every bar'
if self.refresh_period == 0:
period_signals_update = True
else:
# 1. Is the refresh period over?
period_signals_update = (
self.trading_days_total % self.refresh_period == 0)
# 2. Have the args or kwargs been changed since last time?
args_updated = args != self.last_args or kwargs != self.last_kwargs
recalculate_needed = args_updated or (period_signals_update and
self.updated)
if recalculate_needed:
self.cached = self.compute_transform_value(
self.get_data(),
*args,
**kwargs
)
self.last_args = args
self.last_kwargs = kwargs
return self.cached
def get_data(self):
"""Create a pandas.Panel (i.e. 3d DataFrame) from the
events in the current window.
Returns:
The resulting panel looks like this:
index : field_name (e.g. price)
major axis/rows : dt
minor axis/colums : sid
"""
data = self.rolling_panel.get_current()
if self.supplemental_data:
# item will be a date stamp
for item in data.items:
try:
data[item] = self.supplemental_data[item].combine_first(
data[item])
except KeyError:
# Only filling in data available in supplemental data.
pass
if self.clean_nans:
# Fills in gaps of missing data during transform
# of multiple stocks. E.g. we may be missing
# minute data because of illiquidity of one stock
data = data.fillna(method='ffill')
# Hold on to a reference to the data,
# so that it's easier to find the current data when stepping
# through with a debugger
self._curr_data = data
return data
def get_value(self, *args, **kwargs):
raise NotImplementedError(
"Either overwrite get_value or provide a func argument.")
def __call__(self, f):
self.compute_transform_value = f
return self.handle_data
def _extract_field_names(self, event):
# extract field names from sids (price, volume etc), make sure
# every sid has the same fields.
@@ -448,129 +552,12 @@ class BatchTransform(EventWindow):
'datetime', 'source_id'])
return union - unwanted_fields
def handle_add(self, event):
if not self.last_dt:
self.last_dt = event.dt
if self.initial_field_names is None:
def _get_field_names(self, event):
if self.initial_field_names is not None:
return self.initial_field_names
else:
self.latest_names = self._extract_field_names(event)
if self.field_names:
self.field_names = \
set.union(self.field_names, self.latest_names)
else:
self.field_names = self.latest_names
else:
self.field_names = self.initial_field_names
if not self.static_sids:
if self.sids:
event_sids = set(event.data.keys())
self.sids = set.union(self.sids, event_sids)
else:
self.sids = set(event.data.keys())
# update trading day counters
if self.last_dt.day != event.dt.day:
self.last_dt = event.dt
self.trading_days_since_update += 1
self.trading_days_total += 1
if self.trading_days_total >= self.window_length:
self.full = True
if self.trading_days_since_update >= self.refresh_period:
# Setting updated to True will cause get_transform_value()
# to call the user-defined batch-transform with the most
# recent datapanel
self.updated = True
else:
self.updated = False
def get_data(self):
"""Create a pandas.Panel (i.e. 3d DataFrame) from the
events in the current window.
Returns:
The resulting panel looks like this:
index : field_name (e.g. price)
major axis/rows : dt
minor axis/colums : sid
"""
# This Panel data structure ultimately gets passed to the
# user-overloaded get_value() method.
data_dict = {tick['dt']: tick['data'] for tick in self.ticks}
data = pd.Panel(data_dict, major_axis=self.field_names,
minor_axis=self.sids,
dtype='float')
if self.supplemental_data:
# item will be a date stamp
for item in data.items:
try:
data[item] = self.supplemental_data[item].combine_first(
data[item])
except KeyError:
# Only filling in data available in supplemental data.
pass
data = data.swapaxes(0, 1)
if self.clean_nans:
# Fills in gaps of missing data during transform
# of multiple stocks. E.g. we may be missing
# minute data because of illiquidity of one stock
data = data.fillna(method='ffill')
# Hold on to a reference to the data,
# so that it's easier to find the current data when stepping
# through with a debugger
self.curr_data = data
return data
def handle_remove(self, event):
pass
def get_value(self, *args, **kwargs):
raise NotImplementedError(
"Either overwrite get_value or provide a func argument.")
def get_transform_value(self, *args, **kwargs):
"""Call user-defined batch-transform function passing all
arguments.
Note that this will only call the transform if the datapanel
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if self.compute_only_full and not self.full:
return None
recalculate_needed = False
if self.updated:
# Create new pandas panel
self.window = self.get_data()
# reset our counter for refresh_period
self.trading_days_since_update = 0
recalculate_needed = True
else:
recalculate_needed = \
args != self.last_args or kwargs != self.last_kwargs
if recalculate_needed:
self.cached = self.compute_transform_value(
self.window,
*args,
**kwargs
)
self.last_args = args
self.last_kwargs = kwargs
return self.cached
def __call__(self, f):
self.compute_transform_value = f
return self.handle_data
return set.union(self.field_names, self.latest_names)
def batch_transform(func):
+88
View File
@@ -0,0 +1,88 @@
#
# Copyright 2013 Quantopian, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pandas as pd
from copy import deepcopy
def _ensure_index(x):
if not isinstance(x, pd.Index):
x = pd.Index(x)
return x
class RollingPanel(object):
"""
Preallocation strategies for rolling window over expanding data set
Restrictions: major_axis can only be a DatetimeIndex for now
"""
def __init__(self, window, items, minor_axis, cap_multiple=2,
dtype=np.float64):
self.pos = 0
self.window = window
self.items = _ensure_index(items)
self.minor_axis = _ensure_index(minor_axis)
self.cap_multiple = cap_multiple
self.cap = cap_multiple * window
self.dtype = dtype
self.index_buf = np.empty(self.cap, dtype='M8[ns]')
self.buffer = pd.Panel(items=items, minor_axis=minor_axis,
major_axis=range(self.cap),
dtype=dtype)
def add_frame(self, tick, frame):
"""
"""
if self.pos == self.cap:
self._roll_data()
self.buffer.values[:, self.pos, :] = frame.ix[self.items].values
self.index_buf[self.pos] = tick
self.pos += 1
def get_current(self):
"""
Get a Panel that is the current data in view. It is not safe to persist
these objects because internal data might change
"""
where = slice(max(self.pos - self.window, 0), self.pos)
major_axis = pd.DatetimeIndex(deepcopy(self.index_buf[where]),
tz='utc')
return pd.Panel(self.buffer.values[:, where, :], self.items,
major_axis, self.minor_axis)
def _roll_data(self):
"""
Roll window worth of data up to position zero.
Save the effort of having to expensively roll at each iteration
"""
self.buffer.values[:, :self.window, :] = \
self.buffer.values[:, -self.window:]
self.index_buf[:self.window] = self.index_buf[-self.window:]
self.pos = self.window
class NaiveRollingPanel(object):
def __init__(self, window, items, minor_axis, cap_multiple=2):
pass