mirror of
https://github.com/wassname/catalyst.git
synced 2026-06-30 02:54:16 +08:00
TEST: Clarify test_restrictions a bit.
- Use parameter_space instead of `parameterized.expand`. - Use a timedelta instead of concatenating strings. - Use a (possibly no-op) scramble function instead of reordering list literals. - Use `freeze_dt, unfreeze_dt, re_freeze_dt` instead of `dts[n]`. - Rename `assert_vectorized_results` to `assert_all_restrictions`.
This commit is contained in:
committed by
Andrew Liang
parent
99f6ecab3f
commit
8f150eb6ce
+103
-90
@@ -1,9 +1,10 @@
|
||||
import pandas as pd
|
||||
from pandas.util.testing import assert_series_equal
|
||||
from nose_parameterized import parameterized
|
||||
from six import iteritems
|
||||
from functools import partial
|
||||
|
||||
from toolz import groupby
|
||||
|
||||
from zipline.finance.restrictions import (
|
||||
RESTRICTION_STATES,
|
||||
Restriction,
|
||||
@@ -13,6 +14,7 @@ from zipline.finance.restrictions import (
|
||||
NoopRestrictions,
|
||||
)
|
||||
|
||||
from zipline.testing import parameter_space
|
||||
from zipline.testing.fixtures import (
|
||||
WithDataPortal,
|
||||
ZiplineTestCase,
|
||||
@@ -34,6 +36,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
cls.ASSET1 = cls.asset_finder.retrieve_asset(1)
|
||||
cls.ASSET2 = cls.asset_finder.retrieve_asset(2)
|
||||
cls.ASSET3 = cls.asset_finder.retrieve_asset(3)
|
||||
cls.ALL_ASSETS = [cls.ASSET1, cls.ASSET2, cls.ASSET3]
|
||||
|
||||
def assert_is_restricted(self, rl, asset, dt):
|
||||
self.assertTrue(rl.is_restricted(asset, dt))
|
||||
@@ -41,119 +44,122 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
def assert_not_restricted(self, rl, asset, dt):
|
||||
self.assertFalse(rl.is_restricted(asset, dt))
|
||||
|
||||
def assert_vectorized_results(self, rl, expected, dt):
|
||||
def assert_all_restrictions(self, rl, expected, dt):
|
||||
self.assert_many_restrictions(rl, self.ALL_ASSETS, expected, dt)
|
||||
|
||||
def assert_many_restrictions(self, rl, assets, expected, dt):
|
||||
assert_series_equal(
|
||||
rl.is_restricted([self.ASSET1, self.ASSET2, self.ASSET3], dt),
|
||||
pd.Series(
|
||||
index=pd.Index([self.ASSET1, self.ASSET2, self.ASSET3]),
|
||||
data=expected
|
||||
)
|
||||
rl.is_restricted(assets, dt),
|
||||
pd.Series(index=pd.Index(assets), data=expected),
|
||||
)
|
||||
|
||||
@parameterized.expand([
|
||||
('_'.join([timing, ordering]),
|
||||
timing == 'intraday',
|
||||
ordering == 'ordered')
|
||||
for timing in ['intraday', 'interday']
|
||||
for ordering in ['ordered', 'unordered']
|
||||
])
|
||||
def test_historical_restrictions(self, name, is_intraday, is_ordered):
|
||||
@parameter_space(
|
||||
date_offset=(
|
||||
pd.Timedelta(0),
|
||||
pd.Timedelta('1 minute'),
|
||||
pd.Timedelta('15 hours 5 minutes')
|
||||
),
|
||||
check_unordered=(False, True),
|
||||
__fail_fast=True,
|
||||
)
|
||||
def test_historical_restrictions(self, date_offset, check_unordered):
|
||||
"""
|
||||
Test historical restrictions for both interday and intraday
|
||||
restrictions, as well as restrictions defined in/not in order, for both
|
||||
single- and multi-asset queries
|
||||
"""
|
||||
|
||||
hour_of_day = ' 15:00' if is_intraday else ''
|
||||
|
||||
if is_ordered:
|
||||
restriction_dates = {
|
||||
self.ASSET1: [
|
||||
(str_to_ts('2011-01-04' + hour_of_day), FROZEN),
|
||||
(str_to_ts('2011-01-05' + hour_of_day), ALLOWED),
|
||||
(str_to_ts('2011-01-06' + hour_of_day), FROZEN),
|
||||
],
|
||||
self.ASSET2: [
|
||||
(str_to_ts('2011-01-05' + hour_of_day), FROZEN),
|
||||
(str_to_ts('2011-01-06' + hour_of_day), ALLOWED),
|
||||
(str_to_ts('2011-01-07' + hour_of_day), FROZEN),
|
||||
],
|
||||
}
|
||||
if check_unordered:
|
||||
def maybe_scramble(rs):
|
||||
# Swap the first two restrictions to check that we don't care
|
||||
# that the restriction dates are ordered.
|
||||
tmp = rs[0]
|
||||
rs[0] = rs[1]
|
||||
rs[1] = tmp
|
||||
return rs
|
||||
else:
|
||||
restriction_dates = {
|
||||
self.ASSET1: [
|
||||
(str_to_ts('2011-01-05' + hour_of_day), ALLOWED),
|
||||
(str_to_ts('2011-01-06' + hour_of_day), FROZEN),
|
||||
(str_to_ts('2011-01-04' + hour_of_day), FROZEN),
|
||||
],
|
||||
self.ASSET2: [
|
||||
(str_to_ts('2011-01-06' + hour_of_day), ALLOWED),
|
||||
(str_to_ts('2011-01-05' + hour_of_day), FROZEN),
|
||||
(str_to_ts('2011-01-07' + hour_of_day), FROZEN),
|
||||
],
|
||||
}
|
||||
maybe_scramble = lambda r: r
|
||||
|
||||
restrictions = sum([
|
||||
[Restriction(asset, info[0], info[1]) for info in r_history]
|
||||
for asset, r_history in iteritems(restriction_dates)
|
||||
], [])
|
||||
rl = HistoricalRestrictions(restrictions)
|
||||
def rdate(s):
|
||||
"""Convert a date string into a restriction for that date."""
|
||||
# Add date_offset to check that we handle intraday changes.
|
||||
return str_to_ts(s) + date_offset
|
||||
|
||||
all_restrictions = (
|
||||
maybe_scramble([
|
||||
Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN),
|
||||
Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED),
|
||||
Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN),
|
||||
])
|
||||
+
|
||||
maybe_scramble([
|
||||
Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN),
|
||||
Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED),
|
||||
Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN),
|
||||
])
|
||||
)
|
||||
restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions)
|
||||
|
||||
rl = HistoricalRestrictions(all_restrictions)
|
||||
assert_not_restricted = partial(self.assert_not_restricted, rl)
|
||||
assert_is_restricted = partial(self.assert_is_restricted, rl)
|
||||
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
|
||||
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
|
||||
|
||||
for asset, r_history in iteritems(restriction_dates):
|
||||
dts = sorted([info[0] for info in r_history])
|
||||
# Check individual restrictions.
|
||||
for asset, r_history in iteritems(restrictions_by_asset):
|
||||
freeze_dt, unfreeze_dt, re_freeze_dt = (
|
||||
sorted([r.effective_date for r in r_history])
|
||||
)
|
||||
|
||||
# Not restricted until on or after the freeze
|
||||
assert_not_restricted(asset, dts[0] - MINUTE)
|
||||
assert_is_restricted(asset, dts[0])
|
||||
assert_is_restricted(asset, dts[0] + MINUTE)
|
||||
# Starts implicitly unrestricted. Restricted on or after the freeze
|
||||
assert_not_restricted(asset, freeze_dt - MINUTE)
|
||||
assert_is_restricted(asset, freeze_dt)
|
||||
assert_is_restricted(asset, freeze_dt + MINUTE)
|
||||
|
||||
# Unrestricted on or after the unfreeze
|
||||
assert_is_restricted(asset, dts[1] - MINUTE)
|
||||
assert_not_restricted(asset, dts[1])
|
||||
assert_not_restricted(asset, dts[1] + MINUTE)
|
||||
assert_is_restricted(asset, unfreeze_dt - MINUTE)
|
||||
assert_not_restricted(asset, unfreeze_dt)
|
||||
assert_not_restricted(asset, unfreeze_dt + MINUTE)
|
||||
|
||||
# Restricted again on or after the freeze
|
||||
assert_not_restricted(asset, dts[2] - MINUTE)
|
||||
assert_is_restricted(asset, dts[2])
|
||||
assert_is_restricted(asset, dts[2] + MINUTE)
|
||||
assert_not_restricted(asset, re_freeze_dt - MINUTE)
|
||||
assert_is_restricted(asset, re_freeze_dt)
|
||||
assert_is_restricted(asset, re_freeze_dt + MINUTE)
|
||||
|
||||
# Should stay restricted for the rest of time
|
||||
assert_is_restricted(asset, dts[2] + MINUTE * 1000000)
|
||||
|
||||
dts = [str_to_ts(ts + hour_of_day) for ts in ['2011-01-04',
|
||||
'2011-01-05',
|
||||
'2011-01-06',
|
||||
'2011-01-07']]
|
||||
assert_is_restricted(asset, re_freeze_dt + MINUTE * 1000000)
|
||||
|
||||
# Check vectorized restrictions.
|
||||
# Expected results for [self.ASSET1, self.ASSET2, self.ASSET3],
|
||||
# ASSET3 is always False as it has no defined restrictions
|
||||
|
||||
# 01/04 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: ALLOWED
|
||||
assert_vectorized_results([False, False, False], dts[0] - MINUTE)
|
||||
assert_vectorized_results([True, False, False], dts[0])
|
||||
assert_vectorized_results([True, False, False], dts[0] + MINUTE)
|
||||
d0 = rdate('2011-01-04')
|
||||
assert_all_restrictions([False, False, False], d0 - MINUTE)
|
||||
assert_all_restrictions([True, False, False], d0)
|
||||
assert_all_restrictions([True, False, False], d0 + MINUTE)
|
||||
|
||||
# 01/05 XX:00 ASSET1: FROZEN --> ALLOWED; ASSET2: ALLOWED --> FROZEN
|
||||
assert_vectorized_results([True, False, False], dts[1] - MINUTE)
|
||||
assert_vectorized_results([False, True, False], dts[1])
|
||||
assert_vectorized_results([False, True, False], dts[1] + MINUTE)
|
||||
d1 = rdate('2011-01-05')
|
||||
assert_all_restrictions([True, False, False], d1 - MINUTE)
|
||||
assert_all_restrictions([False, True, False], d1)
|
||||
assert_all_restrictions([False, True, False], d1 + MINUTE)
|
||||
|
||||
# 01/06 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: FROZEN --> ALLOWED
|
||||
assert_vectorized_results([False, True, False], dts[2] - MINUTE)
|
||||
assert_vectorized_results([True, False, False], dts[2])
|
||||
assert_vectorized_results([True, False, False], dts[2] + MINUTE)
|
||||
d2 = rdate('2011-01-06')
|
||||
assert_all_restrictions([False, True, False], d2 - MINUTE)
|
||||
assert_all_restrictions([True, False, False], d2)
|
||||
assert_all_restrictions([True, False, False], d2 + MINUTE)
|
||||
|
||||
# 01/07 XX:00 ASSET1: FROZEN; ASSET2: ALLOWED --> FROZEN
|
||||
assert_vectorized_results([True, False, False], dts[3] - MINUTE)
|
||||
assert_vectorized_results([True, True, False], dts[3])
|
||||
assert_vectorized_results([True, True, False], dts[3] + MINUTE)
|
||||
d3 = rdate('2011-01-07')
|
||||
assert_all_restrictions([True, False, False], d3 - MINUTE)
|
||||
assert_all_restrictions([True, True, False], d3)
|
||||
assert_all_restrictions([True, True, False], d3 + MINUTE)
|
||||
|
||||
# Should stay restricted for the rest of time
|
||||
assert_vectorized_results(
|
||||
assert_all_restrictions(
|
||||
[True, True, False],
|
||||
dts[3] + MINUTE * 10000000
|
||||
d3 + (MINUTE * 10000000)
|
||||
)
|
||||
|
||||
def test_historical_restrictions_consecutive_states(self):
|
||||
@@ -202,16 +208,17 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
rl = StaticRestrictions([restricted_a1, restricted_a2])
|
||||
assert_not_restricted = partial(self.assert_not_restricted, rl)
|
||||
assert_is_restricted = partial(self.assert_is_restricted, rl)
|
||||
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
|
||||
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
|
||||
|
||||
for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03',
|
||||
'2011-01-04',
|
||||
'2011-01-04 1:01',
|
||||
'2020-01-04')]:
|
||||
assert_is_restricted(restricted_a1, dt)
|
||||
assert_is_restricted(restricted_a2, dt)
|
||||
assert_not_restricted(unrestricted_a3, dt)
|
||||
|
||||
assert_vectorized_results([True, True, False], dt)
|
||||
assert_all_restrictions([True, True, False], dt)
|
||||
|
||||
def test_security_list_restrictions(self):
|
||||
"""
|
||||
@@ -237,23 +244,29 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
|
||||
assert_not_restricted = partial(self.assert_not_restricted, rl)
|
||||
assert_is_restricted = partial(self.assert_is_restricted, rl)
|
||||
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
|
||||
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
|
||||
|
||||
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03'))
|
||||
assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03'))
|
||||
assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03'))
|
||||
assert_vectorized_results(
|
||||
[True, False, False], str_to_ts('2011-01-03'))
|
||||
assert_all_restrictions(
|
||||
[True, False, False], str_to_ts('2011-01-03')
|
||||
)
|
||||
|
||||
assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04'))
|
||||
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04'))
|
||||
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04'))
|
||||
assert_vectorized_results([False, True, True], str_to_ts('2011-01-04'))
|
||||
assert_all_restrictions(
|
||||
[False, True, True], str_to_ts('2011-01-04')
|
||||
)
|
||||
|
||||
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05'))
|
||||
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05'))
|
||||
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05'))
|
||||
assert_vectorized_results([True, True, True], str_to_ts('2011-01-05'))
|
||||
assert_all_restrictions(
|
||||
[True, True, True],
|
||||
str_to_ts('2011-01-05')
|
||||
)
|
||||
|
||||
def test_noop_restrictions(self):
|
||||
"""
|
||||
@@ -262,7 +275,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
|
||||
rl = NoopRestrictions()
|
||||
assert_not_restricted = partial(self.assert_not_restricted, rl)
|
||||
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
|
||||
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
|
||||
|
||||
for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03',
|
||||
'2011-01-04',
|
||||
@@ -270,4 +283,4 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
|
||||
assert_not_restricted(self.ASSET1, dt)
|
||||
assert_not_restricted(self.ASSET2, dt)
|
||||
assert_not_restricted(self.ASSET3, dt)
|
||||
assert_vectorized_results([False, False, False], dt)
|
||||
assert_all_restrictions([False, False, False], dt)
|
||||
|
||||
Reference in New Issue
Block a user