ENH: batch_transform now supports field filtering.

This commit is contained in:
Thomas Wiecki
2012-12-20 11:56:49 -05:00
parent b68c51afb4
commit b815a57430
3 changed files with 42 additions and 4 deletions
+8
View File
@@ -346,6 +346,14 @@ class TestBatchTransform(TestCase):
self.assertIn(0, data.columns)
self.assertNotIn(1, data.columns)
for data in algo.history_return_field_filter[wl:]:
self.assertIn('price', data.items)
self.assertNotIn('ignore', data.items)
for data in algo.history_return_field_no_filter[wl:]:
self.assertIn('price', data.items)
self.assertIn('ignore', data.items)
# test overloaded class
for test_history in [algo.history_return_price_class,
algo.history_return_price_decorator]:
+25
View File
@@ -269,6 +269,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_return_arbitrary_fields = []
self.history_return_nan = []
self.history_return_sid_filter = []
self.history_return_field_filter = []
self.history_return_field_no_filter = []
self.return_price_class = ReturnPriceBatchTransform(
refresh_period=self.refresh_period,
@@ -313,6 +315,19 @@ class BatchTransformAlgorithm(TradingAlgorithm):
sids=[0]
)
self.return_field_filter = return_data(
refresh_period=self.refresh_period,
window_length=self.window_length,
clean_nans=True,
fields=['price']
)
self.return_field_no_filter = return_data(
refresh_period=self.refresh_period,
window_length=self.window_length,
clean_nans=True
)
self.iter = 0
self.set_slippage(FixedSlippage())
@@ -354,6 +369,16 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.return_sid_filter.handle_data(extra_sid_data)
)
# Add a field to check that it does not get included
extra_field_data = deepcopy(data)
extra_field_data[0]['ignore'] = extra_sid_data[0]['price']
self.history_return_field_filter.append(
self.return_field_filter.handle_data(extra_field_data)
)
self.history_return_field_no_filter.append(
self.return_field_no_filter.handle_data(extra_field_data)
)
class SetPortfolioAlgorithm(TradingAlgorithm):
"""
+9 -4
View File
@@ -344,7 +344,8 @@ class BatchTransform(EventWindow):
refresh_period=None,
window_length=None,
clean_nans=True,
sids=None):
sids=None,
fields=None):
"""Instantiate new batch_transform object.
:Arguments:
@@ -362,7 +363,10 @@ class BatchTransform(EventWindow):
Which sids to include in the moving window. If not
supplied sids will be extracted from incoming
events.
fields : list <optional>
Which fields to include in the moving window
(e.g. 'price'). If not supplied, fields will be
extracted from incoming events.
"""
super(BatchTransform, self).__init__(True,
@@ -388,7 +392,7 @@ class BatchTransform(EventWindow):
self.updated = False
self.cached = None
self.field_names = None
self.field_names = fields
def handle_data(self, data, *args, **kwargs):
"""
@@ -432,7 +436,8 @@ class BatchTransform(EventWindow):
def handle_add(self, event):
if not self.last_dt:
self.field_names = self._extract_field_names(event)
if self.field_names is None:
self.field_names = self._extract_field_names(event)
self.last_dt = event.dt
return