mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 01:53:27 +08:00
Merge batch transform and rolling panel enhancements.
Branch provides a rolling pandas data panel, and converts batch transform to use the new panel type.
This commit is contained in:
+145
-12
@@ -1,3 +1,18 @@
|
||||
#
|
||||
# Copyright 2013 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import pytz
|
||||
@@ -9,9 +24,105 @@ from unittest import TestCase
|
||||
|
||||
from zipline.utils.test_utils import setup_logger
|
||||
|
||||
from zipline.sources.data_source import DataSource
|
||||
import zipline.utils.factory as factory
|
||||
|
||||
from zipline.test_algorithms import BatchTransformAlgorithm
|
||||
from zipline.test_algorithms import (BatchTransformAlgorithm,
|
||||
batch_transform,
|
||||
ReturnPriceBatchTransform)
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.utils.tradingcalendar import trading_days
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@batch_transform
|
||||
def return_price(data):
|
||||
return data.price
|
||||
|
||||
|
||||
class BatchTransformAlgorithmSetSid(TradingAlgorithm):
|
||||
def initialize(self, sids):
|
||||
self.history = []
|
||||
|
||||
self.batch_transform = return_price(
|
||||
refresh_period=1,
|
||||
window_length=10,
|
||||
clean_nans=False,
|
||||
sids=sids,
|
||||
compute_only_full=False
|
||||
)
|
||||
|
||||
def handle_data(self, data):
|
||||
self.history.append(
|
||||
deepcopy(self.batch_transform.handle_data(data)))
|
||||
|
||||
|
||||
class DifferentSidSource(DataSource):
|
||||
def __init__(self):
|
||||
self.dates = pd.date_range('1990-01-01', periods=180, tz='utc')
|
||||
self.start = self.dates[0]
|
||||
self.end = self.dates[-1]
|
||||
self._raw_data = None
|
||||
self.sids = range(90)
|
||||
self.sid = 0
|
||||
self.trading_days = []
|
||||
|
||||
@property
|
||||
def instance_hash(self):
|
||||
return '1234'
|
||||
|
||||
@property
|
||||
def raw_data(self):
|
||||
if not self._raw_data:
|
||||
self._raw_data = self.raw_data_gen()
|
||||
return self._raw_data
|
||||
|
||||
@property
|
||||
def mapping(self):
|
||||
return {
|
||||
'dt': (lambda x: x, 'dt'),
|
||||
'sid': (lambda x: x, 'sid'),
|
||||
'price': (float, 'price'),
|
||||
'volume': (int, 'volume'),
|
||||
}
|
||||
|
||||
def raw_data_gen(self):
|
||||
# Create differente sid for each event
|
||||
for date in self.dates:
|
||||
if date not in trading_days:
|
||||
continue
|
||||
event = {'dt': date,
|
||||
'sid': self.sid,
|
||||
'price': self.sid,
|
||||
'volume': self.sid}
|
||||
self.sid += 1
|
||||
self.trading_days.append(date)
|
||||
yield event
|
||||
|
||||
|
||||
class TestChangeOfSids(TestCase):
|
||||
def setUp(self):
|
||||
self.sids = range(90)
|
||||
self.sim_params = factory.create_simulation_parameters(
|
||||
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
|
||||
end=datetime(1990, 1, 8, tzinfo=pytz.utc)
|
||||
)
|
||||
|
||||
def test_all_sids_passed(self):
|
||||
algo = BatchTransformAlgorithmSetSid(self.sids,
|
||||
sim_params=self.sim_params)
|
||||
source = DifferentSidSource()
|
||||
algo.run(source)
|
||||
for df, date in zip(algo.history, source.trading_days):
|
||||
self.assertEqual(df.index[-1], date, "Newest event doesn't \
|
||||
match.")
|
||||
|
||||
for sid in self.sids:
|
||||
self.assertIn(sid, df.columns)
|
||||
|
||||
last_elem = len(df) - 1
|
||||
self.assertEqual(df[last_elem][last_elem], last_elem)
|
||||
|
||||
|
||||
class TestBatchTransform(TestCase):
|
||||
@@ -24,20 +135,23 @@ class TestBatchTransform(TestCase):
|
||||
self.source, self.df = \
|
||||
factory.create_test_df_source(self.sim_params)
|
||||
|
||||
def test_event_window(self):
|
||||
def test_core_functionality(self):
|
||||
algo = BatchTransformAlgorithm(sim_params=self.sim_params)
|
||||
algo.run(self.source)
|
||||
wl = algo.window_length
|
||||
# The following assertion depend on window length of 3
|
||||
self.assertEqual(wl, 3)
|
||||
self.assertEqual(algo.history_return_price_class[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
# If window_length is 3, there should be 2 None events, as the
|
||||
# window fills up on the 3rd day.
|
||||
n_none_events = 2
|
||||
self.assertEqual(algo.history_return_price_class[:n_none_events],
|
||||
[None] * n_none_events,
|
||||
"First two iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_class,))
|
||||
self.assertEqual(algo.history_return_price_decorator[:wl],
|
||||
[None] * wl,
|
||||
"First three iterations should return None." + "\n" +
|
||||
self.assertEqual(algo.history_return_price_decorator[:n_none_events],
|
||||
[None] * n_none_events,
|
||||
"First two iterations should return None." + "\n" +
|
||||
"i.e. no returned values until window is full'" +
|
||||
"%s" % (algo.history_return_price_decorator,))
|
||||
|
||||
@@ -90,8 +204,8 @@ class TestBatchTransform(TestCase):
|
||||
)
|
||||
|
||||
def test_passing_of_args(self):
|
||||
algo = BatchTransformAlgorithm(1,
|
||||
kwarg='str', sim_params=self.sim_params)
|
||||
algo = BatchTransformAlgorithm(1, kwarg='str',
|
||||
sim_params=self.sim_params)
|
||||
self.assertEqual(algo.args, (1,))
|
||||
self.assertEqual(algo.kwargs, {'kwarg': 'str'})
|
||||
|
||||
@@ -105,10 +219,29 @@ class TestBatchTransform(TestCase):
|
||||
None,
|
||||
# 1990-01-03 - window not full
|
||||
None,
|
||||
# 1990-01-04 - window not full, 3rd event
|
||||
None,
|
||||
# 1990-01-04 - window now full, 3rd event
|
||||
expected_item,
|
||||
# 1990-01-05 - window now full
|
||||
expected_item,
|
||||
# 1990-01-08 - window now full
|
||||
expected_item
|
||||
])
|
||||
|
||||
|
||||
def run_batchtransform(window_length=10):
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=datetime(1990, 1, 1, tzinfo=pytz.utc),
|
||||
end=datetime(1995, 1, 8, tzinfo=pytz.utc)
|
||||
)
|
||||
source, df = factory.create_test_df_source(sim_params)
|
||||
|
||||
return_price_class = ReturnPriceBatchTransform(
|
||||
refresh_period=1,
|
||||
window_length=window_length,
|
||||
clean_nans=False
|
||||
)
|
||||
|
||||
for raw_event in source:
|
||||
raw_event['datetime'] = raw_event.dt
|
||||
event = {0: raw_event}
|
||||
return_price_class.handle_data(event)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright 2013 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
import pandas as pd
|
||||
import pandas.util.testing as tm
|
||||
|
||||
from zipline.utils.data import RollingPanel
|
||||
|
||||
|
||||
class TestRollingPanel(unittest.TestCase):
|
||||
|
||||
def test_basics(self):
|
||||
items = ['foo', 'bar', 'baz']
|
||||
minor = ['A', 'B', 'C', 'D']
|
||||
|
||||
window = 10
|
||||
|
||||
rp = RollingPanel(window, items, minor, cap_multiple=2)
|
||||
|
||||
dates = pd.date_range('2000-01-01', periods=30, tz='utc')
|
||||
|
||||
major_deque = deque()
|
||||
|
||||
frames = {}
|
||||
|
||||
for i in range(30):
|
||||
frame = pd.DataFrame(np.random.randn(3, 4), index=items,
|
||||
columns=minor)
|
||||
date = dates[i]
|
||||
|
||||
rp.add_frame(date, frame)
|
||||
|
||||
frames[date] = frame
|
||||
major_deque.append(date)
|
||||
|
||||
if i >= window:
|
||||
major_deque.popleft()
|
||||
|
||||
result = rp.get_current()
|
||||
expected = pd.Panel(frames, items=list(major_deque),
|
||||
major_axis=items, minor_axis=minor)
|
||||
tm.assert_panel_equal(result, expected.swapaxes(0, 1))
|
||||
|
||||
|
||||
def f(option='clever', n=500, copy=False):
|
||||
items = range(5)
|
||||
minor = range(20)
|
||||
window = 100
|
||||
periods = n
|
||||
|
||||
dates = pd.date_range('2000-01-01', periods=periods, tz='utc')
|
||||
frames = {}
|
||||
|
||||
if option == 'clever':
|
||||
rp = RollingPanel(window, items, minor, cap_multiple=2)
|
||||
major_deque = deque()
|
||||
dummy = pd.DataFrame(np.random.randn(len(items), len(minor)),
|
||||
index=items, columns=minor)
|
||||
|
||||
for i in range(periods):
|
||||
frame = dummy * (1 + 0.001 * i)
|
||||
date = dates[i]
|
||||
|
||||
rp.add_frame(date, frame)
|
||||
|
||||
frames[date] = frame
|
||||
major_deque.append(date)
|
||||
|
||||
if i >= window:
|
||||
del frames[major_deque.popleft()]
|
||||
|
||||
result = rp.get_current()
|
||||
if copy:
|
||||
result = result.copy()
|
||||
else:
|
||||
major_deque = deque()
|
||||
dummy = pd.DataFrame(np.random.randn(len(items), len(minor)),
|
||||
index=items, columns=minor)
|
||||
|
||||
for i in range(periods):
|
||||
frame = dummy * (1 + 0.001 * i)
|
||||
date = dates[i]
|
||||
frames[date] = frame
|
||||
major_deque.append(date)
|
||||
|
||||
if i >= window:
|
||||
del frames[major_deque.popleft()]
|
||||
|
||||
result = pd.Panel(frames, items=list(major_deque),
|
||||
major_axis=items, minor_axis=minor)
|
||||
@@ -340,7 +340,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
)
|
||||
|
||||
self.return_not_full = return_data(
|
||||
refresh_period=0,
|
||||
refresh_period=1,
|
||||
window_length=self.window_length,
|
||||
compute_only_full=False
|
||||
)
|
||||
@@ -378,9 +378,9 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.price_multiple.handle_data(data, 1, extra_arg=1)
|
||||
|
||||
if self.price_multiple.full:
|
||||
pre = len(self.price_multiple.ticks)
|
||||
pre = self.price_multiple.rolling_panel.get_current().shape[0]
|
||||
result1 = self.price_multiple.handle_data(data, 1, extra_arg=1)
|
||||
post = len(self.price_multiple.ticks)
|
||||
post = self.price_multiple.rolling_panel.get_current().shape[0]
|
||||
assert pre == post, "batch transform is appending redundant events"
|
||||
result2 = self.price_multiple.handle_data(data, 1, extra_arg=1)
|
||||
assert result1 is result2, "batch transform is not idempotent"
|
||||
|
||||
+137
-150
@@ -20,17 +20,22 @@ Generator versions of transforms.
|
||||
import functools
|
||||
import types
|
||||
import logbook
|
||||
|
||||
import numpy
|
||||
|
||||
from numbers import Integral
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from zipline.utils.data import RollingPanel
|
||||
from zipline.protocol import Event
|
||||
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from numbers import Integral
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from zipline.protocol import Event, DATASOURCE_TYPE
|
||||
from zipline.protocol import DATASOURCE_TYPE
|
||||
from zipline.gens.utils import assert_sort_unframe_protocol, hash_args
|
||||
import zipline.finance.trading as trading
|
||||
|
||||
@@ -263,7 +268,7 @@ class EventWindow(object):
|
||||
(event, self.ticks[0])
|
||||
|
||||
|
||||
class BatchTransform(EventWindow):
|
||||
class BatchTransform(object):
|
||||
"""Base class for batch transforms with a trailing window of
|
||||
variable length. As opposed to pure EventWindows that get a stream
|
||||
of events and are bound to a single SID, this class creates stream
|
||||
@@ -336,9 +341,6 @@ class BatchTransform(EventWindow):
|
||||
Only call the user-defined function once the window is
|
||||
full. Returns None if window is not full yet.
|
||||
"""
|
||||
|
||||
super(BatchTransform, self).__init__(True, window_length=window_length)
|
||||
|
||||
if func is not None:
|
||||
self.compute_transform_value = func
|
||||
else:
|
||||
@@ -351,16 +353,13 @@ class BatchTransform(EventWindow):
|
||||
# to operate on the data, but to also allow new symbols to
|
||||
# enter the batch transform's window IFF a sid filter is not
|
||||
# specified.
|
||||
self.sids = None
|
||||
if sids:
|
||||
self.static_sids = True
|
||||
self.sids = sids
|
||||
if sids is not None:
|
||||
if isinstance(sids, (basestring, Integral)):
|
||||
self.sids = set([sids])
|
||||
self.static_sids = set([sids])
|
||||
else:
|
||||
self.sids = set(sids)
|
||||
self.static_sids = set(sids)
|
||||
else:
|
||||
self.static_sids = False
|
||||
self.static_sids = None
|
||||
|
||||
self.initial_field_names = fields
|
||||
if isinstance(self.initial_field_names, basestring):
|
||||
@@ -368,13 +367,16 @@ class BatchTransform(EventWindow):
|
||||
self.field_names = set()
|
||||
|
||||
self.refresh_period = refresh_period
|
||||
|
||||
check_window_length(window_length)
|
||||
self.window_length = window_length
|
||||
self.trading_days_since_update = 0
|
||||
|
||||
self.trading_days_total = 0
|
||||
self.window = None
|
||||
|
||||
self.full = False
|
||||
self.last_dt = None
|
||||
# Set to -inf essentially to cause update on first attempt.
|
||||
self.last_dt = pd.Timestamp('1900-1-1', tz='UTC')
|
||||
|
||||
self.updated = False
|
||||
self.cached = None
|
||||
@@ -387,13 +389,13 @@ class BatchTransform(EventWindow):
|
||||
# set of stocks per quarter
|
||||
self.supplemental_data = None
|
||||
|
||||
self.rolling_panel = None
|
||||
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
New method to handle a data frame as sent to the algorithm's
|
||||
handle_data method.
|
||||
Point of entry. Process an event frame.
|
||||
"""
|
||||
# extract dates
|
||||
#dts = [data[sid].datetime for sid in self.sids]
|
||||
dts = [event.datetime for event in data.itervalues()]
|
||||
# we have to provide the event with a dt. This is only for
|
||||
# checking if the event is outside the window or not so a
|
||||
@@ -412,20 +414,122 @@ class BatchTransform(EventWindow):
|
||||
# only modify the trailing window if this is
|
||||
# a new event. This is intended to make handle_data
|
||||
# idempotent.
|
||||
if event not in self.ticks:
|
||||
# append data frame to window. update() will call handle_add() and
|
||||
# handle_remove() appropriately, and self.updated
|
||||
# will be modified based on the refresh_period
|
||||
self.update(event)
|
||||
if self.last_dt < event.dt:
|
||||
self.updated = True
|
||||
self._append_to_window(event)
|
||||
else:
|
||||
# we are recalculating based on an old event, so
|
||||
# there is no change in the contents of the trailing
|
||||
# window
|
||||
self.updated = False
|
||||
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def _append_to_window(self, event):
|
||||
self.field_names = self._get_field_names(event)
|
||||
|
||||
if self.static_sids is None:
|
||||
sids = set(event.data.keys())
|
||||
else:
|
||||
sids = self.static_sids
|
||||
|
||||
# Create rolling panel if not existant
|
||||
if self.rolling_panel is None:
|
||||
self.rolling_panel = RollingPanel(self.window_length,
|
||||
self.field_names, sids)
|
||||
|
||||
# Store event in rolling frame
|
||||
self.rolling_panel.add_frame(event.dt,
|
||||
pd.DataFrame(event.data,
|
||||
index=self.field_names,
|
||||
columns=sids))
|
||||
|
||||
# update trading day counters
|
||||
if self.last_dt.day != event.dt.day:
|
||||
self.last_dt = event.dt
|
||||
self.trading_days_total += 1
|
||||
|
||||
if self.trading_days_total >= self.window_length:
|
||||
self.full = True
|
||||
|
||||
def get_transform_value(self, *args, **kwargs):
|
||||
"""Call user-defined batch-transform function passing all
|
||||
arguments.
|
||||
|
||||
Note that this will only call the transform if the datapanel
|
||||
has actually been updated. Otherwise, the previously, cached
|
||||
value will be returned.
|
||||
"""
|
||||
if self.compute_only_full and not self.full:
|
||||
return None
|
||||
|
||||
#################################################
|
||||
# Determine whether we should call the transform
|
||||
# 0. Support historical/legacy usage of '0' signaling,
|
||||
# 'update on every bar'
|
||||
if self.refresh_period == 0:
|
||||
period_signals_update = True
|
||||
else:
|
||||
# 1. Is the refresh period over?
|
||||
period_signals_update = (
|
||||
self.trading_days_total % self.refresh_period == 0)
|
||||
# 2. Have the args or kwargs been changed since last time?
|
||||
args_updated = args != self.last_args or kwargs != self.last_kwargs
|
||||
recalculate_needed = args_updated or (period_signals_update and
|
||||
self.updated)
|
||||
|
||||
if recalculate_needed:
|
||||
self.cached = self.compute_transform_value(
|
||||
self.get_data(),
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.last_args = args
|
||||
self.last_kwargs = kwargs
|
||||
return self.cached
|
||||
|
||||
def get_data(self):
|
||||
"""Create a pandas.Panel (i.e. 3d DataFrame) from the
|
||||
events in the current window.
|
||||
|
||||
Returns:
|
||||
The resulting panel looks like this:
|
||||
index : field_name (e.g. price)
|
||||
major axis/rows : dt
|
||||
minor axis/colums : sid
|
||||
"""
|
||||
data = self.rolling_panel.get_current()
|
||||
|
||||
if self.supplemental_data:
|
||||
# item will be a date stamp
|
||||
for item in data.items:
|
||||
try:
|
||||
data[item] = self.supplemental_data[item].combine_first(
|
||||
data[item])
|
||||
except KeyError:
|
||||
# Only filling in data available in supplemental data.
|
||||
pass
|
||||
|
||||
if self.clean_nans:
|
||||
# Fills in gaps of missing data during transform
|
||||
# of multiple stocks. E.g. we may be missing
|
||||
# minute data because of illiquidity of one stock
|
||||
data = data.fillna(method='ffill')
|
||||
|
||||
# Hold on to a reference to the data,
|
||||
# so that it's easier to find the current data when stepping
|
||||
# through with a debugger
|
||||
self._curr_data = data
|
||||
|
||||
return data
|
||||
|
||||
def get_value(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Either overwrite get_value or provide a func argument.")
|
||||
|
||||
def __call__(self, f):
|
||||
self.compute_transform_value = f
|
||||
return self.handle_data
|
||||
|
||||
def _extract_field_names(self, event):
|
||||
# extract field names from sids (price, volume etc), make sure
|
||||
# every sid has the same fields.
|
||||
@@ -448,129 +552,12 @@ class BatchTransform(EventWindow):
|
||||
'datetime', 'source_id'])
|
||||
return union - unwanted_fields
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_dt:
|
||||
self.last_dt = event.dt
|
||||
|
||||
if self.initial_field_names is None:
|
||||
def _get_field_names(self, event):
|
||||
if self.initial_field_names is not None:
|
||||
return self.initial_field_names
|
||||
else:
|
||||
self.latest_names = self._extract_field_names(event)
|
||||
if self.field_names:
|
||||
self.field_names = \
|
||||
set.union(self.field_names, self.latest_names)
|
||||
else:
|
||||
self.field_names = self.latest_names
|
||||
else:
|
||||
self.field_names = self.initial_field_names
|
||||
|
||||
if not self.static_sids:
|
||||
if self.sids:
|
||||
event_sids = set(event.data.keys())
|
||||
self.sids = set.union(self.sids, event_sids)
|
||||
else:
|
||||
self.sids = set(event.data.keys())
|
||||
|
||||
# update trading day counters
|
||||
if self.last_dt.day != event.dt.day:
|
||||
self.last_dt = event.dt
|
||||
self.trading_days_since_update += 1
|
||||
self.trading_days_total += 1
|
||||
|
||||
if self.trading_days_total >= self.window_length:
|
||||
self.full = True
|
||||
|
||||
if self.trading_days_since_update >= self.refresh_period:
|
||||
# Setting updated to True will cause get_transform_value()
|
||||
# to call the user-defined batch-transform with the most
|
||||
# recent datapanel
|
||||
self.updated = True
|
||||
else:
|
||||
self.updated = False
|
||||
|
||||
def get_data(self):
|
||||
"""Create a pandas.Panel (i.e. 3d DataFrame) from the
|
||||
events in the current window.
|
||||
|
||||
Returns:
|
||||
The resulting panel looks like this:
|
||||
index : field_name (e.g. price)
|
||||
major axis/rows : dt
|
||||
minor axis/colums : sid
|
||||
"""
|
||||
# This Panel data structure ultimately gets passed to the
|
||||
# user-overloaded get_value() method.
|
||||
data_dict = {tick['dt']: tick['data'] for tick in self.ticks}
|
||||
data = pd.Panel(data_dict, major_axis=self.field_names,
|
||||
minor_axis=self.sids,
|
||||
dtype='float')
|
||||
|
||||
if self.supplemental_data:
|
||||
# item will be a date stamp
|
||||
for item in data.items:
|
||||
try:
|
||||
data[item] = self.supplemental_data[item].combine_first(
|
||||
data[item])
|
||||
except KeyError:
|
||||
# Only filling in data available in supplemental data.
|
||||
pass
|
||||
|
||||
data = data.swapaxes(0, 1)
|
||||
|
||||
if self.clean_nans:
|
||||
# Fills in gaps of missing data during transform
|
||||
# of multiple stocks. E.g. we may be missing
|
||||
# minute data because of illiquidity of one stock
|
||||
data = data.fillna(method='ffill')
|
||||
|
||||
# Hold on to a reference to the data,
|
||||
# so that it's easier to find the current data when stepping
|
||||
# through with a debugger
|
||||
self.curr_data = data
|
||||
|
||||
return data
|
||||
|
||||
def handle_remove(self, event):
|
||||
pass
|
||||
|
||||
def get_value(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Either overwrite get_value or provide a func argument.")
|
||||
|
||||
def get_transform_value(self, *args, **kwargs):
|
||||
"""Call user-defined batch-transform function passing all
|
||||
arguments.
|
||||
|
||||
Note that this will only call the transform if the datapanel
|
||||
has actually been updated. Otherwise, the previously, cached
|
||||
value will be returned.
|
||||
"""
|
||||
if self.compute_only_full and not self.full:
|
||||
return None
|
||||
|
||||
recalculate_needed = False
|
||||
if self.updated:
|
||||
# Create new pandas panel
|
||||
self.window = self.get_data()
|
||||
# reset our counter for refresh_period
|
||||
self.trading_days_since_update = 0
|
||||
recalculate_needed = True
|
||||
else:
|
||||
recalculate_needed = \
|
||||
args != self.last_args or kwargs != self.last_kwargs
|
||||
|
||||
if recalculate_needed:
|
||||
self.cached = self.compute_transform_value(
|
||||
self.window,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.last_args = args
|
||||
self.last_kwargs = kwargs
|
||||
return self.cached
|
||||
|
||||
def __call__(self, f):
|
||||
self.compute_transform_value = f
|
||||
return self.handle_data
|
||||
return set.union(self.field_names, self.latest_names)
|
||||
|
||||
|
||||
def batch_transform(func):
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright 2013 Quantopian, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
def _ensure_index(x):
|
||||
if not isinstance(x, pd.Index):
|
||||
x = pd.Index(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class RollingPanel(object):
|
||||
"""
|
||||
Preallocation strategies for rolling window over expanding data set
|
||||
|
||||
Restrictions: major_axis can only be a DatetimeIndex for now
|
||||
"""
|
||||
|
||||
def __init__(self, window, items, minor_axis, cap_multiple=2,
|
||||
dtype=np.float64):
|
||||
self.pos = 0
|
||||
self.window = window
|
||||
|
||||
self.items = _ensure_index(items)
|
||||
self.minor_axis = _ensure_index(minor_axis)
|
||||
|
||||
self.cap_multiple = cap_multiple
|
||||
self.cap = cap_multiple * window
|
||||
|
||||
self.dtype = dtype
|
||||
self.index_buf = np.empty(self.cap, dtype='M8[ns]')
|
||||
self.buffer = pd.Panel(items=items, minor_axis=minor_axis,
|
||||
major_axis=range(self.cap),
|
||||
dtype=dtype)
|
||||
|
||||
def add_frame(self, tick, frame):
|
||||
"""
|
||||
"""
|
||||
if self.pos == self.cap:
|
||||
self._roll_data()
|
||||
self.buffer.values[:, self.pos, :] = frame.ix[self.items].values
|
||||
self.index_buf[self.pos] = tick
|
||||
|
||||
self.pos += 1
|
||||
|
||||
def get_current(self):
|
||||
"""
|
||||
Get a Panel that is the current data in view. It is not safe to persist
|
||||
these objects because internal data might change
|
||||
"""
|
||||
where = slice(max(self.pos - self.window, 0), self.pos)
|
||||
major_axis = pd.DatetimeIndex(deepcopy(self.index_buf[where]),
|
||||
tz='utc')
|
||||
|
||||
return pd.Panel(self.buffer.values[:, where, :], self.items,
|
||||
major_axis, self.minor_axis)
|
||||
|
||||
def _roll_data(self):
|
||||
"""
|
||||
Roll window worth of data up to position zero.
|
||||
Save the effort of having to expensively roll at each iteration
|
||||
"""
|
||||
self.buffer.values[:, :self.window, :] = \
|
||||
self.buffer.values[:, -self.window:]
|
||||
self.index_buf[:self.window] = self.index_buf[-self.window:]
|
||||
self.pos = self.window
|
||||
|
||||
|
||||
class NaiveRollingPanel(object):
|
||||
|
||||
def __init__(self, window, items, minor_axis, cap_multiple=2):
|
||||
pass
|
||||
Reference in New Issue
Block a user