mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 11:52:58 +08:00
ENH: batch_transform now supports field filtering.
This commit is contained in:
@@ -346,6 +346,14 @@ class TestBatchTransform(TestCase):
|
||||
self.assertIn(0, data.columns)
|
||||
self.assertNotIn(1, data.columns)
|
||||
|
||||
for data in algo.history_return_field_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertNotIn('ignore', data.items)
|
||||
|
||||
for data in algo.history_return_field_no_filter[wl:]:
|
||||
self.assertIn('price', data.items)
|
||||
self.assertIn('ignore', data.items)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
algo.history_return_price_decorator]:
|
||||
|
||||
@@ -269,6 +269,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.history_return_arbitrary_fields = []
|
||||
self.history_return_nan = []
|
||||
self.history_return_sid_filter = []
|
||||
self.history_return_field_filter = []
|
||||
self.history_return_field_no_filter = []
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(
|
||||
refresh_period=self.refresh_period,
|
||||
@@ -313,6 +315,19 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
sids=[0]
|
||||
)
|
||||
|
||||
self.return_field_filter = return_data(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=True,
|
||||
fields=['price']
|
||||
)
|
||||
|
||||
self.return_field_no_filter = return_data(
|
||||
refresh_period=self.refresh_period,
|
||||
window_length=self.window_length,
|
||||
clean_nans=True
|
||||
)
|
||||
|
||||
self.iter = 0
|
||||
|
||||
self.set_slippage(FixedSlippage())
|
||||
@@ -354,6 +369,16 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.return_sid_filter.handle_data(extra_sid_data)
|
||||
)
|
||||
|
||||
# Add a field to check that it does not get included
|
||||
extra_field_data = deepcopy(data)
|
||||
extra_field_data[0]['ignore'] = extra_sid_data[0]['price']
|
||||
self.history_return_field_filter.append(
|
||||
self.return_field_filter.handle_data(extra_field_data)
|
||||
)
|
||||
self.history_return_field_no_filter.append(
|
||||
self.return_field_no_filter.handle_data(extra_field_data)
|
||||
)
|
||||
|
||||
|
||||
class SetPortfolioAlgorithm(TradingAlgorithm):
|
||||
"""
|
||||
|
||||
@@ -344,7 +344,8 @@ class BatchTransform(EventWindow):
|
||||
refresh_period=None,
|
||||
window_length=None,
|
||||
clean_nans=True,
|
||||
sids=None):
|
||||
sids=None,
|
||||
fields=None):
|
||||
"""Instantiate new batch_transform object.
|
||||
|
||||
:Arguments:
|
||||
@@ -362,7 +363,10 @@ class BatchTransform(EventWindow):
|
||||
Which sids to include in the moving window. If not
|
||||
supplied sids will be extracted from incoming
|
||||
events.
|
||||
|
||||
fields : list <optional>
|
||||
Which fields to include in the moving window
|
||||
(e.g. 'price'). If not supplied, fields will be
|
||||
extracted from incoming events.
|
||||
"""
|
||||
|
||||
super(BatchTransform, self).__init__(True,
|
||||
@@ -388,7 +392,7 @@ class BatchTransform(EventWindow):
|
||||
self.updated = False
|
||||
self.cached = None
|
||||
|
||||
self.field_names = None
|
||||
self.field_names = fields
|
||||
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
@@ -432,7 +436,8 @@ class BatchTransform(EventWindow):
|
||||
|
||||
def handle_add(self, event):
|
||||
if not self.last_dt:
|
||||
self.field_names = self._extract_field_names(event)
|
||||
if self.field_names is None:
|
||||
self.field_names = self._extract_field_names(event)
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user