mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 19:47:43 +08:00
MAINT: Refactor performance tracker as part of algorithm.
Instead of having the performance tracker as part of the tradesimulation class, hold on to it inside of the algorithm object, so that the perf_tracker is more easily accessed for reset behavior, etc.
This commit is contained in:
@@ -29,6 +29,7 @@ from zipline.errors import (
|
||||
UnsupportedCommissionModel,
|
||||
OverrideCommissionPostInit
|
||||
)
|
||||
from zipline.finance.performance import PerformanceTracker
|
||||
from zipline.sources import DataFrameSource, DataPanelSource
|
||||
from zipline.utils.factory import create_simulation_parameters
|
||||
from zipline.transforms.utils import StatefulTransform
|
||||
@@ -171,7 +172,7 @@ class TradingAlgorithm(object):
|
||||
skipped.
|
||||
"""
|
||||
self.data_gen = self._create_data_generator(source_filter, sim_params)
|
||||
|
||||
self.perf_tracker = PerformanceTracker(sim_params)
|
||||
self.trading_client = AlgorithmSimulator(self, sim_params)
|
||||
|
||||
transact_method = transact_partial(self.slippage, self.commission)
|
||||
|
||||
@@ -17,7 +17,6 @@ from itertools import chain
|
||||
from logbook import Logger, Processor
|
||||
|
||||
from zipline.protocol import BarData, DATASOURCE_TYPE
|
||||
from zipline.finance.performance import PerformanceTracker
|
||||
from zipline.gens.utils import hash_args
|
||||
|
||||
log = Logger('Trade Simulation')
|
||||
@@ -45,15 +44,6 @@ class AlgorithmSimulator(object):
|
||||
# ==============
|
||||
self.sim_params = sim_params
|
||||
|
||||
# ==============
|
||||
# Perf Tracker
|
||||
# Setup
|
||||
# ==============
|
||||
self.perf_tracker = PerformanceTracker(self.sim_params)
|
||||
|
||||
self.perf_key = self.EMISSION_TO_PERF_KEY_MAP[
|
||||
self.perf_tracker.emission_rate]
|
||||
|
||||
# ==============
|
||||
# Algo Setup
|
||||
# ==============
|
||||
@@ -63,6 +53,9 @@ class AlgorithmSimulator(object):
|
||||
second=0,
|
||||
microsecond=0)
|
||||
|
||||
self.perf_key = self.EMISSION_TO_PERF_KEY_MAP[
|
||||
self.algo.perf_tracker.emission_rate]
|
||||
|
||||
# ==============
|
||||
# Snapshot Setup
|
||||
# ==============
|
||||
@@ -108,7 +101,7 @@ class AlgorithmSimulator(object):
|
||||
updated = False
|
||||
bm_updated = False
|
||||
for date, snapshot in stream:
|
||||
self.perf_tracker.set_date(date)
|
||||
self.algo.perf_tracker.set_date(date)
|
||||
self.algo.blotter.set_date(date)
|
||||
# If we're still in the warmup period. Use the event to
|
||||
# update our universe, but don't yield any perf messages,
|
||||
@@ -118,7 +111,7 @@ class AlgorithmSimulator(object):
|
||||
if event.type in (DATASOURCE_TYPE.TRADE,
|
||||
DATASOURCE_TYPE.CUSTOM):
|
||||
self.update_universe(event)
|
||||
self.perf_tracker.process_event(event)
|
||||
self.algo.perf_tracker.process_event(event)
|
||||
|
||||
else:
|
||||
|
||||
@@ -131,10 +124,12 @@ class AlgorithmSimulator(object):
|
||||
bm_updated = True
|
||||
txns, orders = self.algo.blotter.process_trade(event)
|
||||
for data in chain([event], txns, orders):
|
||||
self.perf_tracker.process_event(data)
|
||||
self.algo.perf_tracker.process_event(data)
|
||||
|
||||
# Update our portfolio.
|
||||
self.algo.set_portfolio(self.perf_tracker.get_portfolio())
|
||||
self.algo.set_portfolio(
|
||||
self.algo.perf_tracker.get_portfolio()
|
||||
)
|
||||
|
||||
# Send the current state of the universe
|
||||
# to the user's algo.
|
||||
@@ -147,7 +142,7 @@ class AlgorithmSimulator(object):
|
||||
# the perf packet, so that the perf includes
|
||||
# placed orders
|
||||
for order in self.algo.blotter.new_orders:
|
||||
self.perf_tracker.process_event(order)
|
||||
self.algo.perf_tracker.process_event(order)
|
||||
self.algo.blotter.new_orders = []
|
||||
|
||||
# The benchmark is our internal clock. When it
|
||||
@@ -156,12 +151,12 @@ class AlgorithmSimulator(object):
|
||||
bm_updated = False
|
||||
yield self.get_message(date)
|
||||
|
||||
risk_message = self.perf_tracker.handle_simulation_end()
|
||||
risk_message = self.algo.perf_tracker.handle_simulation_end()
|
||||
|
||||
# When emitting minutely, it is still useful to have a final
|
||||
# packet with the entire days performance rolled up.
|
||||
if self.perf_tracker.emission_rate == 'minute':
|
||||
daily_rollup = self.perf_tracker.to_dict(
|
||||
if self.algo.perf_tracker.emission_rate == 'minute':
|
||||
daily_rollup = self.algo.perf_tracker.to_dict(
|
||||
emission_type='daily'
|
||||
)
|
||||
daily_rollup['daily_perf']['recorded_vars'] = \
|
||||
@@ -172,15 +167,15 @@ class AlgorithmSimulator(object):
|
||||
|
||||
def get_message(self, date):
|
||||
rvars = self.algo.recorded_vars
|
||||
if self.perf_tracker.emission_rate == 'daily':
|
||||
if self.algo.perf_tracker.emission_rate == 'daily':
|
||||
perf_message = \
|
||||
self.perf_tracker.handle_market_close()
|
||||
self.algo.perf_tracker.handle_market_close()
|
||||
perf_message['daily_perf']['recorded_vars'] = rvars
|
||||
return perf_message
|
||||
|
||||
elif self.perf_tracker.emission_rate == 'minute':
|
||||
self.perf_tracker.handle_minute_close(date)
|
||||
perf_message = self.perf_tracker.to_dict()
|
||||
elif self.algo.perf_tracker.emission_rate == 'minute':
|
||||
self.algo.perf_tracker.handle_minute_close(date)
|
||||
perf_message = self.algo.perf_tracker.to_dict()
|
||||
perf_message['intraday_perf']['recorded_vars'] = rvars
|
||||
return perf_message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user