MAINT: Refactor AlgorithmSimulator.transform.

Breaks out the main snapshot processing loop into its own function, and does
some minor variable renaming-shuffling.

Adds `TradingAlgorithm.on_dt_changed`, a function to be called when the
simulation dt changes, prior to processing any events.

There should be no difference in behavior as a result of this change.
This commit is contained in:
Scott Sanderson
2014-07-08 13:51:17 -04:00
parent 8d16efc5c4
commit 0176279404
2 changed files with 101 additions and 62 deletions
+11 -1
View File
@@ -631,12 +631,22 @@ class TradingAlgorithm(object):
def set_logger(self, logger):
self.logger = logger
def set_datetime(self, dt):
def on_dt_changed(self, dt):
"""
Callback triggered by the simulation loop whenever the current dt
changes.
Any logic that should happen exactly once at the start of each datetime
group should happen here.
"""
assert isinstance(dt, datetime), \
"Attempt to set algorithm's current time with non-datetime"
assert dt.tzinfo == pytz.utc, \
"Algorithm expects a utc datetime"
self.datetime = dt
self.perf_tracker.set_date(dt)
self.blotter.set_date(dt)
@api_method
def get_datetime(self):
+90 -61
View File
@@ -103,13 +103,11 @@ class AlgorithmSimulator(object):
# inject the current algo
# snapshot time to any log record generated.
with self.processor.threadbound():
updated = False
bm_updated = False
for date, snapshot in stream_in:
self.algo.set_datetime(date)
self.simulation_dt = date
self.algo.perf_tracker.set_date(date)
self.algo.blotter.set_date(date)
self.algo.on_dt_changed(date)
# If we're still in the warmup period. Use the event to
# update our universe, but don't yield any perf messages,
# and don't send a snapshot to handle_data.
@@ -122,59 +120,16 @@ class AlgorithmSimulator(object):
DATASOURCE_TYPE.CUSTOM):
self.update_universe(event)
self.algo.perf_tracker.process_event(event)
else:
if self.algo.instant_fill:
events = []
for event in snapshot:
if event.type == DATASOURCE_TYPE.TRADE:
self.update_universe(event)
updated = True
elif event.type == DATASOURCE_TYPE.BENCHMARK:
self.algo.set_datetime(event.dt)
bm_updated = True
elif event.type == DATASOURCE_TYPE.CUSTOM:
self.update_universe(event)
elif event.type == DATASOURCE_TYPE.SPLIT:
self.algo.blotter.process_split(event)
# If we are instantly filling orders we process
# them after handle_data().
if not self.algo.instant_fill:
self.process_event(event)
else:
events.append(event)
# Send the current state of the universe
# to the user's algo.
if updated:
self.algo.handle_data(self.current_data)
updated = False
# run orders placed in the algorithm call
# above through perf tracker before emitting
# the perf packet, so that the perf includes
# placed orders
for order in self.algo.blotter.new_orders:
self.algo.perf_tracker.process_event(order)
self.algo.blotter.new_orders = []
# If we are instantly filling we execute orders
# in this iteration rather than the next.
if self.algo.instant_fill:
for event in events:
self.process_event(event)
# The benchmark is our internal clock. When it
# updates, we need to emit a performance message.
if bm_updated:
bm_updated = False
self.algo.updated_portfolio()
yield self.get_message(date)
message = self._process_snapshot(
date,
snapshot,
self.algo.instant_fill,
)
# Perf messages are only emitted if the snapshot contained
# a benchmark event.
if message is not None:
yield message
# When emitting minutely, we re-iterate the day as a
# packet with the entire days performance rolled up.
@@ -191,10 +146,9 @@ class AlgorithmSimulator(object):
if mkt_close <= self.algo.perf_tracker.last_close:
try:
mkt_open, mkt_close = \
trading.environment.\
next_open_and_close(
mkt_close
)
trading.environment \
.next_open_and_close(mkt_close)
except trading.NoFurtherDataError:
# If at the end of backtest history,
# skip advancing market close.
@@ -207,6 +161,81 @@ class AlgorithmSimulator(object):
risk_message = self.algo.perf_tracker.handle_simulation_end()
yield risk_message
def _process_snapshot(self, dt, snapshot, instant_fill):
"""
Process a stream of events corresponding to a single datetime, possibly
returning a perf message to be yielded.
If @instant_fill = True, we delay processing of events until after the
user's call to handle_data, and we process the user's placed orders
before the snapshot's events. Note that this introduces a lookahead
bias, since the user effectively is effectively placing orders that are
filled based on trades that happened prior to the call the handle_data.
If @instant_fill = False, we process Trade events before calling
handle_data. This means that orders are filled based on trades
occurring in the next snapshot. This is the more conservative model,
and as such it is the default behavior in TradingAlgorithm.
"""
# Flags indicating whether we saw any events of type TRADE and type
# BENCHMARK. Respectively, these control whether or not handle_data is
# called for this snapshot and whether we emit a perf message for this
# snapshot.
any_trade_occurred = False
benchmark_event_occurred = False
if instant_fill:
events_to_be_processed = []
for event in snapshot:
if event.type == DATASOURCE_TYPE.TRADE:
self.update_universe(event)
any_trade_occurred = True
elif event.type == DATASOURCE_TYPE.BENCHMARK:
benchmark_event_occurred = True
elif event.type == DATASOURCE_TYPE.CUSTOM:
self.update_universe(event)
elif event.type == DATASOURCE_TYPE.SPLIT:
self.algo.blotter.process_split(event)
if not self.algo.instant_fill:
self.process_event(event)
else:
events_to_be_processed.append(event)
if any_trade_occurred:
new_orders = self._call_handle_data()
for order in new_orders:
self.algo.perf_tracker.process_event(order)
if instant_fill:
# Now that handle_data has been called and orders have been placed,
# process the event stream to fill user orders based on the events
# from this snapshot.
for event in events_to_be_processed:
self.process_event(event)
if benchmark_event_occurred:
self.algo.updated_portfolio()
return self.get_message(dt)
else:
return None
def _call_handle_data(self):
"""
Call the user's handle_data, returning any orders placed by the algo
during the call.
"""
self.algo.handle_data(self.current_data)
orders = self.algo.blotter.new_orders
self.algo.blotter.new_orders = []
return orders
def get_message(self, date):
rvars = self.algo.recorded_vars
if self.algo.perf_tracker.emission_rate == 'daily':