mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 14:51:47 +08:00
Merge pull request #1168 from quantopian/fix_crashing_benchmark
FIX: Crashing on calculating benchmarking when no trading days
This commit is contained in:
@@ -155,6 +155,7 @@ from zipline.test_algorithms import (
|
||||
call_with_good_kwargs_get_open_orders,
|
||||
call_with_no_kwargs_get_open_orders,
|
||||
empty_positions,
|
||||
set_benchmark_algo,
|
||||
no_handle_data,
|
||||
)
|
||||
from zipline.utils.api_support import ZiplineAPI, set_algo_instance
|
||||
@@ -1819,6 +1820,28 @@ def handle_data(context, data):
|
||||
self.assertTrue(all(num_positions == 0))
|
||||
self.assertTrue(all(amounts == 0))
|
||||
|
||||
@parameterized.expand([
|
||||
('noop_algo', noop_algo),
|
||||
('with_benchmark_set', set_benchmark_algo)]
|
||||
)
|
||||
def test_zero_trading_days(self, name, algocode):
|
||||
"""
|
||||
Test that when we run a simulation with no trading days (e.g. beginning
|
||||
and ending the same weekend), we don't crash on calculating the
|
||||
benchmark
|
||||
"""
|
||||
sim_params = factory.create_simulation_parameters(
|
||||
start=pd.Timestamp('2006-01-14', tz='UTC'),
|
||||
end=pd.Timestamp('2006-01-15', tz='UTC')
|
||||
)
|
||||
|
||||
algo = TradingAlgorithm(
|
||||
script=algocode,
|
||||
sim_params=sim_params,
|
||||
env=self.env
|
||||
)
|
||||
algo.run(self.data_portal)
|
||||
|
||||
|
||||
class TestGetDatetime(WithLogger,
|
||||
WithSimParams,
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from zipline.errors import (
|
||||
InvalidBenchmarkAsset,
|
||||
BenchmarkAssetNotAvailableTooEarly,
|
||||
@@ -29,15 +31,17 @@ class BenchmarkSource(object):
|
||||
self.emission_rate = emission_rate
|
||||
self.data_portal = data_portal
|
||||
|
||||
if self.benchmark_sid:
|
||||
self.benchmark_asset = self.env.asset_finder.retrieve_asset(
|
||||
if len(trading_days) == 0:
|
||||
self._precalculated_series = pd.Series()
|
||||
elif self.benchmark_sid:
|
||||
benchmark_asset = self.env.asset_finder.retrieve_asset(
|
||||
self.benchmark_sid)
|
||||
|
||||
self._validate_benchmark()
|
||||
self._validate_benchmark(benchmark_asset)
|
||||
|
||||
self._precalculated_series = \
|
||||
self._initialize_precalculated_series(
|
||||
self.benchmark_asset,
|
||||
benchmark_asset,
|
||||
self.env,
|
||||
self.trading_days,
|
||||
self.data_portal
|
||||
@@ -68,7 +72,7 @@ class BenchmarkSource(object):
|
||||
def get_value(self, dt):
|
||||
return self._precalculated_series.loc[dt]
|
||||
|
||||
def _validate_benchmark(self):
|
||||
def _validate_benchmark(self, benchmark_asset):
|
||||
# check if this security has a stock dividend. if so, raise an
|
||||
# error suggesting that the user pick a different asset to use
|
||||
# as benchmark.
|
||||
@@ -82,20 +86,20 @@ class BenchmarkSource(object):
|
||||
dt=stock_dividends[0]["ex_date"]
|
||||
)
|
||||
|
||||
if self.benchmark_asset.start_date > self.trading_days[0]:
|
||||
if benchmark_asset.start_date > self.trading_days[0]:
|
||||
# the asset started trading after the first simulation day
|
||||
raise BenchmarkAssetNotAvailableTooEarly(
|
||||
sid=str(self.benchmark_sid),
|
||||
dt=self.trading_days[0],
|
||||
start_dt=self.benchmark_asset.start_date
|
||||
start_dt=benchmark_asset.start_date
|
||||
)
|
||||
|
||||
if self.benchmark_asset.end_date < self.trading_days[-1]:
|
||||
if benchmark_asset.end_date < self.trading_days[-1]:
|
||||
# the asset stopped trading before the last simulation day
|
||||
raise BenchmarkAssetNotAvailableTooLate(
|
||||
sid=str(self.benchmark_sid),
|
||||
dt=self.trading_days[0],
|
||||
end_dt=self.benchmark_asset.end_date
|
||||
end_dt=benchmark_asset.end_date
|
||||
)
|
||||
|
||||
def _initialize_precalculated_series(self, asset, env, trading_days,
|
||||
@@ -146,7 +150,7 @@ class BenchmarkSource(object):
|
||||
|
||||
return benchmark_series.pct_change()[1:]
|
||||
else:
|
||||
start_date = self.benchmark_asset.start_date
|
||||
start_date = asset.start_date
|
||||
if start_date < trading_days[0]:
|
||||
# get the window of close prices for benchmark_sid from the
|
||||
# last trading day of the simulation, going up to one day
|
||||
|
||||
@@ -1140,3 +1140,14 @@ def test_history(context,data):
|
||||
record(amounts=context.portfolio.positions[context.sid].amount)
|
||||
record(num_positions=len(context.portfolio.positions))
|
||||
"""
|
||||
|
||||
set_benchmark_algo = """
|
||||
from zipline.api import symbol, set_benchmark
|
||||
|
||||
def initialize(context):
|
||||
set_benchmark(symbol('TEST'))
|
||||
context.sid = symbol('TEST')
|
||||
|
||||
def handle_data(context, data):
|
||||
pass
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user