TEST: Clarify test_restrictions a bit.

- Use parameter_space instead of `parameterized.expand`.
- Use a timedelta instead of concatenating strings.
- Use a (possibly no-op) scramble function instead of reordering list
  literals.
- Use `freeze_dt, unfreeze_dt, re_freeze_dt` instead of `dts[n]`.
- Rename `assert_vectorized_results` to `assert_all_restrictions`.
This commit is contained in:
Scott Sanderson
2016-09-28 18:56:03 -04:00
committed by Andrew Liang
parent 99f6ecab3f
commit 8f150eb6ce
+103 -90
View File
@@ -1,9 +1,10 @@
import pandas as pd
from pandas.util.testing import assert_series_equal
from nose_parameterized import parameterized
from six import iteritems
from functools import partial
from toolz import groupby
from zipline.finance.restrictions import (
RESTRICTION_STATES,
Restriction,
@@ -13,6 +14,7 @@ from zipline.finance.restrictions import (
NoopRestrictions,
)
from zipline.testing import parameter_space
from zipline.testing.fixtures import (
WithDataPortal,
ZiplineTestCase,
@@ -34,6 +36,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
cls.ASSET1 = cls.asset_finder.retrieve_asset(1)
cls.ASSET2 = cls.asset_finder.retrieve_asset(2)
cls.ASSET3 = cls.asset_finder.retrieve_asset(3)
cls.ALL_ASSETS = [cls.ASSET1, cls.ASSET2, cls.ASSET3]
def assert_is_restricted(self, rl, asset, dt):
self.assertTrue(rl.is_restricted(asset, dt))
@@ -41,119 +44,122 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
def assert_not_restricted(self, rl, asset, dt):
self.assertFalse(rl.is_restricted(asset, dt))
def assert_vectorized_results(self, rl, expected, dt):
def assert_all_restrictions(self, rl, expected, dt):
self.assert_many_restrictions(rl, self.ALL_ASSETS, expected, dt)
def assert_many_restrictions(self, rl, assets, expected, dt):
assert_series_equal(
rl.is_restricted([self.ASSET1, self.ASSET2, self.ASSET3], dt),
pd.Series(
index=pd.Index([self.ASSET1, self.ASSET2, self.ASSET3]),
data=expected
)
rl.is_restricted(assets, dt),
pd.Series(index=pd.Index(assets), data=expected),
)
@parameterized.expand([
('_'.join([timing, ordering]),
timing == 'intraday',
ordering == 'ordered')
for timing in ['intraday', 'interday']
for ordering in ['ordered', 'unordered']
])
def test_historical_restrictions(self, name, is_intraday, is_ordered):
@parameter_space(
date_offset=(
pd.Timedelta(0),
pd.Timedelta('1 minute'),
pd.Timedelta('15 hours 5 minutes')
),
check_unordered=(False, True),
__fail_fast=True,
)
def test_historical_restrictions(self, date_offset, check_unordered):
"""
Test historical restrictions for both interday and intraday
restrictions, as well as restrictions defined in/not in order, for both
single- and multi-asset queries
"""
hour_of_day = ' 15:00' if is_intraday else ''
if is_ordered:
restriction_dates = {
self.ASSET1: [
(str_to_ts('2011-01-04' + hour_of_day), FROZEN),
(str_to_ts('2011-01-05' + hour_of_day), ALLOWED),
(str_to_ts('2011-01-06' + hour_of_day), FROZEN),
],
self.ASSET2: [
(str_to_ts('2011-01-05' + hour_of_day), FROZEN),
(str_to_ts('2011-01-06' + hour_of_day), ALLOWED),
(str_to_ts('2011-01-07' + hour_of_day), FROZEN),
],
}
if check_unordered:
def maybe_scramble(rs):
# Swap the first two restrictions to check that we don't care
# that the restriction dates are ordered.
tmp = rs[0]
rs[0] = rs[1]
rs[1] = tmp
return rs
else:
restriction_dates = {
self.ASSET1: [
(str_to_ts('2011-01-05' + hour_of_day), ALLOWED),
(str_to_ts('2011-01-06' + hour_of_day), FROZEN),
(str_to_ts('2011-01-04' + hour_of_day), FROZEN),
],
self.ASSET2: [
(str_to_ts('2011-01-06' + hour_of_day), ALLOWED),
(str_to_ts('2011-01-05' + hour_of_day), FROZEN),
(str_to_ts('2011-01-07' + hour_of_day), FROZEN),
],
}
maybe_scramble = lambda r: r
restrictions = sum([
[Restriction(asset, info[0], info[1]) for info in r_history]
for asset, r_history in iteritems(restriction_dates)
], [])
rl = HistoricalRestrictions(restrictions)
def rdate(s):
"""Convert a date string into a restriction for that date."""
# Add date_offset to check that we handle intraday changes.
return str_to_ts(s) + date_offset
all_restrictions = (
maybe_scramble([
Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN),
Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED),
Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN),
])
+
maybe_scramble([
Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN),
Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED),
Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN),
])
)
restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions)
rl = HistoricalRestrictions(all_restrictions)
assert_not_restricted = partial(self.assert_not_restricted, rl)
assert_is_restricted = partial(self.assert_is_restricted, rl)
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
for asset, r_history in iteritems(restriction_dates):
dts = sorted([info[0] for info in r_history])
# Check individual restrictions.
for asset, r_history in iteritems(restrictions_by_asset):
freeze_dt, unfreeze_dt, re_freeze_dt = (
sorted([r.effective_date for r in r_history])
)
# Not restricted until on or after the freeze
assert_not_restricted(asset, dts[0] - MINUTE)
assert_is_restricted(asset, dts[0])
assert_is_restricted(asset, dts[0] + MINUTE)
# Starts implicitly unrestricted. Restricted on or after the freeze
assert_not_restricted(asset, freeze_dt - MINUTE)
assert_is_restricted(asset, freeze_dt)
assert_is_restricted(asset, freeze_dt + MINUTE)
# Unrestricted on or after the unfreeze
assert_is_restricted(asset, dts[1] - MINUTE)
assert_not_restricted(asset, dts[1])
assert_not_restricted(asset, dts[1] + MINUTE)
assert_is_restricted(asset, unfreeze_dt - MINUTE)
assert_not_restricted(asset, unfreeze_dt)
assert_not_restricted(asset, unfreeze_dt + MINUTE)
# Restricted again on or after the freeze
assert_not_restricted(asset, dts[2] - MINUTE)
assert_is_restricted(asset, dts[2])
assert_is_restricted(asset, dts[2] + MINUTE)
assert_not_restricted(asset, re_freeze_dt - MINUTE)
assert_is_restricted(asset, re_freeze_dt)
assert_is_restricted(asset, re_freeze_dt + MINUTE)
# Should stay restricted for the rest of time
assert_is_restricted(asset, dts[2] + MINUTE * 1000000)
dts = [str_to_ts(ts + hour_of_day) for ts in ['2011-01-04',
'2011-01-05',
'2011-01-06',
'2011-01-07']]
assert_is_restricted(asset, re_freeze_dt + MINUTE * 1000000)
# Check vectorized restrictions.
# Expected results for [self.ASSET1, self.ASSET2, self.ASSET3],
# ASSET3 is always False as it has no defined restrictions
# 01/04 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: ALLOWED
assert_vectorized_results([False, False, False], dts[0] - MINUTE)
assert_vectorized_results([True, False, False], dts[0])
assert_vectorized_results([True, False, False], dts[0] + MINUTE)
d0 = rdate('2011-01-04')
assert_all_restrictions([False, False, False], d0 - MINUTE)
assert_all_restrictions([True, False, False], d0)
assert_all_restrictions([True, False, False], d0 + MINUTE)
# 01/05 XX:00 ASSET1: FROZEN --> ALLOWED; ASSET2: ALLOWED --> FROZEN
assert_vectorized_results([True, False, False], dts[1] - MINUTE)
assert_vectorized_results([False, True, False], dts[1])
assert_vectorized_results([False, True, False], dts[1] + MINUTE)
d1 = rdate('2011-01-05')
assert_all_restrictions([True, False, False], d1 - MINUTE)
assert_all_restrictions([False, True, False], d1)
assert_all_restrictions([False, True, False], d1 + MINUTE)
# 01/06 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: FROZEN --> ALLOWED
assert_vectorized_results([False, True, False], dts[2] - MINUTE)
assert_vectorized_results([True, False, False], dts[2])
assert_vectorized_results([True, False, False], dts[2] + MINUTE)
d2 = rdate('2011-01-06')
assert_all_restrictions([False, True, False], d2 - MINUTE)
assert_all_restrictions([True, False, False], d2)
assert_all_restrictions([True, False, False], d2 + MINUTE)
# 01/07 XX:00 ASSET1: FROZEN; ASSET2: ALLOWED --> FROZEN
assert_vectorized_results([True, False, False], dts[3] - MINUTE)
assert_vectorized_results([True, True, False], dts[3])
assert_vectorized_results([True, True, False], dts[3] + MINUTE)
d3 = rdate('2011-01-07')
assert_all_restrictions([True, False, False], d3 - MINUTE)
assert_all_restrictions([True, True, False], d3)
assert_all_restrictions([True, True, False], d3 + MINUTE)
# Should stay restricted for the rest of time
assert_vectorized_results(
assert_all_restrictions(
[True, True, False],
dts[3] + MINUTE * 10000000
d3 + (MINUTE * 10000000)
)
def test_historical_restrictions_consecutive_states(self):
@@ -202,16 +208,17 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
rl = StaticRestrictions([restricted_a1, restricted_a2])
assert_not_restricted = partial(self.assert_not_restricted, rl)
assert_is_restricted = partial(self.assert_is_restricted, rl)
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03',
'2011-01-04',
'2011-01-04 1:01',
'2020-01-04')]:
assert_is_restricted(restricted_a1, dt)
assert_is_restricted(restricted_a2, dt)
assert_not_restricted(unrestricted_a3, dt)
assert_vectorized_results([True, True, False], dt)
assert_all_restrictions([True, True, False], dt)
def test_security_list_restrictions(self):
"""
@@ -237,23 +244,29 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
assert_not_restricted = partial(self.assert_not_restricted, rl)
assert_is_restricted = partial(self.assert_is_restricted, rl)
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03'))
assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03'))
assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03'))
assert_vectorized_results(
[True, False, False], str_to_ts('2011-01-03'))
assert_all_restrictions(
[True, False, False], str_to_ts('2011-01-03')
)
assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04'))
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04'))
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04'))
assert_vectorized_results([False, True, True], str_to_ts('2011-01-04'))
assert_all_restrictions(
[False, True, True], str_to_ts('2011-01-04')
)
assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05'))
assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05'))
assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05'))
assert_vectorized_results([True, True, True], str_to_ts('2011-01-05'))
assert_all_restrictions(
[True, True, True],
str_to_ts('2011-01-05')
)
def test_noop_restrictions(self):
"""
@@ -262,7 +275,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
rl = NoopRestrictions()
assert_not_restricted = partial(self.assert_not_restricted, rl)
assert_vectorized_results = partial(self.assert_vectorized_results, rl)
assert_all_restrictions = partial(self.assert_all_restrictions, rl)
for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03',
'2011-01-04',
@@ -270,4 +283,4 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
assert_not_restricted(self.ASSET1, dt)
assert_not_restricted(self.ASSET2, dt)
assert_not_restricted(self.ASSET3, dt)
assert_vectorized_results([False, False, False], dt)
assert_all_restrictions([False, False, False], dt)