mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 01:53:27 +08:00
ENH: Add downsampling to BatchTransform.
So that with minute data, 2.5 orders of magnitude of data can be cut, allowing for longer window_lenghts, when the daily values are what are desired for a signal.
This commit is contained in:
@@ -380,7 +380,6 @@ class PerformancePeriod(object):
|
||||
positions = self._positions_store
|
||||
|
||||
for sid, pos in self.positions.iteritems():
|
||||
|
||||
if sid not in positions:
|
||||
positions[sid] = zp.Position(sid)
|
||||
position = positions[sid]
|
||||
|
||||
@@ -34,6 +34,42 @@ import zipline.finance.trading as trading
|
||||
from . utils import check_window_length
|
||||
|
||||
log = logbook.Logger('BatchTransform')
|
||||
func_map = {'open_price': 'first',
|
||||
'close_price': 'last',
|
||||
'low': 'min',
|
||||
'high': 'max',
|
||||
'volume': 'sum'
|
||||
}
|
||||
|
||||
|
||||
def get_sample_func(item):
|
||||
if item in func_map:
|
||||
return func_map[item]
|
||||
else:
|
||||
return 'last'
|
||||
|
||||
|
||||
def downsample_panel(minute_rp, daily_rp, dt):
|
||||
"""
|
||||
@minute_rp is a rolling panel, which should have minutely rows
|
||||
@daily_rp is a rolling panel, which should have daily rows
|
||||
@dt is the timestamp to use when adding a frame to daily_rp
|
||||
|
||||
Using the history in minute_rp, a new daily bar is created by
|
||||
downsampling. The data from the daily bar is then added to the
|
||||
daily rolling panel using add_frame.
|
||||
"""
|
||||
|
||||
cur_panel = minute_rp.get_current()
|
||||
sids = minute_rp.minor_axis
|
||||
day_frame = pd.DataFrame(columns=sids, index=cur_panel.items)
|
||||
for item in minute_rp.items:
|
||||
frame = cur_panel[item]
|
||||
func = get_sample_func(item)
|
||||
dframe = frame.groupby(lambda d: d.date()).resample('1d', how=func)
|
||||
for stock in sids:
|
||||
day_frame[stock][item] = dframe[stock][dframe.index[-1][0]]
|
||||
daily_rp.add_frame(dt, day_frame)
|
||||
|
||||
|
||||
class BatchTransform(object):
|
||||
@@ -83,7 +119,8 @@ class BatchTransform(object):
|
||||
sids=None,
|
||||
fields=None,
|
||||
compute_only_full=True,
|
||||
bars='daily'):
|
||||
bars='daily',
|
||||
downsample=False):
|
||||
|
||||
"""Instantiate new batch_transform object.
|
||||
|
||||
@@ -109,6 +146,8 @@ class BatchTransform(object):
|
||||
compute_only_full : bool <default=True>
|
||||
Only call the user-defined function once the window is
|
||||
full. Returns None if window is not full yet.
|
||||
downsample : bool <default=False>
|
||||
If true, downsample bars to daily bars. Otherwise, do nothing.
|
||||
"""
|
||||
if func is not None:
|
||||
self.compute_transform_value = func
|
||||
@@ -117,6 +156,8 @@ class BatchTransform(object):
|
||||
|
||||
self.clean_nans = clean_nans
|
||||
self.compute_only_full = compute_only_full
|
||||
# no need to down sample if the bars are already daily
|
||||
self.downsample = downsample and (bars == 'minute')
|
||||
|
||||
# How many bars are in a day
|
||||
self.bars = bars
|
||||
@@ -168,6 +209,7 @@ class BatchTransform(object):
|
||||
self.supplemental_data = None
|
||||
|
||||
self.rolling_panel = None
|
||||
self.daily_rolling_panel = None
|
||||
|
||||
def handle_data(self, data, *args, **kwargs):
|
||||
"""
|
||||
@@ -201,6 +243,18 @@ class BatchTransform(object):
|
||||
# return newly computed or cached value
|
||||
return self.get_transform_value(*args, **kwargs)
|
||||
|
||||
def _init_panels(self, sids):
|
||||
if self.downsample:
|
||||
self.rolling_panel = RollingPanel(self.bars_in_day,
|
||||
self.field_names, sids)
|
||||
|
||||
self.daily_rolling_panel = RollingPanel(self.window_length,
|
||||
self.field_names, sids)
|
||||
else:
|
||||
self.rolling_panel = RollingPanel(self.window_length *
|
||||
self.bars_in_day,
|
||||
self.field_names, sids)
|
||||
|
||||
def _append_to_window(self, event):
|
||||
self.field_names = self._get_field_names(event)
|
||||
|
||||
@@ -211,15 +265,7 @@ class BatchTransform(object):
|
||||
|
||||
# Create rolling panel if not existant
|
||||
if self.rolling_panel is None:
|
||||
self.rolling_panel = RollingPanel(self.window_length *
|
||||
self.bars_in_day,
|
||||
self.field_names, sids)
|
||||
|
||||
# Store event in rolling frame
|
||||
self.rolling_panel.add_frame(event.dt,
|
||||
pd.DataFrame(event.data,
|
||||
index=self.field_names,
|
||||
columns=sids))
|
||||
self._init_panels(sids)
|
||||
|
||||
# update trading day counters
|
||||
_, mkt_close = trading.environment.get_open_and_close(event.dt)
|
||||
@@ -227,8 +273,18 @@ class BatchTransform(object):
|
||||
# Daily bars have their dt set to midnight.
|
||||
mkt_close = mkt_close.replace(hour=0, minute=0, second=0)
|
||||
if event.dt >= mkt_close:
|
||||
if self.downsample:
|
||||
downsample_panel(self.rolling_panel,
|
||||
self.daily_rolling_panel,
|
||||
mkt_close)
|
||||
self.trading_days_total += 1
|
||||
|
||||
# Store event in rolling frame
|
||||
self.rolling_panel.add_frame(event.dt,
|
||||
pd.DataFrame(event.data,
|
||||
index=self.field_names,
|
||||
columns=sids))
|
||||
|
||||
self.last_dt = event.dt
|
||||
|
||||
if self.trading_days_total >= self.window_length:
|
||||
@@ -281,7 +337,10 @@ class BatchTransform(object):
|
||||
major axis/rows : dt
|
||||
minor axis/colums : sid
|
||||
"""
|
||||
data = self.rolling_panel.get_current()
|
||||
if self.downsample:
|
||||
data = self.daily_rolling_panel.get_current()
|
||||
else:
|
||||
data = self.rolling_panel.get_current()
|
||||
|
||||
if self.supplemental_data:
|
||||
for item in data.items:
|
||||
|
||||
Reference in New Issue
Block a user