diff --git a/zipline/finance/risk.py b/zipline/finance/risk.py index 04c40b1d..3d828d0c 100644 --- a/zipline/finance/risk.py +++ b/zipline/finance/risk.py @@ -178,6 +178,21 @@ def alpha(algorithm_period_return, treasury_period_return, ########################### +def get_treasury_rate(treasury_curves, treasury_duration, day): + rate = None + + curve = treasury_curves[day] + # 1month note data begins in 8/2001, + # so we can use 3month instead. + idx = TREASURY_DURATIONS.index(treasury_duration) + for duration in TREASURY_DURATIONS[idx:]: + rate = curve[duration] + if rate is not None: + break + + return rate + + def search_day_distance(end_date, dt): tdd = trading.environment.trading_day_distance(dt, end_date) if tdd is None: @@ -435,7 +450,9 @@ class RiskMetricsBase(object): search_day = None if end_day in self.treasury_curves: - rate = self.get_treasury_rate(end_day) + rate = get_treasury_rate(self.treasury_curves, + self.treasury_duration, + end_day) if rate is not None: search_day = end_day @@ -447,7 +464,9 @@ class RiskMetricsBase(object): # Find rightmost value less than or equal to end_day i = bisect.bisect_right(search_days, end_day) for prev_day in search_days[i - 1::-1]: - rate = self.get_treasury_rate(prev_day) + rate = get_treasury_rate(self.treasury_curves, + self.treasury_duration, + prev_day) if rate is not None: search_day = prev_day search_dist = search_day_distance(self.end_date, prev_day) @@ -476,20 +495,6 @@ that date doesn't exceed treasury history range." ) raise Exception(message) - def get_treasury_rate(self, day): - rate = None - - curve = self.treasury_curves[day] - # 1month note data begins in 8/2001, - # so we can use 3month instead. - idx = TREASURY_DURATIONS.index(self.treasury_duration) - for duration in TREASURY_DURATIONS[idx:]: - rate = curve[duration] - if rate is not None: - break - - return rate - class RiskMetricsIterative(RiskMetricsBase): """Iterative version of RiskMetrics.