diff --git a/tests/test_restrictions.py b/tests/test_restrictions.py index 9ec90e66..0d6020f4 100644 --- a/tests/test_restrictions.py +++ b/tests/test_restrictions.py @@ -60,44 +60,37 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): pd.Timedelta('1 minute'), pd.Timedelta('15 hours 5 minutes') ), - check_unordered=(False, True), + restriction_order=( + list(range(6)), # Keep restrictions in order. + [0, 2, 1, 3, 5, 4], # Re-order within asset. + [0, 3, 1, 4, 2, 5], # Scramble assets, maintain per-asset order. + [0, 5, 2, 3, 1, 4], # Scramble assets and per-asset order. + ), __fail_fast=True, ) - def test_historical_restrictions(self, date_offset, check_unordered): + def test_historical_restrictions(self, date_offset, restriction_order): """ Test historical restrictions for both interday and intraday restrictions, as well as restrictions defined in/not in order, for both single- and multi-asset queries """ - if check_unordered: - def maybe_scramble(rs): - # Swap the first two restrictions to check that we don't care - # that the restriction dates are ordered. - tmp = rs[0] - rs[0] = rs[1] - rs[1] = tmp - return rs - else: - maybe_scramble = lambda r: r - def rdate(s): """Convert a date string into a restriction for that date.""" # Add date_offset to check that we handle intraday changes. return str_to_ts(s) + date_offset - all_restrictions = ( - maybe_scramble([ - Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN), - Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED), - Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN), - ]) - + - maybe_scramble([ - Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN), - Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED), - Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN), - ]) - ) + base_restrictions = [ + Restriction(self.ASSET1, rdate('2011-01-04'), FROZEN), + Restriction(self.ASSET1, rdate('2011-01-05'), ALLOWED), + Restriction(self.ASSET1, rdate('2011-01-06'), FROZEN), + Restriction(self.ASSET2, rdate('2011-01-05'), FROZEN), + Restriction(self.ASSET2, rdate('2011-01-06'), ALLOWED), + Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN), + ] + # Scramble the restrictions based on restriction_order to check that we + # don't depend on the order in which restrictions are provided to us. + all_restrictions = [base_restrictions[i] for i in restriction_order] + restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions) rl = HistoricalRestrictions(all_restrictions) diff --git a/zipline/finance/asset_restrictions.py b/zipline/finance/asset_restrictions.py index c22d5086..8724eae0 100644 --- a/zipline/finance/asset_restrictions.py +++ b/zipline/finance/asset_restrictions.py @@ -3,9 +3,9 @@ from numpy import vectorize from functools import partial, reduce import operator import pandas as pd -from six import with_metaclass +from six import with_metaclass, iteritems from collections import namedtuple -from itertools import groupby +from toolz import groupby from zipline.utils.enum import enum from zipline.utils.numpy_utils import vectorized_is_element @@ -171,7 +171,7 @@ class HistoricalRestrictions(Restrictions): restrictions_for_asset, key=lambda x: x.effective_date ) for asset, restrictions_for_asset - in groupby(restrictions, lambda x: x.asset) + in iteritems(groupby(lambda x: x.asset, restrictions)) } def is_restricted(self, assets, dt):