mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-02 07:19:44 +08:00
MAINT: Use sim_params for risk metrics init.
Prepare for adding emission_rate in risk metrics logic.
This commit is contained in:
@@ -23,6 +23,7 @@ import pandas as pd
|
||||
|
||||
import zipline.finance.risk as risk
|
||||
import zipline.finance.trading as trading
|
||||
from zipline.finance.trading import SimulationParameters
|
||||
from zipline.protocol import DailyReturn
|
||||
|
||||
from test_risk import RETURNS
|
||||
@@ -63,7 +64,9 @@ class RiskCompareIterativeToBatch(unittest.TestCase):
|
||||
end_date = trading.environment.trading_days[
|
||||
start_index + len(RETURNS)]
|
||||
|
||||
risk_metrics_refactor = risk.RiskMetricsIterative(start_date, end_date)
|
||||
sim_params = SimulationParameters(start_date, end_date)
|
||||
|
||||
risk_metrics_refactor = risk.RiskMetricsIterative(sim_params)
|
||||
todays_date = start_date
|
||||
|
||||
cur_returns = []
|
||||
|
||||
@@ -161,8 +161,9 @@ class PerformanceTracker(object):
|
||||
trading.environment.get_open_and_close(first_day)
|
||||
self.total_days = self.sim_params.days_in_period
|
||||
self.capital_base = self.sim_params.capital_base
|
||||
self.emission_rate = sim_params.emission_rate
|
||||
self.cumulative_risk_metrics = \
|
||||
risk.RiskMetricsIterative(self.period_start, self.period_end)
|
||||
risk.RiskMetricsIterative(self.sim_params)
|
||||
self.emission_rate = sim_params.emission_rate
|
||||
|
||||
# Temporarily hold these here as we work on streaming benchmarks.
|
||||
|
||||
@@ -522,12 +522,14 @@ class RiskMetricsIterative(RiskMetricsBase):
|
||||
Call update() method on each dt to update the metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, start_date, end_date):
|
||||
def __init__(self, sim_params):
|
||||
self.treasury_curves = trading.environment.treasury_curves
|
||||
self.start_date = start_date.replace(hour=0, minute=0, second=0,
|
||||
microsecond=0)
|
||||
self.end_date = end_date.replace(hour=0, minute=0, second=0,
|
||||
microsecond=0)
|
||||
self.start_date = sim_params.period_start.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
self.end_date = sim_params.period_end.replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
|
||||
all_trading_days = trading.environment.trading_days
|
||||
mask = ((all_trading_days >= self.start_date) &
|
||||
|
||||
Reference in New Issue
Block a user