From c1280daaa309d3079f16646fa58eedbe3b480720 Mon Sep 17 00:00:00 2001 From: Eddie Hebert Date: Thu, 25 May 2017 16:11:25 -0400 Subject: [PATCH] MAINT: Remove environment as an argument to benchmark source. (#1816) MAINT: Remove environment as an argument to benchmark source. To allow the BenchmarkSource class to be more easily used in contexts other than a TradingAlgorithm, remove the TradingEnvironment as an argument to the benchmark source. Instead: - Pass a benchmark Asset, instead of a bencmark sid; so that the asset_finder does not need to be passed to the benchmark source. - Pass the pre-calculated benchmark_returns instead of an env, which contains the benchmark_returns; a consumer can let the benchmark_returns stay as the default of `None` when using an asset. We may want to further refactor and make two different classes, instead of relying on a combination of existence/non-existence of benchmark_asset and benchmark_returns. That refactoring should be easier to do with this change. --- tests/test_benchmark.py | 31 ++++++++++++++----------- zipline/algorithm.py | 13 +++++++++-- zipline/sources/benchmark_source.py | 36 +++++++++++++++-------------- 3 files changed, 47 insertions(+), 33 deletions(-) diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 948baa05..f45632e2 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -96,7 +96,10 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars, days_to_use = self.sim_params.sessions[1:] source = BenchmarkSource( - 1, self.env, self.trading_calendar, days_to_use, self.data_portal + self.env.asset_finder.retrieve_asset(1), + self.trading_calendar, + days_to_use, + self.data_portal ) # should be the equivalent of getting the price history, then doing @@ -131,30 +134,28 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars, with self.assertRaises(BenchmarkAssetNotAvailableTooEarly) as exc: BenchmarkSource( - 3, - self.env, + benchmark, self.trading_calendar, self.sim_params.sessions[1:], self.data_portal ) self.assertEqual( - '3 does not exist on %s. It started trading on %s.' % + 'Equity(3 [C]) does not exist on %s. It started trading on %s.' % (self.sim_params.sessions[1], benchmark_start), exc.exception.message ) with self.assertRaises(BenchmarkAssetNotAvailableTooLate) as exc2: BenchmarkSource( - 3, - self.env, + benchmark, self.trading_calendar, self.sim_params.sessions[120:], self.data_portal ) self.assertEqual( - '3 does not exist on %s. It stopped trading on %s.' % + 'Equity(3 [C]) does not exist on %s. It stopped trading on %s.' % (self.sim_params.sessions[-1], benchmark_end), exc2.exception.message ) @@ -182,8 +183,7 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars, ) source = BenchmarkSource( - 2, - self.env, + self.env.asset_finder.retrieve_asset(2), self.trading_calendar, self.sim_params.sessions, data_portal @@ -214,11 +214,14 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars, with self.assertRaises(InvalidBenchmarkAsset) as exc: BenchmarkSource( - 4, self.env, self.trading_calendar, - self.sim_params.sessions, self.data_portal + self.env.asset_finder.retrieve_asset(4), + self.trading_calendar, + self.sim_params.sessions, + self.data_portal ) - self.assertEqual("4 cannot be used as the benchmark because it has a " - "stock dividend on 2006-03-16 00:00:00. Choose " - "another asset to use as the benchmark.", + self.assertEqual("Equity(4 [D]) cannot be used as the benchmark " + "because it has a stock dividend on 2006-03-16 " + "00:00:00. Choose another asset to use as the " + "benchmark.", exc.exception.message) diff --git a/zipline/algorithm.py b/zipline/algorithm.py index d6730089..44f2a1df 100644 --- a/zipline/algorithm.py +++ b/zipline/algorithm.py @@ -541,13 +541,22 @@ class TradingAlgorithm(object): ) def _create_benchmark_source(self): + if self.benchmark_sid is not None: + benchmark_asset = self.asset_finder.retrieve_asset( + self.benchmark_sid) + benchmark_returns = None + else: + benchmark_asset = None + # get benchmark info from trading environment, which defaults to + # downloading data from Yahoo. + benchmark_returns = self.trading_environment.benchmark_returns return BenchmarkSource( - benchmark_sid=self.benchmark_sid, - env=self.trading_environment, + benchmark_asset=benchmark_asset, trading_calendar=self.trading_calendar, sessions=self.sim_params.sessions, data_portal=self.data_portal, emission_rate=self.sim_params.emission_rate, + benchmark_returns=benchmark_returns, ) def _create_generator(self, sim_params): diff --git a/zipline/sources/benchmark_source.py b/zipline/sources/benchmark_source.py index 309ca8d2..ec477665 100644 --- a/zipline/sources/benchmark_source.py +++ b/zipline/sources/benchmark_source.py @@ -23,19 +23,21 @@ from zipline.errors import ( class BenchmarkSource(object): - def __init__(self, benchmark_sid, env, trading_calendar, sessions, - data_portal, emission_rate="daily"): - self.benchmark_sid = benchmark_sid - self.env = env + def __init__(self, + benchmark_asset, + trading_calendar, + sessions, + data_portal, + emission_rate="daily", + benchmark_returns=None): + self.benchmark_asset = benchmark_asset self.sessions = sessions self.emission_rate = emission_rate self.data_portal = data_portal if len(sessions) == 0: self._precalculated_series = pd.Series() - elif self.benchmark_sid: - benchmark_asset = self.env.asset_finder.retrieve_asset( - self.benchmark_sid) + elif benchmark_asset is not None: self._validate_benchmark(benchmark_asset) @@ -46,11 +48,8 @@ class BenchmarkSource(object): self.sessions, self.data_portal ) - else: - # get benchmark info from trading environment, which defaults to - # downloading data from Yahoo. - daily_series = \ - env.benchmark_returns[sessions[0]:sessions[-1]] + elif benchmark_returns is not None: + daily_series = benchmark_returns[sessions[0]:sessions[-1]] if self.emission_rate == "minute": # we need to take the env's benchmark returns, which are daily, @@ -68,6 +67,9 @@ class BenchmarkSource(object): self._precalculated_series = minute_series else: self._precalculated_series = daily_series + else: + raise Exception("Must provide either benchmark_asset or " + "benchmark_returns.") def get_value(self, dt): return self._precalculated_series.loc[dt] @@ -80,19 +82,19 @@ class BenchmarkSource(object): # error suggesting that the user pick a different asset to use # as benchmark. stock_dividends = \ - self.data_portal.get_stock_dividends(self.benchmark_sid, + self.data_portal.get_stock_dividends(self.benchmark_asset, self.sessions) if len(stock_dividends) > 0: raise InvalidBenchmarkAsset( - sid=str(self.benchmark_sid), + sid=str(self.benchmark_asset), dt=stock_dividends[0]["ex_date"] ) if benchmark_asset.start_date > self.sessions[0]: # the asset started trading after the first simulation day raise BenchmarkAssetNotAvailableTooEarly( - sid=str(self.benchmark_sid), + sid=str(self.benchmark_asset), dt=self.sessions[0], start_dt=benchmark_asset.start_date ) @@ -100,7 +102,7 @@ class BenchmarkSource(object): if benchmark_asset.end_date < self.sessions[-1]: # the asset stopped trading before the last simulation day raise BenchmarkAssetNotAvailableTooLate( - sid=str(self.benchmark_sid), + sid=str(self.benchmark_asset), dt=self.sessions[-1], end_dt=benchmark_asset.end_date ) @@ -157,7 +159,7 @@ class BenchmarkSource(object): else: start_date = asset.start_date if start_date < trading_days[0]: - # get the window of close prices for benchmark_sid from the + # get the window of close prices for benchmark_asset from the # last trading day of the simulation, going up to one day # before the simulation start day (so that we can get the % # change on day 1)