mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-06 04:26:17 +08:00
BUG: batch_transform was wrongly updating when days < refresh_period.
This commit is contained in:
@@ -327,6 +327,14 @@ class TestBatchTransform(TestCase):
|
||||
self.assertEqual(algo.history_return_price_market_aware[:2],
|
||||
[None, None],
|
||||
"First two iterations should return None")
|
||||
self.assertEqual(algo.history_return_more_days_than_refresh[:3],
|
||||
[None, None, None],
|
||||
"First five iterations should return None")
|
||||
self.assertTrue(isinstance(
|
||||
algo.history_return_more_days_than_refresh[4],
|
||||
pd.DataFrame),
|
||||
"Sixth iteration should not be None"
|
||||
)
|
||||
|
||||
# test overloaded class
|
||||
for test_history in [algo.history_return_price_class,
|
||||
|
||||
@@ -260,6 +260,7 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
self.history_return_price_decorator = []
|
||||
self.history_return_args = []
|
||||
self.history_return_price_market_aware = []
|
||||
self.history_return_more_days_than_refresh = []
|
||||
|
||||
self.return_price_class = ReturnPriceBatchTransform(
|
||||
market_aware=False,
|
||||
@@ -285,6 +286,12 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
days=self.days
|
||||
)
|
||||
|
||||
self.return_price_more_days_than_refresh = ReturnPriceBatchTransform(
|
||||
market_aware=True,
|
||||
refresh_period=1,
|
||||
days=3
|
||||
)
|
||||
|
||||
self.set_slippage(FixedSlippage())
|
||||
|
||||
def handle_data(self, data):
|
||||
@@ -297,6 +304,8 @@ class BatchTransformAlgorithm(TradingAlgorithm):
|
||||
data, *self.args, **self.kwargs))
|
||||
self.history_return_price_market_aware.append(
|
||||
self.return_price_market_aware.handle_data(data))
|
||||
self.history_return_more_days_than_refresh.append(
|
||||
self.return_price_more_days_than_refresh.handle_data(data))
|
||||
|
||||
|
||||
class SetPortfolioAlgorithm(TradingAlgorithm):
|
||||
|
||||
+52
-34
@@ -361,6 +361,7 @@ class BatchTransform(EventWindow):
|
||||
self.refresh_period = refresh_period
|
||||
self.days = days
|
||||
self.trading_days_since_update = 0
|
||||
self.trading_days_total = 0
|
||||
|
||||
self.full = False
|
||||
self.last_dt = None
|
||||
@@ -397,49 +398,59 @@ class BatchTransform(EventWindow):
|
||||
self.last_dt = event.dt
|
||||
return
|
||||
|
||||
# update trading day counters
|
||||
if self.last_dt.day != event.dt.day:
|
||||
self.last_dt = event.dt
|
||||
self.trading_days_since_update += 1
|
||||
self.trading_days_total += 1
|
||||
|
||||
if self.trading_days_since_update >= self.refresh_period:
|
||||
# Create a pandas.Panel (i.e. 3d DataFrame) from the
|
||||
# events in the current window.
|
||||
#
|
||||
# The resulting panel looks like this:
|
||||
# index : field_name (e.g. price)
|
||||
# major axis/rows : dt
|
||||
# minor axis/colums : sid
|
||||
#
|
||||
# This Panel data structure ultimately gets passed to the
|
||||
# user-overloaded get_value() method.
|
||||
#
|
||||
# self.ticks contains ndicts with data, dt keys.
|
||||
# event parameter is an ndict with data, dt keys.
|
||||
fields = {}
|
||||
for field_name in ['price', 'volume']:
|
||||
sids = self.ticks[0].data.keys()
|
||||
# Skip non-existant fields
|
||||
if field_name not in self.ticks[0].data[sids[0]]:
|
||||
continue
|
||||
|
||||
values_per_sid = {}
|
||||
|
||||
for sid in sids:
|
||||
values_per_sid[sid] = pd.Series(
|
||||
{tick.data[sid].dt: tick.data[sid][field_name]
|
||||
for tick in self.ticks}
|
||||
)
|
||||
|
||||
# concatenate different sids into one df
|
||||
fields[field_name] = pd.DataFrame.from_dict(values_per_sid)
|
||||
|
||||
self.data = pd.Panel.from_dict(fields, orient='items')
|
||||
|
||||
if self.trading_days_since_update >= self.refresh_period and\
|
||||
self.trading_days_total >= self.days:
|
||||
# Create datapanel of running event window.
|
||||
self.data = self.get_data()
|
||||
# Setting updated to True will cause get_transform_value()
|
||||
# to call the user-defined batch-transform with the most
|
||||
# recent datapanel
|
||||
self.updated = True
|
||||
self.trading_days_since_update = 0
|
||||
else:
|
||||
self.updated = False
|
||||
|
||||
def get_data(self):
|
||||
# Create a pandas.Panel (i.e. 3d DataFrame) from the
|
||||
# events in the current window.
|
||||
#
|
||||
# The resulting panel looks like this:
|
||||
# index : field_name (e.g. price)
|
||||
# major axis/rows : dt
|
||||
# minor axis/colums : sid
|
||||
#
|
||||
# This Panel data structure ultimately gets passed to the
|
||||
# user-overloaded get_value() method.
|
||||
#
|
||||
# self.ticks contains ndicts with data, dt keys.
|
||||
# event parameter is an ndict with data, dt keys.
|
||||
fields = {}
|
||||
for field_name in ['price', 'volume']:
|
||||
sids = self.ticks[0].data.keys()
|
||||
# Skip non-existant fields
|
||||
if field_name not in self.ticks[0].data[sids[0]]:
|
||||
continue
|
||||
|
||||
values_per_sid = {}
|
||||
|
||||
for sid in sids:
|
||||
values_per_sid[sid] = pd.Series(
|
||||
{tick.data[sid].dt: tick.data[sid][field_name]
|
||||
for tick in self.ticks}
|
||||
)
|
||||
|
||||
# concatenate different sids into one df
|
||||
fields[field_name] = pd.DataFrame.from_dict(values_per_sid)
|
||||
|
||||
data = pd.Panel.from_dict(fields, orient='items')
|
||||
return data
|
||||
|
||||
def handle_remove(self, event):
|
||||
# since an event is expiring, we know the window is full
|
||||
self.full = True
|
||||
@@ -449,6 +460,13 @@ class BatchTransform(EventWindow):
|
||||
"Either overwrite get_value or provide a func argument.")
|
||||
|
||||
def get_transform_value(self, *args, **kwargs):
|
||||
"""Call user-defined batch-transform function passing all
|
||||
arguments.
|
||||
|
||||
Note that this will only call the transform if the datapanel
|
||||
has actually been updated. Otherwise, the previously, cached
|
||||
value will be returned.
|
||||
"""
|
||||
if self.data is None:
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user