mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 10:31:21 +08:00
Merge pull request #54 from quantopian/enhance_batch_full_panel
Enhance batch full panel
This commit is contained in:
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections import deque
|
||||
|
||||
import pytz
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -342,6 +344,24 @@ class TestBatchTransform(TestCase):
|
||||
'arbitrary dataframe should contain only "test"'
|
||||
)
|
||||
|
||||
for data in algo.history_return_sid_filter[wl:]:
|
||||
self.assertIn(0, data.columns)
|
||||
self.assertNotIn(1, data.columns)
|
||||
|
||||
for data in algo.history_return_field_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertNotIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_field_no_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_ticks[wl:]:
|
||||
self.assertTrue(isinstance(data, deque))
|
||||
|
||||
for data in algo.history_return_not_full:
|
||||
self.assertIsNot(data, None)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
algo.history_return_price_decorator]:
|
||||
|
||||
@@ -72,6 +72,7 @@ The algorithm must expose methods:
|
||||
|
||||
"""
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
from zipline.algorithm import TradingAlgorithm
|
||||
from zipline.finance.slippage import FixedSlippage
|
||||
@@ -268,6 +269,11 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.history_return_args = []
|
||||
self.history_return_arbitrary_fields = []
|
||||
self.history_return_nan = []
|
||||
self.history_return_sid_filter = []
|
||||
self.history_return_field_filter = []
|
||||
self.history_return_field_no_filter = []
|
||||
self.history_return_ticks = []
|
||||
self.history_return_not_full = []
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(
|
||||
refresh_period=self.refresh_period,
|
||||
@@ -305,6 +311,38 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
clean_nans=True
|
||||
)
|
||||
|
||||
self.return_sid_filter = return_price_batch_decorator(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=True,
|
||||
sids=[0]
|
||||
)
|
||||
|
||||
self.return_field_filter = return_data(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=True,
|
||||
fields=['price']
|
||||
)
|
||||
|
||||
self.return_field_no_filter = return_data(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=True
|
||||
)
|
||||
|
||||
self.return_ticks = return_data(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
create_panel=False
|
||||
)
|
||||
|
||||
self.return_not_full = return_data(
|
||||
refresh_period=0,
|
||||
window_length=self.window_length,
|
||||
compute_only_full=False
|
||||
)
|
||||
|
||||
self.iter = 0
|
||||
|
||||
self.set_slippage(FixedSlippage())
|
||||
@@ -317,6 +355,10 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.history_return_args.append(
|
||||
self.return_args_batch.handle_data(
|
||||
data, *self.args, **self.kwargs))
|
||||
self.history_return_ticks.append(
|
||||
self.return_ticks.handle_data(data))
|
||||
self.history_return_not_full.append(
|
||||
self.return_not_full.handle_data(data))
|
||||
|
||||
new_data = deepcopy(data)
|
||||
for sid in new_data:
|
||||
@@ -331,7 +373,6 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.return_nan.handle_data(data))
|
||||
else:
|
||||
nan_data = deepcopy(data)
|
||||
import numpy as np
|
||||
for sid in nan_data.iterkeys():
|
||||
nan_data[sid].price = np.nan
|
||||
self.history_return_nan.append(
|
||||
@@ -339,6 +380,23 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
|
||||
self.iter += 1
|
||||
|
||||
# Add a new sid to check that it does not get included
|
||||
extra_sid_data = deepcopy(data)
|
||||
extra_sid_data[1] = extra_sid_data[0]
|
||||
self.history_return_sid_filter.append(
|
||||
self.return_sid_filter.handle_data(extra_sid_data)
|
||||
)
|
||||
|
||||
# Add a field to check that it does not get included
|
||||
extra_field_data = deepcopy(data)
|
||||
extra_field_data[0]['ignore'] = extra_sid_data[0]['price']
|
||||
self.history_return_field_filter.append(
|
||||
self.return_field_filter.handle_data(extra_field_data)
|
||||
)
|
||||
self.history_return_field_no_filter.append(
|
||||
self.return_field_no_filter.handle_data(extra_field_data)
|
||||
)
|
||||
|
||||
|
||||
class SetPortfolioAlgorithm(TradingAlgorithm):
|
||||
"""
|
||||
|
||||
+71
-18
@@ -24,6 +24,7 @@ from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from numbers import Integral
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@@ -343,7 +344,43 @@ class BatchTransform(EventWindow):
|
||||
func=None,
|
||||
refresh_period=None,
|
||||
window_length=None,
|
||||
clean_nans=True):
|
||||
clean_nans=True,
|
||||
sids=None,
|
||||
fields=None,
|
||||
create_panel=True,
|
||||
compute_only_full=True):
|
||||
|
||||
"""Instantiate new batch_transform object.
|
||||
|
||||
:Arguments:
|
||||
func : python function <optional>
|
||||
If supplied will be called after each refresh_period
|
||||
with the data panel and all args and kwargs supplied
|
||||
to the handle_data() call.
|
||||
refresh_period : int
|
||||
Interval to call batch_transform function.
|
||||
window_length : int
|
||||
How many days the trailing window should have.
|
||||
clean_nans : bool <default=True>
|
||||
Whether to (forward) fill in nans.
|
||||
sids : list <optional>
|
||||
Which sids to include in the moving window. If not
|
||||
supplied sids will be extracted from incoming
|
||||
events.
|
||||
fields : list <optional>
|
||||
Which fields to include in the moving window
|
||||
(e.g. 'price'). If not supplied, fields will be
|
||||
extracted from incoming events.
|
||||
create_panel : bool <default=True>
|
||||
If True, will create a pandas panel every refresh
|
||||
period and pass it to the user-defined function.
|
||||
If False, will pass the underlying deque reference
|
||||
directly to the function which will be significantly
|
||||
faster.
|
||||
compute_only_full : bool <default=True>
|
||||
Only call the user-defined function once the window is
|
||||
full. Returns None if window is not full yet.
|
||||
"""
|
||||
|
||||
super(BatchTransform, self).__init__(True,
|
||||
window_length=window_length)
|
||||
@@ -354,6 +391,16 @@ class BatchTransform(EventWindow):
|
||||
self.compute_transform_value = self.get_value
|
||||
|
||||
self.clean_nans = clean_nans
|
||||
self.create_panel = create_panel
|
||||
self.compute_only_full = compute_only_full
|
||||
|
||||
self.sids = sids
|
||||
if isinstance(self.sids, (basestring, Integral)):
|
||||
self.sids = [self.sids]
|
||||
|
||||
self.field_names = fields
|
||||
if isinstance(self.field_names, str):
|
||||
self.field_names = [self.field_names]
|
||||
|
||||
self.refresh_period = refresh_period
|
||||
self.window_length = window_length
|
||||
@@ -366,8 +413,6 @@ class BatchTransform(EventWindow):
|
||||
self.updated = False
|
||||
self.cached = None
|
||||
|
||||
self.field_names = None
|
||||
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
New method to handle a data frame as sent to the algorithm's
|
||||
@@ -410,9 +455,9 @@ class BatchTransform(EventWindow):
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_dt:
|
||||
self.field_names = self._extract_field_names(event)
|
||||
if self.field_names is None:
|
||||
self.field_names = self._extract_field_names(event)
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
# update trading day counters
|
||||
if self.last_dt.day != event.dt.day:
|
||||
@@ -420,15 +465,14 @@ class BatchTransform(EventWindow):
|
||||
self.trading_days_since_update += 1
|
||||
self.trading_days_total += 1
|
||||
|
||||
if (
|
||||
self.trading_days_total >= self.window_length and
|
||||
self.trading_days_since_update >= self.refresh_period
|
||||
):
|
||||
if self.trading_days_total >= self.window_length:
|
||||
self.full = True
|
||||
|
||||
if self.trading_days_since_update >= self.refresh_period:
|
||||
# Setting updated to True will cause get_transform_value()
|
||||
# to call the user-defined batch-transform with the most
|
||||
# recent datapanel
|
||||
self.updated = True
|
||||
self.full = True
|
||||
self.trading_days_since_update = 0
|
||||
else:
|
||||
self.updated = False
|
||||
@@ -445,7 +489,13 @@ class BatchTransform(EventWindow):
|
||||
"""
|
||||
# This Panel data structure ultimately gets passed to the
|
||||
# user-overloaded get_value() method.
|
||||
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
|
||||
|
||||
# If sids are set, use those. Otherwise extract.
|
||||
if self.sids is not None:
|
||||
sids = self.sids
|
||||
else:
|
||||
sids = set.union(*[set(tick.data.keys()) for tick in self.ticks])
|
||||
|
||||
dts = [tick.dt for tick in self.ticks]
|
||||
|
||||
data = pd.Panel(items=self.field_names, major_axis=dts,
|
||||
@@ -454,9 +504,10 @@ class BatchTransform(EventWindow):
|
||||
# Fill data panel
|
||||
for tick in self.ticks:
|
||||
dt = tick.dt
|
||||
for sid, fields in tick.data.iteritems():
|
||||
for sid in sids:
|
||||
fields = tick.data[sid]
|
||||
for field_name in self.field_names:
|
||||
data[field_name][sid].ix[dt] = fields[field_name]
|
||||
data[field_name][sid].ix[dt] = fields[field_name]
|
||||
|
||||
if self.clean_nans:
|
||||
# Fills in gaps of missing data during transform
|
||||
@@ -471,8 +522,7 @@ class BatchTransform(EventWindow):
|
||||
return data
|
||||
|
||||
def handle_remove(self, event):
|
||||
# since an event is expiring, we know the window is full
|
||||
self.full = True
|
||||
pass
|
||||
|
||||
def get_value(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
@@ -486,12 +536,15 @@ class BatchTransform(EventWindow):
|
||||
has actually been updated. Otherwise, the previously, cached
|
||||
value will be returned.
|
||||
"""
|
||||
if not self.full:
|
||||
if self.compute_only_full and not self.full:
|
||||
return None
|
||||
|
||||
if self.updated:
|
||||
self.cached = self.compute_transform_value(self.get_data(),
|
||||
*args, **kwargs)
|
||||
# Either create new pandas panel or pass ticks dequeue
|
||||
# directly
|
||||
data = self.get_data() if self.create_panel else self.ticks
|
||||
self.cached = self.compute_transform_value(data, *args,
|
||||
**kwargs)
|
||||
|
||||
return self.cached
|
||||
|
||||
|
||||
Reference in New Issue
Block a user