Merge pull request #54 from quantopian/enhance_batch_full_panel

Enhance batch full panel
This commit is contained in:
Eddie Hebert
2013-01-07 12:13:31 -08:00
3 changed files with 150 additions and 19 deletions
+20
View File
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import pytz
import numpy as np
import pandas as pd
@@ -342,6 +344,24 @@ class TestBatchTransform(TestCase):
'arbitrary dataframe should contain only "test"'
)
for data in algo.history_return_sid_filter[wl:]:
self.assertIn(0, data.columns)
self.assertNotIn(1, data.columns)
for data in algo.history_return_field_filter[wl:]:
self.assertIn('price', data.items)
self.assertNotIn('ignore', data.items)
for data in algo.history_return_field_no_filter[wl:]:
self.assertIn('price', data.items)
self.assertIn('ignore', data.items)
for data in algo.history_return_ticks[wl:]:
self.assertTrue(isinstance(data, deque))
for data in algo.history_return_not_full:
self.assertIsNot(data, None)
# test overloaded class
for test_history in [algo.history_return_price_class,
algo.history_return_price_decorator]:
+59 -1
View File
@@ -72,6 +72,7 @@ The algorithm must expose methods:
"""
from copy import deepcopy
import numpy as np
from zipline.algorithm import TradingAlgorithm
from zipline.finance.slippage import FixedSlippage
@@ -268,6 +269,11 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_return_args = []
self.history_return_arbitrary_fields = []
self.history_return_nan = []
self.history_return_sid_filter = []
self.history_return_field_filter = []
self.history_return_field_no_filter = []
self.history_return_ticks = []
self.history_return_not_full = []
self.return_price_class = ReturnPriceBatchTransform(
refresh_period=self.refresh_period,
@@ -305,6 +311,38 @@ class BatchTransformAlgorithm(TradingAlgorithm):
clean_nans=True
)
self.return_sid_filter = return_price_batch_decorator(
refresh_period=self.refresh_period,
window_length=self.window_length,
clean_nans=True,
sids=[0]
)
self.return_field_filter = return_data(
refresh_period=self.refresh_period,
window_length=self.window_length,
clean_nans=True,
fields=['price']
)
self.return_field_no_filter = return_data(
refresh_period=self.refresh_period,
window_length=self.window_length,
clean_nans=True
)
self.return_ticks = return_data(
refresh_period=self.refresh_period,
window_length=self.window_length,
create_panel=False
)
self.return_not_full = return_data(
refresh_period=0,
window_length=self.window_length,
compute_only_full=False
)
self.iter = 0
self.set_slippage(FixedSlippage())
@@ -317,6 +355,10 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_return_args.append(
self.return_args_batch.handle_data(
data, *self.args, **self.kwargs))
self.history_return_ticks.append(
self.return_ticks.handle_data(data))
self.history_return_not_full.append(
self.return_not_full.handle_data(data))
new_data = deepcopy(data)
for sid in new_data:
@@ -331,7 +373,6 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.return_nan.handle_data(data))
else:
nan_data = deepcopy(data)
import numpy as np
for sid in nan_data.iterkeys():
nan_data[sid].price = np.nan
self.history_return_nan.append(
@@ -339,6 +380,23 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.iter += 1
# Add a new sid to check that it does not get included
extra_sid_data = deepcopy(data)
extra_sid_data[1] = extra_sid_data[0]
self.history_return_sid_filter.append(
self.return_sid_filter.handle_data(extra_sid_data)
)
# Add a field to check that it does not get included
extra_field_data = deepcopy(data)
extra_field_data[0]['ignore'] = extra_sid_data[0]['price']
self.history_return_field_filter.append(
self.return_field_filter.handle_data(extra_field_data)
)
self.history_return_field_no_filter.append(
self.return_field_no_filter.handle_data(extra_field_data)
)
class SetPortfolioAlgorithm(TradingAlgorithm):
"""
+71 -18
View File
@@ -24,6 +24,7 @@ from copy import deepcopy
from datetime import datetime
from collections import deque
from abc import ABCMeta, abstractmethod
from numbers import Integral
import pandas as pd
@@ -343,7 +344,43 @@ class BatchTransform(EventWindow):
func=None,
refresh_period=None,
window_length=None,
clean_nans=True):
clean_nans=True,
sids=None,
fields=None,
create_panel=True,
compute_only_full=True):
"""Instantiate new batch_transform object.
:Arguments:
func : python function <optional>
If supplied will be called after each refresh_period
with the data panel and all args and kwargs supplied
to the handle_data() call.
refresh_period : int
Interval to call batch_transform function.
window_length : int
How many days the trailing window should have.
clean_nans : bool <default=True>
Whether to (forward) fill in nans.
sids : list <optional>
Which sids to include in the moving window. If not
supplied sids will be extracted from incoming
events.
fields : list <optional>
Which fields to include in the moving window
(e.g. 'price'). If not supplied, fields will be
extracted from incoming events.
create_panel : bool <default=True>
If True, will create a pandas panel every refresh
period and pass it to the user-defined function.
If False, will pass the underlying deque reference
directly to the function which will be significantly
faster.
compute_only_full : bool <default=True>
Only call the user-defined function once the window is
full. Returns None if window is not full yet.
"""
super(BatchTransform, self).__init__(True,
window_length=window_length)
@@ -354,6 +391,16 @@ class BatchTransform(EventWindow):
self.compute_transform_value = self.get_value
self.clean_nans = clean_nans
self.create_panel = create_panel
self.compute_only_full = compute_only_full
self.sids = sids
if isinstance(self.sids, (basestring, Integral)):
self.sids = [self.sids]
self.field_names = fields
if isinstance(self.field_names, str):
self.field_names = [self.field_names]
self.refresh_period = refresh_period
self.window_length = window_length
@@ -366,8 +413,6 @@ class BatchTransform(EventWindow):
self.updated = False
self.cached = None
self.field_names = None
def handle_data(self, data, *args, **kwargs):
"""
New method to handle a data frame as sent to the algorithm's
@@ -410,9 +455,9 @@ class BatchTransform(EventWindow):
def handle_add(self, event):
if not self.last_dt:
self.field_names = self._extract_field_names(event)
if self.field_names is None:
self.field_names = self._extract_field_names(event)
self.last_dt = event.dt
return
# update trading day counters
if self.last_dt.day != event.dt.day:
@@ -420,15 +465,14 @@ class BatchTransform(EventWindow):
self.trading_days_since_update += 1
self.trading_days_total += 1
if (
self.trading_days_total >= self.window_length and
self.trading_days_since_update >= self.refresh_period
):
if self.trading_days_total >= self.window_length:
self.full = True
if self.trading_days_since_update >= self.refresh_period:
# Setting updated to True will cause get_transform_value()
# to call the user-defined batch-transform with the most
# recent datapanel
self.updated = True
self.full = True
self.trading_days_since_update = 0
else:
self.updated = False
@@ -445,7 +489,13 @@ class BatchTransform(EventWindow):
"""
# This Panel data structure ultimately gets passed to the
# user-overloaded get_value() method.
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
# If sids are set, use those. Otherwise extract.
if self.sids is not None:
sids = self.sids
else:
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
dts = [tick.dt for tick in self.ticks]
data = pd.Panel(items=self.field_names, major_axis=dts,
@@ -454,9 +504,10 @@ class BatchTransform(EventWindow):
# Fill data panel
for tick in self.ticks:
dt = tick.dt
for sid, fields in tick.data.iteritems():
for sid in sids:
fields = tick.data[sid]
for field_name in self.field_names:
data[field_name][sid].ix[dt] = fields[field_name]
data[field_name][sid].ix[dt] = fields[field_name]
if self.clean_nans:
# Fills in gaps of missing data during transform
@@ -471,8 +522,7 @@ class BatchTransform(EventWindow):
return data
def handle_remove(self, event):
# since an event is expiring, we know the window is full
self.full = True
pass
def get_value(self, *args, **kwargs):
raise NotImplementedError(
@@ -486,12 +536,15 @@ class BatchTransform(EventWindow):
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if not self.full:
if self.compute_only_full and not self.full:
return None
if self.updated:
self.cached = self.compute_transform_value(self.get_data(),
*args, **kwargs)
# Either create new pandas panel or pass ticks dequeue
# directly
data = self.get_data() if self.create_panel else self.ticks
self.cached = self.compute_transform_value(data, *args,
**kwargs)
return self.cached