BUG: batch_transform was wrongly updating when days < refresh_period.

This commit is contained in:
Thomas Wiecki
2012-11-19 10:20:24 -05:00
parent d5697cdf0a
commit 4c02fea6e2
3 changed files with 69 additions and 34 deletions
+8
View File
@@ -327,6 +327,14 @@ class TestBatchTransform(TestCase):
self.assertEqual(algo.history_return_price_market_aware[:2],
[None, None],
"First two iterations should return None")
self.assertEqual(algo.history_return_more_days_than_refresh[:3],
[None, None, None],
"First five iterations should return None")
self.assertTrue(isinstance(
algo.history_return_more_days_than_refresh[4],
pd.DataFrame),
"Sixth iteration should not be None"
)
# test overloaded class
for test_history in [algo.history_return_price_class,
+9
View File
@@ -260,6 +260,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
self.history_return_price_decorator = []
self.history_return_args = []
self.history_return_price_market_aware = []
self.history_return_more_days_than_refresh = []
self.return_price_class = ReturnPriceBatchTransform(
market_aware=False,
@@ -285,6 +286,12 @@ class BatchTransformAlgorithm(TradingAlgorithm):
days=self.days
)
self.return_price_more_days_than_refresh = ReturnPriceBatchTransform(
market_aware=True,
refresh_period=1,
days=3
)
self.set_slippage(FixedSlippage())
def handle_data(self, data):
@@ -297,6 +304,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
data, *self.args, **self.kwargs))
self.history_return_price_market_aware.append(
self.return_price_market_aware.handle_data(data))
self.history_return_more_days_than_refresh.append(
self.return_price_more_days_than_refresh.handle_data(data))
class SetPortfolioAlgorithm(TradingAlgorithm):
+52 -34
View File
@@ -361,6 +361,7 @@ class BatchTransform(EventWindow):
self.refresh_period = refresh_period
self.days = days
self.trading_days_since_update = 0
self.trading_days_total = 0
self.full = False
self.last_dt = None
@@ -397,49 +398,59 @@ class BatchTransform(EventWindow):
self.last_dt = event.dt
return
# update trading day counters
if self.last_dt.day != event.dt.day:
self.last_dt = event.dt
self.trading_days_since_update += 1
self.trading_days_total += 1
if self.trading_days_since_update >= self.refresh_period:
# Create a pandas.Panel (i.e. 3d DataFrame) from the
# events in the current window.
#
# The resulting panel looks like this:
# index : field_name (e.g. price)
# major axis/rows : dt
# minor axis/colums : sid
#
# This Panel data structure ultimately gets passed to the
# user-overloaded get_value() method.
#
# self.ticks contains ndicts with data, dt keys.
# event parameter is an ndict with data, dt keys.
fields = {}
for field_name in ['price', 'volume']:
sids = self.ticks[0].data.keys()
# Skip non-existant fields
if field_name not in self.ticks[0].data[sids[0]]:
continue
values_per_sid = {}
for sid in sids:
values_per_sid[sid] = pd.Series(
{tick.data[sid].dt: tick.data[sid][field_name]
for tick in self.ticks}
)
# concatenate different sids into one df
fields[field_name] = pd.DataFrame.from_dict(values_per_sid)
self.data = pd.Panel.from_dict(fields, orient='items')
if self.trading_days_since_update >= self.refresh_period and\
self.trading_days_total >= self.days:
# Create datapanel of running event window.
self.data = self.get_data()
# Setting updated to True will cause get_transform_value()
# to call the user-defined batch-transform with the most
# recent datapanel
self.updated = True
self.trading_days_since_update = 0
else:
self.updated = False
def get_data(self):
# Create a pandas.Panel (i.e. 3d DataFrame) from the
# events in the current window.
#
# The resulting panel looks like this:
# index : field_name (e.g. price)
# major axis/rows : dt
# minor axis/colums : sid
#
# This Panel data structure ultimately gets passed to the
# user-overloaded get_value() method.
#
# self.ticks contains ndicts with data, dt keys.
# event parameter is an ndict with data, dt keys.
fields = {}
for field_name in ['price', 'volume']:
sids = self.ticks[0].data.keys()
# Skip non-existant fields
if field_name not in self.ticks[0].data[sids[0]]:
continue
values_per_sid = {}
for sid in sids:
values_per_sid[sid] = pd.Series(
{tick.data[sid].dt: tick.data[sid][field_name]
for tick in self.ticks}
)
# concatenate different sids into one df
fields[field_name] = pd.DataFrame.from_dict(values_per_sid)
data = pd.Panel.from_dict(fields, orient='items')
return data
def handle_remove(self, event):
# since an event is expiring, we know the window is full
self.full = True
@@ -449,6 +460,13 @@ class BatchTransform(EventWindow):
"Either overwrite get_value or provide a func argument.")
def get_transform_value(self, *args, **kwargs):
"""Call user-defined batch-transform function passing all
arguments.
Note that this will only call the transform if the datapanel
has actually been updated. Otherwise, the previously, cached
value will be returned.
"""
if self.data is None:
return None