mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 07:51:59 +08:00
Merge lazy loading of portfolio and related values.
This commit is contained in:
@@ -321,7 +321,7 @@ class SlippageTestCase(TestCase):
|
||||
@parameterized.expand([
|
||||
(name, case['order'], case['event'], case['expected'])
|
||||
for name, case in STOP_ORDER_CASES.items()
|
||||
])
|
||||
])
|
||||
def test_orders_stop(self, name, order_data, event_data, expected):
|
||||
order = Order(**order_data)
|
||||
event = Event(initial_values=event_data)
|
||||
|
||||
+25
-5
@@ -90,7 +90,6 @@ class TradingAlgorithm(object):
|
||||
capital_base : float <default: 1.0e5>
|
||||
How much capital to start with.
|
||||
"""
|
||||
self._portfolio = None
|
||||
self.datetime = None
|
||||
|
||||
self.registered_transforms = {}
|
||||
@@ -102,6 +101,7 @@ class TradingAlgorithm(object):
|
||||
self.logger = None
|
||||
|
||||
self.benchmark_return_source = None
|
||||
self.perf_tracker = None
|
||||
|
||||
# default components for transact
|
||||
self.slippage = VolumeShareSlippage()
|
||||
@@ -124,11 +124,15 @@ class TradingAlgorithm(object):
|
||||
self.sim_params = kwargs.pop('sim_params', None)
|
||||
if self.sim_params:
|
||||
self.sim_params.data_frequency = self.data_frequency
|
||||
self.perf_tracker = PerformanceTracker(self.sim_params)
|
||||
|
||||
self.blotter = kwargs.pop('blotter', None)
|
||||
if not self.blotter:
|
||||
self.blotter = Blotter()
|
||||
|
||||
self.portfolio_needs_update = True
|
||||
self._portfolio = None
|
||||
|
||||
# an algorithm subclass needs to set initialized to True when
|
||||
# it is fully initialized.
|
||||
self.initialized = False
|
||||
@@ -213,9 +217,14 @@ class TradingAlgorithm(object):
|
||||
"""
|
||||
sim_params.data_frequency = self.data_frequency
|
||||
|
||||
# perf_tracker will be instantiated in __init__ if a sim_params
|
||||
# is passed to the constructor. If not, we instantiate here.
|
||||
if self.perf_tracker is None:
|
||||
self.perf_tracker = PerformanceTracker(sim_params)
|
||||
|
||||
self.data_gen = self._create_data_generator(source_filter,
|
||||
sim_params)
|
||||
self.perf_tracker = PerformanceTracker(sim_params)
|
||||
|
||||
self.trading_client = AlgorithmSimulator(self, sim_params)
|
||||
|
||||
transact_method = transact_partial(self.slippage, self.commission)
|
||||
@@ -304,6 +313,10 @@ class TradingAlgorithm(object):
|
||||
|
||||
self.transforms.append(sf)
|
||||
|
||||
# force a reset of the performance tracker, in case
|
||||
# this is a repeat run of the algorithm.
|
||||
self.perf_tracker = None
|
||||
|
||||
# create transforms and zipline
|
||||
self.gen = self._create_generator(sim_params)
|
||||
|
||||
@@ -380,10 +393,17 @@ class TradingAlgorithm(object):
|
||||
|
||||
@property
|
||||
def portfolio(self):
|
||||
return self._portfolio
|
||||
# internally this will cause a refresh of the
|
||||
# period performance calculations.
|
||||
return self.perf_tracker.get_portfolio()
|
||||
|
||||
def set_portfolio(self, portfolio):
|
||||
self._portfolio = portfolio
|
||||
def updated_portfolio(self):
|
||||
# internally this will cause a refresh of the
|
||||
# period performance calculations.
|
||||
if self.portfolio_needs_update:
|
||||
self._portfolio = self.perf_tracker.get_portfolio()
|
||||
self.portfolio_needs_update = False
|
||||
return self._portfolio
|
||||
|
||||
def set_logger(self, logger):
|
||||
self.logger = logger
|
||||
|
||||
@@ -255,10 +255,14 @@ class PerformancePeriod(object):
|
||||
return np.dot(self._position_amounts, self._position_last_sale_prices)
|
||||
|
||||
def update_last_sale(self, event):
|
||||
is_trade = event.type == zp.DATASOURCE_TYPE.TRADE
|
||||
has_price = not np.isnan(event.price)
|
||||
if event.sid not in self.positions:
|
||||
return
|
||||
|
||||
if event.type != zp.DATASOURCE_TYPE.TRADE:
|
||||
return
|
||||
|
||||
if not pd.isnull(event.price):
|
||||
# isnan check will keep the last price if its not present
|
||||
if (event.sid in self.positions) and is_trade and has_price:
|
||||
self.update_position(event.sid, last_sale_price=event.price,
|
||||
last_sale_date=event.dt)
|
||||
|
||||
|
||||
@@ -188,7 +188,13 @@ class PerformanceTracker(object):
|
||||
self.saved_dt = date
|
||||
self.todays_performance.period_close = self.saved_dt
|
||||
|
||||
def update_performance(self):
|
||||
# calculate performance as of last trade
|
||||
for perf_period in self.perf_periods:
|
||||
perf_period.calculate_performance()
|
||||
|
||||
def get_portfolio(self):
|
||||
self.update_performance()
|
||||
return self.cumulative_performance.as_portfolio()
|
||||
|
||||
def to_dict(self, emission_type=None):
|
||||
@@ -217,7 +223,6 @@ class PerformanceTracker(object):
|
||||
return _dict
|
||||
|
||||
def process_event(self, event):
|
||||
|
||||
self.event_count += 1
|
||||
|
||||
if event.type == zp.DATASOURCE_TYPE.TRADE:
|
||||
@@ -271,11 +276,8 @@ class PerformanceTracker(object):
|
||||
|
||||
self.all_benchmark_returns[midnight] = event.returns
|
||||
|
||||
# calculate performance as of last trade
|
||||
for perf_period in self.perf_periods:
|
||||
perf_period.calculate_performance()
|
||||
|
||||
def handle_minute_close(self, dt):
|
||||
self.update_performance()
|
||||
todays_date = normalize_date(dt)
|
||||
|
||||
minute_returns = self.minute_performance.returns
|
||||
@@ -304,6 +306,8 @@ class PerformanceTracker(object):
|
||||
self.returns[todays_date] = self.todays_performance.returns
|
||||
|
||||
def handle_intraday_close(self):
|
||||
# update_performance should have been called in handle_minute_close
|
||||
# so it is not repeated here.
|
||||
self.intraday_risk_metrics = \
|
||||
risk.RiskMetricsCumulative(self.sim_params)
|
||||
# increment the day counter before we move markers forward.
|
||||
@@ -316,6 +320,7 @@ class PerformanceTracker(object):
|
||||
self.market_close = self.sim_params.last_close
|
||||
|
||||
def handle_market_close(self):
|
||||
self.update_performance()
|
||||
# add the return results from today to the returns series
|
||||
todays_date = normalize_date(self.market_close)
|
||||
self.cumulative_performance.update_dividends(todays_date)
|
||||
|
||||
@@ -102,7 +102,6 @@ class AlgorithmSimulator(object):
|
||||
# inject the current algo
|
||||
# snapshot time to any log record generated.
|
||||
with self.processor.threadbound():
|
||||
|
||||
updated = False
|
||||
bm_updated = False
|
||||
for date, snapshot in stream_in:
|
||||
@@ -150,11 +149,6 @@ class AlgorithmSimulator(object):
|
||||
else:
|
||||
events.append(event)
|
||||
|
||||
# Update our portfolio.
|
||||
self.algo.set_portfolio(
|
||||
self.algo.perf_tracker.get_portfolio()
|
||||
)
|
||||
|
||||
# Send the current state of the universe
|
||||
# to the user's algo.
|
||||
if updated:
|
||||
@@ -179,6 +173,7 @@ class AlgorithmSimulator(object):
|
||||
# updates, we need to emit a performance message.
|
||||
if bm_updated:
|
||||
bm_updated = False
|
||||
self.algo.updated_portfolio()
|
||||
yield self.get_message(date)
|
||||
|
||||
# When emitting minutely, we re-iterate the day as a
|
||||
@@ -200,6 +195,8 @@ class AlgorithmSimulator(object):
|
||||
)
|
||||
self.algo.perf_tracker.handle_intraday_close()
|
||||
|
||||
self.portfolio_needs_update = True
|
||||
|
||||
risk_message = self.algo.perf_tracker.handle_simulation_end()
|
||||
yield risk_message
|
||||
|
||||
|
||||
Reference in New Issue
Block a user