diff --git a/tests/test_restrictions.py b/tests/test_restrictions.py index 74327ce9..078b9699 100644 --- a/tests/test_restrictions.py +++ b/tests/test_restrictions.py @@ -1,9 +1,10 @@ import pandas as pd from pandas.util.testing import assert_series_equal -from nose_parameterized import parameterized from six import iteritems from functools import partial +from toolz import groupby + from zipline.finance.restrictions import ( RESTRICTION_STATES, Restriction, @@ -13,6 +14,7 @@ from zipline.finance.restrictions import ( NoopRestrictions, ) +from zipline.testing import parameter_space from zipline.testing.fixtures import ( WithDataPortal, ZiplineTestCase, @@ -34,6 +36,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): cls.ASSET1 = cls.asset_finder.retrieve_asset(1) cls.ASSET2 = cls.asset_finder.retrieve_asset(2) cls.ASSET3 = cls.asset_finder.retrieve_asset(3) + cls.ALL_ASSETS = [cls.ASSET1, cls.ASSET2, cls.ASSET3] def assert_is_restricted(self, rl, asset, dt): self.assertTrue(rl.is_restricted(asset, dt)) @@ -41,119 +44,122 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): def assert_not_restricted(self, rl, asset, dt): self.assertFalse(rl.is_restricted(asset, dt)) - def assert_vectorized_results(self, rl, expected, dt): + def assert_all_restrictions(self, rl, expected, dt): + self.assert_many_restrictions(rl, self.ALL_ASSETS, expected, dt) + + def assert_many_restrictions(self, rl, assets, expected, dt): assert_series_equal( - rl.is_restricted([self.ASSET1, self.ASSET2, self.ASSET3], dt), - pd.Series( - index=pd.Index([self.ASSET1, self.ASSET2, self.ASSET3]), - data=expected - ) + rl.is_restricted(assets, dt), + pd.Series(index=pd.Index(assets), data=expected), ) - @parameterized.expand([ - ('_'.join([timing, ordering]), - timing == 'intraday', - ordering == 'ordered') - for timing in ['intraday', 'interday'] - for ordering in ['ordered', 'unordered'] - ]) - def test_historical_restrictions(self, name, is_intraday, is_ordered): + @parameter_space( + date_offset=( + pd.Timedelta(0), + pd.Timedelta('1 minute'), + pd.Timedelta('15 hours 5 minutes') + ), + check_unordered=(False, True), + __fail_fast=True, + ) + def test_historical_restrictions(self, date_offset, check_unordered): """ Test historical restrictions for both interday and intraday restrictions, as well as restrictions defined in/not in order, for both single- and multi-asset queries """ - - hour_of_day = ' 15:00' if is_intraday else '' - - if is_ordered: - restriction_dates = { - self.ASSET1: [ - (str_to_ts('2011-01-04' + hour_of_day), FROZEN), - (str_to_ts('2011-01-05' + hour_of_day), ALLOWED), - (str_to_ts('2011-01-06' + hour_of_day), FROZEN), - ], - self.ASSET2: [ - (str_to_ts('2011-01-05' + hour_of_day), FROZEN), - (str_to_ts('2011-01-06' + hour_of_day), ALLOWED), - (str_to_ts('2011-01-07' + hour_of_day), FROZEN), - ], - } + if check_unordered: + def maybe_scramble(rs): + # Swap the first two restrictions to check that we don't care + # that the restriction dates are ordered. + tmp = rs[0] + rs[0] = rs[1] + rs[1] = tmp + return rs else: - restriction_dates = { - self.ASSET1: [ - (str_to_ts('2011-01-05' + hour_of_day), ALLOWED), - (str_to_ts('2011-01-06' + hour_of_day), FROZEN), - (str_to_ts('2011-01-04' + hour_of_day), FROZEN), - ], - self.ASSET2: [ - (str_to_ts('2011-01-06' + hour_of_day), ALLOWED), - (str_to_ts('2011-01-05' + hour_of_day), FROZEN), - (str_to_ts('2011-01-07' + hour_of_day), FROZEN), - ], - } + maybe_scramble = lambda r: r - restrictions = sum([ - [Restriction(asset, info[0], info[1]) for info in r_history] - for asset, r_history in iteritems(restriction_dates) - ], []) - rl = HistoricalRestrictions(restrictions) + def rdate(s): + """Convert a date string into a restriction for that date.""" + # Add date_offset to check that we handle intraday changes. + return str_to_ts(s) + date_offset + all_restrictions = ( + maybe_scramble([ + Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN), + Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED), + Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN), + ]) + + + maybe_scramble([ + Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN), + Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED), + Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN), + ]) + ) + restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions) + + rl = HistoricalRestrictions(all_restrictions) assert_not_restricted = partial(self.assert_not_restricted, rl) assert_is_restricted = partial(self.assert_is_restricted, rl) - assert_vectorized_results = partial(self.assert_vectorized_results, rl) + assert_all_restrictions = partial(self.assert_all_restrictions, rl) - for asset, r_history in iteritems(restriction_dates): - dts = sorted([info[0] for info in r_history]) + # Check individual restrictions. + for asset, r_history in iteritems(restrictions_by_asset): + freeze_dt, unfreeze_dt, re_freeze_dt = ( + sorted([r.effective_date for r in r_history]) + ) - # Not restricted until on or after the freeze - assert_not_restricted(asset, dts[0] - MINUTE) - assert_is_restricted(asset, dts[0]) - assert_is_restricted(asset, dts[0] + MINUTE) + # Starts implicitly unrestricted. Restricted on or after the freeze + assert_not_restricted(asset, freeze_dt - MINUTE) + assert_is_restricted(asset, freeze_dt) + assert_is_restricted(asset, freeze_dt + MINUTE) # Unrestricted on or after the unfreeze - assert_is_restricted(asset, dts[1] - MINUTE) - assert_not_restricted(asset, dts[1]) - assert_not_restricted(asset, dts[1] + MINUTE) + assert_is_restricted(asset, unfreeze_dt - MINUTE) + assert_not_restricted(asset, unfreeze_dt) + assert_not_restricted(asset, unfreeze_dt + MINUTE) # Restricted again on or after the freeze - assert_not_restricted(asset, dts[2] - MINUTE) - assert_is_restricted(asset, dts[2]) - assert_is_restricted(asset, dts[2] + MINUTE) + assert_not_restricted(asset, re_freeze_dt - MINUTE) + assert_is_restricted(asset, re_freeze_dt) + assert_is_restricted(asset, re_freeze_dt + MINUTE) + # Should stay restricted for the rest of time - assert_is_restricted(asset, dts[2] + MINUTE * 1000000) - - dts = [str_to_ts(ts + hour_of_day) for ts in ['2011-01-04', - '2011-01-05', - '2011-01-06', - '2011-01-07']] + assert_is_restricted(asset, re_freeze_dt + MINUTE * 1000000) + # Check vectorized restrictions. # Expected results for [self.ASSET1, self.ASSET2, self.ASSET3], # ASSET3 is always False as it has no defined restrictions # 01/04 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: ALLOWED - assert_vectorized_results([False, False, False], dts[0] - MINUTE) - assert_vectorized_results([True, False, False], dts[0]) - assert_vectorized_results([True, False, False], dts[0] + MINUTE) + d0 = rdate('2011-01-04') + assert_all_restrictions([False, False, False], d0 - MINUTE) + assert_all_restrictions([True, False, False], d0) + assert_all_restrictions([True, False, False], d0 + MINUTE) # 01/05 XX:00 ASSET1: FROZEN --> ALLOWED; ASSET2: ALLOWED --> FROZEN - assert_vectorized_results([True, False, False], dts[1] - MINUTE) - assert_vectorized_results([False, True, False], dts[1]) - assert_vectorized_results([False, True, False], dts[1] + MINUTE) + d1 = rdate('2011-01-05') + assert_all_restrictions([True, False, False], d1 - MINUTE) + assert_all_restrictions([False, True, False], d1) + assert_all_restrictions([False, True, False], d1 + MINUTE) # 01/06 XX:00 ASSET1: ALLOWED --> FROZEN; ASSET2: FROZEN --> ALLOWED - assert_vectorized_results([False, True, False], dts[2] - MINUTE) - assert_vectorized_results([True, False, False], dts[2]) - assert_vectorized_results([True, False, False], dts[2] + MINUTE) + d2 = rdate('2011-01-06') + assert_all_restrictions([False, True, False], d2 - MINUTE) + assert_all_restrictions([True, False, False], d2) + assert_all_restrictions([True, False, False], d2 + MINUTE) # 01/07 XX:00 ASSET1: FROZEN; ASSET2: ALLOWED --> FROZEN - assert_vectorized_results([True, False, False], dts[3] - MINUTE) - assert_vectorized_results([True, True, False], dts[3]) - assert_vectorized_results([True, True, False], dts[3] + MINUTE) + d3 = rdate('2011-01-07') + assert_all_restrictions([True, False, False], d3 - MINUTE) + assert_all_restrictions([True, True, False], d3) + assert_all_restrictions([True, True, False], d3 + MINUTE) + # Should stay restricted for the rest of time - assert_vectorized_results( + assert_all_restrictions( [True, True, False], - dts[3] + MINUTE * 10000000 + d3 + (MINUTE * 10000000) ) def test_historical_restrictions_consecutive_states(self): @@ -202,16 +208,17 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): rl = StaticRestrictions([restricted_a1, restricted_a2]) assert_not_restricted = partial(self.assert_not_restricted, rl) assert_is_restricted = partial(self.assert_is_restricted, rl) - assert_vectorized_results = partial(self.assert_vectorized_results, rl) + assert_all_restrictions = partial(self.assert_all_restrictions, rl) for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03', '2011-01-04', + '2011-01-04 1:01', '2020-01-04')]: assert_is_restricted(restricted_a1, dt) assert_is_restricted(restricted_a2, dt) assert_not_restricted(unrestricted_a3, dt) - assert_vectorized_results([True, True, False], dt) + assert_all_restrictions([True, True, False], dt) def test_security_list_restrictions(self): """ @@ -237,23 +244,29 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): assert_not_restricted = partial(self.assert_not_restricted, rl) assert_is_restricted = partial(self.assert_is_restricted, rl) - assert_vectorized_results = partial(self.assert_vectorized_results, rl) + assert_all_restrictions = partial(self.assert_all_restrictions, rl) assert_is_restricted(self.ASSET1, str_to_ts('2011-01-03')) assert_not_restricted(self.ASSET2, str_to_ts('2011-01-03')) assert_not_restricted(self.ASSET3, str_to_ts('2011-01-03')) - assert_vectorized_results( - [True, False, False], str_to_ts('2011-01-03')) + assert_all_restrictions( + [True, False, False], str_to_ts('2011-01-03') + ) assert_not_restricted(self.ASSET1, str_to_ts('2011-01-04')) assert_is_restricted(self.ASSET2, str_to_ts('2011-01-04')) assert_is_restricted(self.ASSET3, str_to_ts('2011-01-04')) - assert_vectorized_results([False, True, True], str_to_ts('2011-01-04')) + assert_all_restrictions( + [False, True, True], str_to_ts('2011-01-04') + ) assert_is_restricted(self.ASSET1, str_to_ts('2011-01-05')) assert_is_restricted(self.ASSET2, str_to_ts('2011-01-05')) assert_is_restricted(self.ASSET3, str_to_ts('2011-01-05')) - assert_vectorized_results([True, True, True], str_to_ts('2011-01-05')) + assert_all_restrictions( + [True, True, True], + str_to_ts('2011-01-05') + ) def test_noop_restrictions(self): """ @@ -262,7 +275,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): rl = NoopRestrictions() assert_not_restricted = partial(self.assert_not_restricted, rl) - assert_vectorized_results = partial(self.assert_vectorized_results, rl) + assert_all_restrictions = partial(self.assert_all_restrictions, rl) for dt in [str_to_ts(dt_str) for dt_str in ('2011-01-03', '2011-01-04', @@ -270,4 +283,4 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): assert_not_restricted(self.ASSET1, dt) assert_not_restricted(self.ASSET2, dt) assert_not_restricted(self.ASSET3, dt) - assert_vectorized_results([False, False, False], dt) + assert_all_restrictions([False, False, False], dt)