mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 03:48:58 +08:00
MAINT: Refactor AlgorithmSimulator.transform.
Breaks out the main snapshot processing loop into its own function, and does some minor variable renaming-shuffling. Adds `TradingAlgorithm.on_dt_changed`, a function to be called when the simulation dt changes, prior to processing any events. There should be no difference in behavior as a result of this change.
This commit is contained in:
+11
-1
@@ -631,12 +631,22 @@ class TradingAlgorithm(object):
|
||||
def set_logger(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
def set_datetime(self, dt):
|
||||
def on_dt_changed(self, dt):
|
||||
"""
|
||||
Callback triggered by the simulation loop whenever the current dt
|
||||
changes.
|
||||
|
||||
Any logic that should happen exactly once at the start of each datetime
|
||||
group should happen here.
|
||||
"""
|
||||
assert isinstance(dt, datetime), \
|
||||
"Attempt to set algorithm's current time with non-datetime"
|
||||
assert dt.tzinfo == pytz.utc, \
|
||||
"Algorithm expects a utc datetime"
|
||||
|
||||
self.datetime = dt
|
||||
self.perf_tracker.set_date(dt)
|
||||
self.blotter.set_date(dt)
|
||||
|
||||
@api_method
|
||||
def get_datetime(self):
|
||||
|
||||
@@ -103,13 +103,11 @@ class AlgorithmSimulator(object):
|
||||
# inject the current algo
|
||||
# snapshot time to any log record generated.
|
||||
with self.processor.threadbound():
|
||||
updated = False
|
||||
bm_updated = False
|
||||
for date, snapshot in stream_in:
|
||||
self.algo.set_datetime(date)
|
||||
|
||||
self.simulation_dt = date
|
||||
self.algo.perf_tracker.set_date(date)
|
||||
self.algo.blotter.set_date(date)
|
||||
self.algo.on_dt_changed(date)
|
||||
|
||||
# If we're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
# and don't send a snapshot to handle_data.
|
||||
@@ -122,59 +120,16 @@ class AlgorithmSimulator(object):
|
||||
DATASOURCE_TYPE.CUSTOM):
|
||||
self.update_universe(event)
|
||||
self.algo.perf_tracker.process_event(event)
|
||||
|
||||
else:
|
||||
if self.algo.instant_fill:
|
||||
events = []
|
||||
|
||||
for event in snapshot:
|
||||
if event.type == DATASOURCE_TYPE.TRADE:
|
||||
self.update_universe(event)
|
||||
updated = True
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.BENCHMARK:
|
||||
self.algo.set_datetime(event.dt)
|
||||
bm_updated = True
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.CUSTOM:
|
||||
self.update_universe(event)
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.SPLIT:
|
||||
self.algo.blotter.process_split(event)
|
||||
|
||||
# If we are instantly filling orders we process
|
||||
# them after handle_data().
|
||||
if not self.algo.instant_fill:
|
||||
self.process_event(event)
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
# Send the current state of the universe
|
||||
# to the user's algo.
|
||||
if updated:
|
||||
self.algo.handle_data(self.current_data)
|
||||
updated = False
|
||||
|
||||
# run orders placed in the algorithm call
|
||||
# above through perf tracker before emitting
|
||||
# the perf packet, so that the perf includes
|
||||
# placed orders
|
||||
for order in self.algo.blotter.new_orders:
|
||||
self.algo.perf_tracker.process_event(order)
|
||||
self.algo.blotter.new_orders = []
|
||||
|
||||
# If we are instantly filling we execute orders
|
||||
# in this iteration rather than the next.
|
||||
if self.algo.instant_fill:
|
||||
for event in events:
|
||||
self.process_event(event)
|
||||
|
||||
# The benchmark is our internal clock. When it
|
||||
# updates, we need to emit a performance message.
|
||||
if bm_updated:
|
||||
bm_updated = False
|
||||
self.algo.updated_portfolio()
|
||||
yield self.get_message(date)
|
||||
message = self._process_snapshot(
|
||||
date,
|
||||
snapshot,
|
||||
self.algo.instant_fill,
|
||||
)
|
||||
# Perf messages are only emitted if the snapshot contained
|
||||
# a benchmark event.
|
||||
if message is not None:
|
||||
yield message
|
||||
|
||||
# When emitting minutely, we re-iterate the day as a
|
||||
# packet with the entire days performance rolled up.
|
||||
@@ -191,10 +146,9 @@ class AlgorithmSimulator(object):
|
||||
if mkt_close <= self.algo.perf_tracker.last_close:
|
||||
try:
|
||||
mkt_open, mkt_close = \
|
||||
trading.environment.\
|
||||
next_open_and_close(
|
||||
mkt_close
|
||||
)
|
||||
trading.environment \
|
||||
.next_open_and_close(mkt_close)
|
||||
|
||||
except trading.NoFurtherDataError:
|
||||
# If at the end of backtest history,
|
||||
# skip advancing market close.
|
||||
@@ -207,6 +161,81 @@ class AlgorithmSimulator(object):
|
||||
risk_message = self.algo.perf_tracker.handle_simulation_end()
|
||||
yield risk_message
|
||||
|
||||
def _process_snapshot(self, dt, snapshot, instant_fill):
|
||||
"""
|
||||
Process a stream of events corresponding to a single datetime, possibly
|
||||
returning a perf message to be yielded.
|
||||
|
||||
If @instant_fill = True, we delay processing of events until after the
|
||||
user's call to handle_data, and we process the user's placed orders
|
||||
before the snapshot's events. Note that this introduces a lookahead
|
||||
bias, since the user effectively is effectively placing orders that are
|
||||
filled based on trades that happened prior to the call the handle_data.
|
||||
|
||||
If @instant_fill = False, we process Trade events before calling
|
||||
handle_data. This means that orders are filled based on trades
|
||||
occurring in the next snapshot. This is the more conservative model,
|
||||
and as such it is the default behavior in TradingAlgorithm.
|
||||
"""
|
||||
|
||||
# Flags indicating whether we saw any events of type TRADE and type
|
||||
# BENCHMARK. Respectively, these control whether or not handle_data is
|
||||
# called for this snapshot and whether we emit a perf message for this
|
||||
# snapshot.
|
||||
any_trade_occurred = False
|
||||
benchmark_event_occurred = False
|
||||
|
||||
if instant_fill:
|
||||
events_to_be_processed = []
|
||||
|
||||
for event in snapshot:
|
||||
|
||||
if event.type == DATASOURCE_TYPE.TRADE:
|
||||
self.update_universe(event)
|
||||
any_trade_occurred = True
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.BENCHMARK:
|
||||
benchmark_event_occurred = True
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.CUSTOM:
|
||||
self.update_universe(event)
|
||||
|
||||
elif event.type == DATASOURCE_TYPE.SPLIT:
|
||||
self.algo.blotter.process_split(event)
|
||||
|
||||
if not self.algo.instant_fill:
|
||||
self.process_event(event)
|
||||
else:
|
||||
events_to_be_processed.append(event)
|
||||
|
||||
if any_trade_occurred:
|
||||
new_orders = self._call_handle_data()
|
||||
for order in new_orders:
|
||||
self.algo.perf_tracker.process_event(order)
|
||||
|
||||
if instant_fill:
|
||||
# Now that handle_data has been called and orders have been placed,
|
||||
# process the event stream to fill user orders based on the events
|
||||
# from this snapshot.
|
||||
for event in events_to_be_processed:
|
||||
self.process_event(event)
|
||||
|
||||
if benchmark_event_occurred:
|
||||
self.algo.updated_portfolio()
|
||||
return self.get_message(dt)
|
||||
else:
|
||||
return None
|
||||
|
||||
def _call_handle_data(self):
|
||||
"""
|
||||
Call the user's handle_data, returning any orders placed by the algo
|
||||
during the call.
|
||||
"""
|
||||
self.algo.handle_data(self.current_data)
|
||||
orders = self.algo.blotter.new_orders
|
||||
self.algo.blotter.new_orders = []
|
||||
return orders
|
||||
|
||||
def get_message(self, date):
|
||||
rvars = self.algo.recorded_vars
|
||||
if self.algo.perf_tracker.emission_rate == 'daily':
|
||||
|
||||
Reference in New Issue
Block a user