mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 23:24:02 +08:00
ENH: Change simulation loop to use benchmarks as simulation 'clock'.
Refactor PerformanceTracker, Blotter, and AlgorithmSimulator to work with handling the end of a bar at the AlgorithmSimulator level instead of within PerformanceTracker. - PerforamnceTracker and Blotter are longer generators, both provide functions to process events instead. - AlgorithmSimulator calls each from within the loop running over the data generator. - Change test_perf_tracker utility to be compatible with change away from PerformanceTracker as a generator. Has the effect of: - Fixing the timing of order emission. - Allow minutely emission of benchmarks, which was prevented by the extra grouping previously caused by Blotter. Minutely emission also depends on work for streaming benchmarks through performance and risk at a minute granularity.
This commit is contained in:
+53
-57
@@ -22,7 +22,6 @@ from nose_parameterized import parameterized
|
||||
import datetime
|
||||
import pytz
|
||||
import itertools
|
||||
from operator import attrgetter
|
||||
|
||||
import zipline.utils.factory as factory
|
||||
import zipline.finance.performance as perf
|
||||
@@ -63,23 +62,26 @@ def calculate_results(host, events):
|
||||
|
||||
perf_tracker = perf.PerformanceTracker(host.sim_params)
|
||||
|
||||
all_events = (msg[1] for msg in heapq.merge(
|
||||
all_events = heapq.merge(
|
||||
((event.dt, event) for event in events),
|
||||
((event.dt, event) for event in host.benchmark_events)))
|
||||
((event.dt, event) for event in host.benchmark_events))
|
||||
|
||||
transformed_events = list(perf_tracker.transform(
|
||||
itertools.groupby(all_events, attrgetter('dt'))))
|
||||
|
||||
#flatten the list of events
|
||||
filtered_events = [(date, filt_event) for (date, filt_event)
|
||||
in all_events if date <= events[-1].dt]
|
||||
filtered_events.sort(key=lambda x: x[0])
|
||||
grouped_events = itertools.groupby(filtered_events, lambda x: x[0])
|
||||
results = []
|
||||
for te in transformed_events:
|
||||
for event in te[1]:
|
||||
for message in event.perf_messages:
|
||||
results.append(message)
|
||||
|
||||
perf_messages, risk = perf_tracker.handle_simulation_end()
|
||||
results.append(perf_messages[0])
|
||||
|
||||
bm_updated = False
|
||||
for date, group in grouped_events:
|
||||
for _, event in group:
|
||||
perf_tracker.process_event(event)
|
||||
if event.type == DATASOURCE_TYPE.BENCHMARK:
|
||||
bm_updated = True
|
||||
if bm_updated:
|
||||
msg = perf_tracker.handle_market_close()
|
||||
results.append(msg)
|
||||
bm_updated = False
|
||||
return results
|
||||
|
||||
|
||||
@@ -239,9 +241,9 @@ class TestDividendPerformance(unittest.TestCase):
|
||||
)
|
||||
|
||||
buy_txn = create_txn(1, 10.0, 100, events[1].dt)
|
||||
events.insert(2, buy_txn)
|
||||
events.insert(1, buy_txn)
|
||||
sell_txn = create_txn(1, 10.0, -100, events[3].dt)
|
||||
events.insert(4, sell_txn)
|
||||
events.insert(3, sell_txn)
|
||||
events.insert(1, dividend)
|
||||
results = calculate_results(self, events)
|
||||
|
||||
@@ -267,12 +269,16 @@ class TestDividendPerformance(unittest.TestCase):
|
||||
self.sim_params
|
||||
)
|
||||
|
||||
pay_date = self.sim_params.first_open
|
||||
# find pay date that is much later.
|
||||
for i in xrange(30):
|
||||
pay_date = factory.get_next_trading_dt(pay_date, oneday)
|
||||
dividend = factory.create_dividend(
|
||||
1,
|
||||
10.00,
|
||||
events[0].dt,
|
||||
events[1].dt,
|
||||
events[-1].dt + 10 * oneday
|
||||
pay_date
|
||||
)
|
||||
|
||||
buy_txn = create_txn(1, 10.0, 100, events[1].dt)
|
||||
@@ -308,9 +314,11 @@ class TestDividendPerformance(unittest.TestCase):
|
||||
dividend = factory.create_dividend(
|
||||
1,
|
||||
10.00,
|
||||
# declare at open of test
|
||||
events[0].dt,
|
||||
events[1].dt,
|
||||
events[2].dt
|
||||
# ex_date same as trade 2
|
||||
events[2].dt,
|
||||
events[3].dt
|
||||
)
|
||||
|
||||
txn = create_txn(1, 10.0, -100, events[1].dt)
|
||||
@@ -321,14 +329,14 @@ class TestDividendPerformance(unittest.TestCase):
|
||||
self.assertEqual(len(results), 5)
|
||||
cumulative_returns = \
|
||||
[event['cumulative_perf']['returns'] for event in results]
|
||||
self.assertEqual(cumulative_returns, [0.0, 0.0, -0.1, -0.1, -0.1])
|
||||
self.assertEqual(cumulative_returns, [0.0, 0.0, 0.0, -0.1, -0.1])
|
||||
daily_returns = [event['daily_perf']['returns'] for event in results]
|
||||
self.assertEqual(daily_returns, [0.0, 0.0, -0.1, 0.0, 0.0])
|
||||
self.assertEqual(daily_returns, [0.0, 0.0, 0.0, -0.1, 0.0])
|
||||
cash_flows = [event['daily_perf']['capital_used'] for event in results]
|
||||
self.assertEqual(cash_flows, [1000, 0, -1000, 0, 0])
|
||||
self.assertEqual(cash_flows, [0, 1000, 0, -1000, 0])
|
||||
cumulative_cash_flows = \
|
||||
[event['cumulative_perf']['capital_used'] for event in results]
|
||||
self.assertEqual(cumulative_cash_flows, [1000, 1000, 0, 0, 0])
|
||||
self.assertEqual(cumulative_cash_flows, [0, 1000, 1000, 0, 0])
|
||||
|
||||
def test_no_position_receives_no_dividend(self):
|
||||
#post some trades in the market
|
||||
@@ -349,24 +357,7 @@ class TestDividendPerformance(unittest.TestCase):
|
||||
)
|
||||
|
||||
events.insert(1, dividend)
|
||||
perf_tracker = perf.PerformanceTracker(self.sim_params)
|
||||
|
||||
all_events = (msg[1] for msg in heapq.merge(
|
||||
((event.dt, event) for event in events),
|
||||
((event.dt, event) for event in self.benchmark_events)))
|
||||
|
||||
transformed_events = list(perf_tracker.transform(
|
||||
itertools.groupby(all_events, attrgetter('dt'))))
|
||||
|
||||
#flatten the list of events
|
||||
results = []
|
||||
for te in transformed_events:
|
||||
for event in te[1]:
|
||||
for message in event.perf_messages:
|
||||
results.append(message)
|
||||
|
||||
perf_messages, risk = perf_tracker.handle_simulation_end()
|
||||
results.append(perf_messages[0])
|
||||
results = calculate_results(self, events)
|
||||
|
||||
self.assertEqual(len(results), 5)
|
||||
cumulative_returns = \
|
||||
@@ -972,19 +963,18 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
((event.dt, event) for event in events),
|
||||
((event.dt, event) for event in benchmark_events)))
|
||||
|
||||
# Extract events with transactions to use for verification.
|
||||
perf_messages = \
|
||||
[m for date, snapshot in
|
||||
perf_tracker.transform(
|
||||
itertools.groupby(all_events, attrgetter('dt')))
|
||||
for e in snapshot
|
||||
for m in e.perf_messages]
|
||||
filtered_events = [filt_event for filt_event
|
||||
in all_events if event.dt <= end_dt]
|
||||
filtered_events.sort(key=lambda x: x.dt)
|
||||
grouped_events = itertools.groupby(filtered_events, lambda x: x.dt)
|
||||
perf_messages = []
|
||||
|
||||
end_perf_messages, risk_message = perf_tracker.handle_simulation_end()
|
||||
for date, group in grouped_events:
|
||||
for event in group:
|
||||
perf_tracker.process_event(event)
|
||||
msg = perf_tracker.handle_market_close()
|
||||
perf_messages.append(msg)
|
||||
|
||||
perf_messages.extend(end_perf_messages)
|
||||
|
||||
#we skip two trades, to test case of None transaction
|
||||
self.assertEqual(perf_tracker.txn_count, len(txns))
|
||||
self.assertEqual(perf_tracker.txn_count, len(orders))
|
||||
|
||||
@@ -1074,11 +1064,17 @@ class TestPerformanceTracker(unittest.TestCase):
|
||||
bar_event_2,
|
||||
]
|
||||
|
||||
messages = {date: snapshot[-1].perf_messages[0] for date, snapshot in
|
||||
tracker.transform(
|
||||
itertools.groupby(
|
||||
events,
|
||||
operator.attrgetter('dt')))}
|
||||
grouped_events = itertools.groupby(
|
||||
events, operator.attrgetter('dt'))
|
||||
|
||||
messages = {}
|
||||
for date, group in grouped_events:
|
||||
tracker.set_date(date)
|
||||
for event in group:
|
||||
tracker.process_event(event)
|
||||
tracker.handle_minute_close(date)
|
||||
msg = tracker.to_dict()
|
||||
messages[date] = msg
|
||||
|
||||
self.assertEquals(2, len(messages))
|
||||
|
||||
|
||||
@@ -221,42 +221,10 @@ class PerformanceTracker(object):
|
||||
elif self.emission_rate == 'daily':
|
||||
return self.day_count / self.total_days
|
||||
|
||||
def transform(self, stream_in):
|
||||
"""
|
||||
Main generator work loop.
|
||||
"""
|
||||
for date, snapshot in stream_in:
|
||||
new_snapshot = []
|
||||
|
||||
if self.emission_rate == 'daily':
|
||||
for event in snapshot:
|
||||
messages = self.process_event(event)
|
||||
if messages is not None:
|
||||
event.perf_messages = messages
|
||||
event.portfolio = self.get_portfolio()
|
||||
|
||||
new_snapshot.append(event)
|
||||
|
||||
elif self.emission_rate == 'minute':
|
||||
self.saved_dt = date
|
||||
self.todays_performance.period_close = self.saved_dt
|
||||
|
||||
for event in snapshot:
|
||||
self.process_event(event)
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
event.perf_messages = []
|
||||
event.portfolio = None
|
||||
new_snapshot.append(event)
|
||||
|
||||
|
||||
self.handle_minute_close(date)
|
||||
|
||||
if new_snapshot:
|
||||
new_snapshot[-1].perf_messages = [self.to_dict()]
|
||||
new_snapshot[-1].portfolio = self.get_portfolio()
|
||||
|
||||
if new_snapshot:
|
||||
yield date, new_snapshot
|
||||
def set_date(self, date):
|
||||
if self.emission_rate == 'minute':
|
||||
self.saved_dt = date
|
||||
self.todays_performance.period_close = self.saved_dt
|
||||
|
||||
def get_portfolio(self):
|
||||
return self.cumulative_performance.as_portfolio()
|
||||
@@ -286,6 +254,8 @@ class PerformanceTracker(object):
|
||||
# its own configuration down the line.
|
||||
# Naming as intraday to make clear that these results are
|
||||
# being updated per minute
|
||||
_dict['intraday_risk_metrics'] = \
|
||||
self.cumulative_risk_metrics.to_dict()
|
||||
_dict['intraday_perf'] = self.todays_performance.to_dict(
|
||||
self.saved_dt)
|
||||
|
||||
@@ -293,25 +263,14 @@ class PerformanceTracker(object):
|
||||
|
||||
def process_event(self, event):
|
||||
|
||||
messages = None
|
||||
self.event_count += 1
|
||||
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
messages = []
|
||||
|
||||
# This switch could also be handled by an inheritance
|
||||
# with a DailyPerformanceTracker and a MinutePerformanceTracker
|
||||
if self.emission_rate == 'daily':
|
||||
while (event.dt > self.market_close and
|
||||
event.dt < self.last_close):
|
||||
messages.append(self.handle_market_close())
|
||||
|
||||
#update last sale
|
||||
self.cumulative_performance.update_last_sale(event)
|
||||
self.todays_performance.update_last_sale(event)
|
||||
|
||||
elif event.type == zp.DATASOURCE_TYPE.TRANSACTION:
|
||||
|
||||
# Trade simulation always follows a transaction with the
|
||||
# TRADE event that was used to simulate it, so we don't
|
||||
# check for end of day rollover messages here.
|
||||
@@ -320,26 +279,17 @@ class PerformanceTracker(object):
|
||||
event
|
||||
)
|
||||
self.todays_performance.execute_transaction(event)
|
||||
# Transactions are consumed by performance, and not
|
||||
# relayed to the next element in the generator chain.
|
||||
messages = None
|
||||
|
||||
elif event.type == zp.DATASOURCE_TYPE.DIVIDEND:
|
||||
self.cumulative_performance.add_dividend(event)
|
||||
self.todays_performance.add_dividend(event)
|
||||
# Dividends are consumed by performance, and not
|
||||
# relayed to the next element in the generator chain.
|
||||
messages = None
|
||||
|
||||
elif event.type == zp.DATASOURCE_TYPE.ORDER:
|
||||
self.cumulative_performance.record_order(event)
|
||||
self.todays_performance.record_order(event)
|
||||
messages = None
|
||||
|
||||
elif event.type == zp.DATASOURCE_TYPE.CUSTOM:
|
||||
# we just want to relay this event unchanged.
|
||||
messages = []
|
||||
return messages
|
||||
pass
|
||||
elif event.type == zp.DATASOURCE_TYPE.BENCHMARK:
|
||||
self.all_benchmark_returns[event.dt] = event.returns
|
||||
|
||||
@@ -347,8 +297,6 @@ class PerformanceTracker(object):
|
||||
self.cumulative_performance.calculate_performance()
|
||||
self.todays_performance.calculate_performance()
|
||||
|
||||
return messages
|
||||
|
||||
def handle_minute_close(self, dt):
|
||||
#update risk metrics for cumulative performance
|
||||
algorithm_returns = pd.Series({dt: self.todays_performance.returns})
|
||||
@@ -421,14 +369,6 @@ class PerformanceTracker(object):
|
||||
When the simulation is complete, run the full period risk report
|
||||
and send it out on the results socket.
|
||||
"""
|
||||
# the stream will end on the last trading day, but will
|
||||
# not trigger an end of day, so we trigger the final
|
||||
# market close(s) here
|
||||
perf_messages = []
|
||||
while self.last_close > self.market_close:
|
||||
perf_messages.append(self.handle_market_close())
|
||||
|
||||
perf_messages.append(self.handle_market_close())
|
||||
|
||||
log_msg = "Simulated {n} trading days out of {m}."
|
||||
log.info(log_msg.format(n=int(self.day_count), m=self.total_days))
|
||||
@@ -440,7 +380,7 @@ class PerformanceTracker(object):
|
||||
self.risk_report = risk.RiskReport(self.returns, self.sim_params)
|
||||
|
||||
risk_dict = self.risk_report.to_dict()
|
||||
return perf_messages, risk_dict
|
||||
return risk_dict
|
||||
|
||||
|
||||
class Position(object):
|
||||
|
||||
@@ -68,11 +68,6 @@ class Blotter(object):
|
||||
Main generator work loop.
|
||||
"""
|
||||
for date, snapshot in stream_in:
|
||||
# relay any orders placed in prior snapshot
|
||||
# handling and reset the internal holding pen
|
||||
if self.new_orders:
|
||||
yield date, self.new_orders
|
||||
self.new_orders = []
|
||||
results = []
|
||||
|
||||
for event in snapshot:
|
||||
@@ -85,6 +80,9 @@ class Blotter(object):
|
||||
yield date, results
|
||||
|
||||
def process_trade(self, trade_event):
|
||||
if trade_event.type != DATASOURCE_TYPE.TRADE:
|
||||
return [], []
|
||||
|
||||
if zp_math.tolerant_equals(trade_event.volume, 0):
|
||||
# there are zero volume trade_events bc some stocks trade
|
||||
# less frequently than once per minute.
|
||||
@@ -103,7 +101,8 @@ class Blotter(object):
|
||||
txns = self.transact(trade_event, current_orders)
|
||||
for txn in txns:
|
||||
self.orders[txn.order_id].filled += txn.amount
|
||||
# mark the date of the order to match the txn
|
||||
# mark the date of the order to match the transaction
|
||||
# that is filling it.
|
||||
self.orders[txn.order_id].dt = txn.dt
|
||||
|
||||
modified_orders = [order for order
|
||||
@@ -262,23 +261,10 @@ class TradeSimulationClient(object):
|
||||
"""
|
||||
Main generator work loop.
|
||||
"""
|
||||
|
||||
# Simulate filling any open orders made by the previous run of
|
||||
# the user's algorithm. Fills the Transaction field on any
|
||||
# event that results in a filled order.
|
||||
with_filled_orders = self.blotter.transform(stream_in)
|
||||
|
||||
# Pipe the events with transactions to perf. This will remove
|
||||
# the TRANSACTION field added by TransactionSimulator and replace it
|
||||
# with a portfolio field to be passed to the user's
|
||||
# algorithm. Also adds a perf_messages field which is usually
|
||||
# empty, but contains update messages once per day.
|
||||
with_portfolio = self.perf_tracker.transform(with_filled_orders)
|
||||
|
||||
# Pass the messages from perf to the user's algorithm for simulation.
|
||||
# Events are batched by dt so that the algo handles all events for a
|
||||
# given timestamp at one one go.
|
||||
performance_messages = self.algo_sim.transform(with_portfolio)
|
||||
performance_messages = self.algo_sim.transform(stream_in)
|
||||
|
||||
# The algorithm will yield a daily_results message (as
|
||||
# calculated by the performance tracker) at the end of each
|
||||
@@ -407,41 +393,57 @@ class AlgorithmSimulator(object):
|
||||
# snapshot time to any log record generated.
|
||||
with self.processor.threadbound():
|
||||
|
||||
updated = False
|
||||
bm_updated = False
|
||||
for date, snapshot in stream:
|
||||
# We're still in the warmup period. Use the event to
|
||||
self.perf_tracker.set_date(date)
|
||||
# If we're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
# and don't send a snapshot to handle_data.
|
||||
if date < self.algo_start:
|
||||
for event in snapshot:
|
||||
del event['perf_messages']
|
||||
self.update_universe(event)
|
||||
if event.type in (DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CUSTOM):
|
||||
self.update_universe(event)
|
||||
self.perf_tracker.process_event(event)
|
||||
|
||||
# Regular snapshot. Update the universe and send a snapshot
|
||||
# to handle data.
|
||||
else:
|
||||
for event in snapshot:
|
||||
for perf_message in event.perf_messages:
|
||||
# append current values of recorded vars
|
||||
# to emitted message
|
||||
perf_message[self.perf_key]['recorded_vars'] =\
|
||||
self.algo.recorded_vars
|
||||
yield perf_message
|
||||
del event['perf_messages']
|
||||
|
||||
self.update_universe(event)
|
||||
for event in snapshot:
|
||||
if event.type in (DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CUSTOM):
|
||||
self.update_universe(event)
|
||||
updated = True
|
||||
if event.type == DATASOURCE_TYPE.BENCHMARK:
|
||||
bm_updated = True
|
||||
txns, orders = self.blotter.process_trade(event)
|
||||
for data in chain([event], txns, orders):
|
||||
self.perf_tracker.process_event(data)
|
||||
|
||||
# Update our portfolio.
|
||||
self.algo.set_portfolio(self.perf_tracker.get_portfolio())
|
||||
|
||||
# Send the current state of the universe
|
||||
# to the user's algo.
|
||||
self.simulate_snapshot(date)
|
||||
if updated:
|
||||
self.simulate_snapshot(date)
|
||||
updated = False
|
||||
|
||||
perf_messages, risk_message = \
|
||||
self.perf_tracker.handle_simulation_end()
|
||||
# run orders placed in the algorithm call
|
||||
# above through perf tracker before emitting
|
||||
# the perf packet, so that the perf includes
|
||||
# placed orders
|
||||
for order in self.blotter.new_orders:
|
||||
self.perf_tracker.process_event(order)
|
||||
self.blotter.new_orders = []
|
||||
|
||||
if self.perf_tracker.emission_rate == 'daily':
|
||||
for message in perf_messages:
|
||||
message[self.perf_key]['recorded_vars'] =\
|
||||
self.algo.recorded_vars
|
||||
yield message
|
||||
# The benchmark is our internal clock. When it
|
||||
# updates, we need to emit a performance message.
|
||||
if bm_updated:
|
||||
bm_updated = False
|
||||
yield self.get_message(date)
|
||||
|
||||
risk_message = self.perf_tracker.handle_simulation_end()
|
||||
|
||||
# When emitting minutely, it is still useful to have a final
|
||||
# packet with the entire days performance rolled up.
|
||||
@@ -455,20 +457,24 @@ class AlgorithmSimulator(object):
|
||||
|
||||
yield risk_message
|
||||
|
||||
def get_message(self, date):
|
||||
rvars = self.algo.recorded_vars
|
||||
if self.perf_tracker.emission_rate == 'daily':
|
||||
perf_message = \
|
||||
self.perf_tracker.handle_market_close()
|
||||
perf_message['daily_perf']['recorded_vars'] = rvars
|
||||
return perf_message
|
||||
|
||||
elif self.perf_tracker.emission_rate == 'minute':
|
||||
self.perf_tracker.handle_minute_close(date)
|
||||
perf_message = self.perf_tracker.to_dict()
|
||||
perf_message['intraday_perf']['recorded_vars'] = rvars
|
||||
return perf_message
|
||||
|
||||
def update_universe(self, event):
|
||||
"""
|
||||
Update the universe with new event information.
|
||||
"""
|
||||
# Update our portfolio.
|
||||
self.algo.set_portfolio(event.portfolio)
|
||||
# the portfolio is modified by each event passed into the
|
||||
# performance tracker (prices and amounts can change).
|
||||
# Performance tracker sends back an up-to-date portfolio
|
||||
# with each event. However, we provide the portfolio to
|
||||
# the algorithm via a setter method, rather than as part
|
||||
# of the event data sent to handle_data. To avoid
|
||||
# confusion, we remove it from the event here.
|
||||
del event.portfolio
|
||||
# Update our knowledge of this event's sid
|
||||
sid_data = self.universe[event.sid]
|
||||
sid_data.__dict__.update(event.__dict__)
|
||||
@@ -482,7 +488,6 @@ class AlgorithmSimulator(object):
|
||||
# log/print lines.
|
||||
self.snapshot_dt = date
|
||||
self.algo.set_datetime(self.snapshot_dt)
|
||||
self.algo.handle_data(self.universe)
|
||||
|
||||
# Update the simulation time.
|
||||
self.simulation_dt = date
|
||||
self.algo.handle_data(self.universe)
|
||||
|
||||
Reference in New Issue
Block a user