MAINT: Remove environment as an argument to benchmark source. (#1816)

MAINT: Remove environment as an argument to benchmark source.

To allow the BenchmarkSource class to be more easily used in contexts other than
a TradingAlgorithm, remove the TradingEnvironment as an argument to the
benchmark source.

Instead:
- Pass a benchmark Asset, instead of a bencmark sid; so that the asset_finder
does not need to be passed to the benchmark source.
- Pass the pre-calculated benchmark_returns instead of an env,
which contains the benchmark_returns; a consumer can let the benchmark_returns
stay as the default of `None` when using an asset.

We may want to further refactor and make two different classes, instead of
relying on a combination of existence/non-existence of benchmark_asset and
benchmark_returns. That refactoring should be easier to do with this change.
This commit is contained in:
Eddie Hebert
2017-05-25 16:11:25 -04:00
committed by GitHub
parent a3fe687b15
commit c1280daaa3
3 changed files with 47 additions and 33 deletions
+17 -14
View File
@@ -96,7 +96,10 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
days_to_use = self.sim_params.sessions[1:]
source = BenchmarkSource(
1, self.env, self.trading_calendar, days_to_use, self.data_portal
self.env.asset_finder.retrieve_asset(1),
self.trading_calendar,
days_to_use,
self.data_portal
)
# should be the equivalent of getting the price history, then doing
@@ -131,30 +134,28 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
with self.assertRaises(BenchmarkAssetNotAvailableTooEarly) as exc:
BenchmarkSource(
3,
self.env,
benchmark,
self.trading_calendar,
self.sim_params.sessions[1:],
self.data_portal
)
self.assertEqual(
'3 does not exist on %s. It started trading on %s.' %
'Equity(3 [C]) does not exist on %s. It started trading on %s.' %
(self.sim_params.sessions[1], benchmark_start),
exc.exception.message
)
with self.assertRaises(BenchmarkAssetNotAvailableTooLate) as exc2:
BenchmarkSource(
3,
self.env,
benchmark,
self.trading_calendar,
self.sim_params.sessions[120:],
self.data_portal
)
self.assertEqual(
'3 does not exist on %s. It stopped trading on %s.' %
'Equity(3 [C]) does not exist on %s. It stopped trading on %s.' %
(self.sim_params.sessions[-1], benchmark_end),
exc2.exception.message
)
@@ -182,8 +183,7 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
)
source = BenchmarkSource(
2,
self.env,
self.env.asset_finder.retrieve_asset(2),
self.trading_calendar,
self.sim_params.sessions,
data_portal
@@ -214,11 +214,14 @@ class TestBenchmark(WithDataPortal, WithSimParams, WithTradingCalendars,
with self.assertRaises(InvalidBenchmarkAsset) as exc:
BenchmarkSource(
4, self.env, self.trading_calendar,
self.sim_params.sessions, self.data_portal
self.env.asset_finder.retrieve_asset(4),
self.trading_calendar,
self.sim_params.sessions,
self.data_portal
)
self.assertEqual("4 cannot be used as the benchmark because it has a "
"stock dividend on 2006-03-16 00:00:00. Choose "
"another asset to use as the benchmark.",
self.assertEqual("Equity(4 [D]) cannot be used as the benchmark "
"because it has a stock dividend on 2006-03-16 "
"00:00:00. Choose another asset to use as the "
"benchmark.",
exc.exception.message)
+11 -2
View File
@@ -541,13 +541,22 @@ class TradingAlgorithm(object):
)
def _create_benchmark_source(self):
if self.benchmark_sid is not None:
benchmark_asset = self.asset_finder.retrieve_asset(
self.benchmark_sid)
benchmark_returns = None
else:
benchmark_asset = None
# get benchmark info from trading environment, which defaults to
# downloading data from Yahoo.
benchmark_returns = self.trading_environment.benchmark_returns
return BenchmarkSource(
benchmark_sid=self.benchmark_sid,
env=self.trading_environment,
benchmark_asset=benchmark_asset,
trading_calendar=self.trading_calendar,
sessions=self.sim_params.sessions,
data_portal=self.data_portal,
emission_rate=self.sim_params.emission_rate,
benchmark_returns=benchmark_returns,
)
def _create_generator(self, sim_params):
+19 -17
View File
@@ -23,19 +23,21 @@ from zipline.errors import (
class BenchmarkSource(object):
def __init__(self, benchmark_sid, env, trading_calendar, sessions,
data_portal, emission_rate="daily"):
self.benchmark_sid = benchmark_sid
self.env = env
def __init__(self,
benchmark_asset,
trading_calendar,
sessions,
data_portal,
emission_rate="daily",
benchmark_returns=None):
self.benchmark_asset = benchmark_asset
self.sessions = sessions
self.emission_rate = emission_rate
self.data_portal = data_portal
if len(sessions) == 0:
self._precalculated_series = pd.Series()
elif self.benchmark_sid:
benchmark_asset = self.env.asset_finder.retrieve_asset(
self.benchmark_sid)
elif benchmark_asset is not None:
self._validate_benchmark(benchmark_asset)
@@ -46,11 +48,8 @@ class BenchmarkSource(object):
self.sessions,
self.data_portal
)
else:
# get benchmark info from trading environment, which defaults to
# downloading data from Yahoo.
daily_series = \
env.benchmark_returns[sessions[0]:sessions[-1]]
elif benchmark_returns is not None:
daily_series = benchmark_returns[sessions[0]:sessions[-1]]
if self.emission_rate == "minute":
# we need to take the env's benchmark returns, which are daily,
@@ -68,6 +67,9 @@ class BenchmarkSource(object):
self._precalculated_series = minute_series
else:
self._precalculated_series = daily_series
else:
raise Exception("Must provide either benchmark_asset or "
"benchmark_returns.")
def get_value(self, dt):
return self._precalculated_series.loc[dt]
@@ -80,19 +82,19 @@ class BenchmarkSource(object):
# error suggesting that the user pick a different asset to use
# as benchmark.
stock_dividends = \
self.data_portal.get_stock_dividends(self.benchmark_sid,
self.data_portal.get_stock_dividends(self.benchmark_asset,
self.sessions)
if len(stock_dividends) > 0:
raise InvalidBenchmarkAsset(
sid=str(self.benchmark_sid),
sid=str(self.benchmark_asset),
dt=stock_dividends[0]["ex_date"]
)
if benchmark_asset.start_date > self.sessions[0]:
# the asset started trading after the first simulation day
raise BenchmarkAssetNotAvailableTooEarly(
sid=str(self.benchmark_sid),
sid=str(self.benchmark_asset),
dt=self.sessions[0],
start_dt=benchmark_asset.start_date
)
@@ -100,7 +102,7 @@ class BenchmarkSource(object):
if benchmark_asset.end_date < self.sessions[-1]:
# the asset stopped trading before the last simulation day
raise BenchmarkAssetNotAvailableTooLate(
sid=str(self.benchmark_sid),
sid=str(self.benchmark_asset),
dt=self.sessions[-1],
end_dt=benchmark_asset.end_date
)
@@ -157,7 +159,7 @@ class BenchmarkSource(object):
else:
start_date = asset.start_date
if start_date < trading_days[0]:
# get the window of close prices for benchmark_sid from the
# get the window of close prices for benchmark_asset from the
# last trading day of the simulation, going up to one day
# before the simulation start day (so that we can get the %
# change on day 1)