mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-28 04:53:20 +08:00
MAINT: Remove environment as an argument to benchmark source. (#1816)
MAINT: Remove environment as an argument to benchmark source. To allow the BenchmarkSource class to be more easily used in contexts other than a TradingAlgorithm, remove the TradingEnvironment as an argument to the benchmark source. Instead: - Pass a benchmark Asset, instead of a bencmark sid; so that the asset_finder does not need to be passed to the benchmark source. - Pass the pre-calculated benchmark_returns instead of an env, which contains the benchmark_returns; a consumer can let the benchmark_returns stay as the default of `None` when using an asset. We may want to further refactor and make two different classes, instead of relying on a combination of existence/non-existence of benchmark_asset and benchmark_returns. That refactoring should be easier to do with this change.
This commit is contained in:
+17
-14
@@ -96,7 +96,10 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
|
||||
days_to_use = self.sim_params.sessions[1:]
|
||||
|
||||
source = BenchmarkSource(
|
||||
1, self.env, self.trading_calendar, days_to_use, self.data_portal
|
||||
self.env.asset_finder.retrieve_asset(1),
|
||||
self.trading_calendar,
|
||||
days_to_use,
|
||||
self.data_portal
|
||||
)
|
||||
|
||||
# should be the equivalent of getting the price history, then doing
|
||||
@@ -131,30 +134,28 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
|
||||
|
||||
with self.assertRaises(BenchmarkAssetNotAvailableTooEarly) as exc:
|
||||
BenchmarkSource(
|
||||
3,
|
||||
self.env,
|
||||
benchmark,
|
||||
self.trading_calendar,
|
||||
self.sim_params.sessions[1:],
|
||||
self.data_portal
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
'3 does not exist on %s. It started trading on %s.' %
|
||||
'Equity(3 [C]) does not exist on %s. It started trading on %s.' %
|
||||
(self.sim_params.sessions[1], benchmark_start),
|
||||
exc.exception.message
|
||||
)
|
||||
|
||||
with self.assertRaises(BenchmarkAssetNotAvailableTooLate) as exc2:
|
||||
BenchmarkSource(
|
||||
3,
|
||||
self.env,
|
||||
benchmark,
|
||||
self.trading_calendar,
|
||||
self.sim_params.sessions[120:],
|
||||
self.data_portal
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
'3 does not exist on %s. It stopped trading on %s.' %
|
||||
'Equity(3 [C]) does not exist on %s. It stopped trading on %s.' %
|
||||
(self.sim_params.sessions[-1], benchmark_end),
|
||||
exc2.exception.message
|
||||
)
|
||||
@@ -182,8 +183,7 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
|
||||
)
|
||||
|
||||
source = BenchmarkSource(
|
||||
2,
|
||||
self.env,
|
||||
self.env.asset_finder.retrieve_asset(2),
|
||||
self.trading_calendar,
|
||||
self.sim_params.sessions,
|
||||
data_portal
|
||||
@@ -214,11 +214,14 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
|
||||
|
||||
with self.assertRaises(InvalidBenchmarkAsset) as exc:
|
||||
BenchmarkSource(
|
||||
4, self.env, self.trading_calendar,
|
||||
self.sim_params.sessions, self.data_portal
|
||||
self.env.asset_finder.retrieve_asset(4),
|
||||
self.trading_calendar,
|
||||
self.sim_params.sessions,
|
||||
self.data_portal
|
||||
)
|
||||
|
||||
self.assertEqual("4 cannot be used as the benchmark because it has a "
|
||||
"stock dividend on 2006-03-16 00:00:00. Choose "
|
||||
"another asset to use as the benchmark.",
|
||||
self.assertEqual("Equity(4 [D]) cannot be used as the benchmark "
|
||||
"because it has a stock dividend on 2006-03-16 "
|
||||
"00:00:00. Choose another asset to use as the "
|
||||
"benchmark.",
|
||||
exc.exception.message)
|
||||
|
||||
+11
-2
@@ -541,13 +541,22 @@ class TradingAlgorithm(object):
|
||||
)
|
||||
|
||||
def _create_benchmark_source(self):
|
||||
if self.benchmark_sid is not None:
|
||||
benchmark_asset = self.asset_finder.retrieve_asset(
|
||||
self.benchmark_sid)
|
||||
benchmark_returns = None
|
||||
else:
|
||||
benchmark_asset = None
|
||||
# get benchmark info from trading environment, which defaults to
|
||||
# downloading data from Yahoo.
|
||||
benchmark_returns = self.trading_environment.benchmark_returns
|
||||
return BenchmarkSource(
|
||||
benchmark_sid=self.benchmark_sid,
|
||||
env=self.trading_environment,
|
||||
benchmark_asset=benchmark_asset,
|
||||
trading_calendar=self.trading_calendar,
|
||||
sessions=self.sim_params.sessions,
|
||||
data_portal=self.data_portal,
|
||||
emission_rate=self.sim_params.emission_rate,
|
||||
benchmark_returns=benchmark_returns,
|
||||
)
|
||||
|
||||
def _create_generator(self, sim_params):
|
||||
|
||||
@@ -23,19 +23,21 @@ from zipline.errors import (
|
||||
|
||||
|
||||
class BenchmarkSource(object):
|
||||
def __init__(self, benchmark_sid, env, trading_calendar, sessions,
|
||||
data_portal, emission_rate="daily"):
|
||||
self.benchmark_sid = benchmark_sid
|
||||
self.env = env
|
||||
def __init__(self,
|
||||
benchmark_asset,
|
||||
trading_calendar,
|
||||
sessions,
|
||||
data_portal,
|
||||
emission_rate="daily",
|
||||
benchmark_returns=None):
|
||||
self.benchmark_asset = benchmark_asset
|
||||
self.sessions = sessions
|
||||
self.emission_rate = emission_rate
|
||||
self.data_portal = data_portal
|
||||
|
||||
if len(sessions) == 0:
|
||||
self._precalculated_series = pd.Series()
|
||||
elif self.benchmark_sid:
|
||||
benchmark_asset = self.env.asset_finder.retrieve_asset(
|
||||
self.benchmark_sid)
|
||||
elif benchmark_asset is not None:
|
||||
|
||||
self._validate_benchmark(benchmark_asset)
|
||||
|
||||
@@ -46,11 +48,8 @@ class BenchmarkSource(object):
|
||||
self.sessions,
|
||||
self.data_portal
|
||||
)
|
||||
else:
|
||||
# get benchmark info from trading environment, which defaults to
|
||||
# downloading data from Yahoo.
|
||||
daily_series = \
|
||||
env.benchmark_returns[sessions[0]:sessions[-1]]
|
||||
elif benchmark_returns is not None:
|
||||
daily_series = benchmark_returns[sessions[0]:sessions[-1]]
|
||||
|
||||
if self.emission_rate == "minute":
|
||||
# we need to take the env's benchmark returns, which are daily,
|
||||
@@ -68,6 +67,9 @@ class BenchmarkSource(object):
|
||||
self._precalculated_series = minute_series
|
||||
else:
|
||||
self._precalculated_series = daily_series
|
||||
else:
|
||||
raise Exception("Must provide either benchmark_asset or "
|
||||
"benchmark_returns.")
|
||||
|
||||
def get_value(self, dt):
|
||||
return self._precalculated_series.loc[dt]
|
||||
@@ -80,19 +82,19 @@ class BenchmarkSource(object):
|
||||
# error suggesting that the user pick a different asset to use
|
||||
# as benchmark.
|
||||
stock_dividends = \
|
||||
self.data_portal.get_stock_dividends(self.benchmark_sid,
|
||||
self.data_portal.get_stock_dividends(self.benchmark_asset,
|
||||
self.sessions)
|
||||
|
||||
if len(stock_dividends) > 0:
|
||||
raise InvalidBenchmarkAsset(
|
||||
sid=str(self.benchmark_sid),
|
||||
sid=str(self.benchmark_asset),
|
||||
dt=stock_dividends[0]["ex_date"]
|
||||
)
|
||||
|
||||
if benchmark_asset.start_date > self.sessions[0]:
|
||||
# the asset started trading after the first simulation day
|
||||
raise BenchmarkAssetNotAvailableTooEarly(
|
||||
sid=str(self.benchmark_sid),
|
||||
sid=str(self.benchmark_asset),
|
||||
dt=self.sessions[0],
|
||||
start_dt=benchmark_asset.start_date
|
||||
)
|
||||
@@ -100,7 +102,7 @@ class BenchmarkSource(object):
|
||||
if benchmark_asset.end_date < self.sessions[-1]:
|
||||
# the asset stopped trading before the last simulation day
|
||||
raise BenchmarkAssetNotAvailableTooLate(
|
||||
sid=str(self.benchmark_sid),
|
||||
sid=str(self.benchmark_asset),
|
||||
dt=self.sessions[-1],
|
||||
end_dt=benchmark_asset.end_date
|
||||
)
|
||||
@@ -157,7 +159,7 @@ class BenchmarkSource(object):
|
||||
else:
|
||||
start_date = asset.start_date
|
||||
if start_date < trading_days[0]:
|
||||
# get the window of close prices for benchmark_sid from the
|
||||
# get the window of close prices for benchmark_asset from the
|
||||
# last trading day of the simulation, going up to one day
|
||||
# before the simulation start day (so that we can get the %
|
||||
# change on day 1)
|
||||
|
||||
Reference in New Issue
Block a user