diff --git a/tests/test_restrictions.py b/tests/test_restrictions.py index 9ec90e66..af6e681e 100644 --- a/tests/test_restrictions.py +++ b/tests/test_restrictions.py @@ -61,9 +61,11 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): pd.Timedelta('15 hours 5 minutes') ), check_unordered=(False, True), + check_mixed=(False, True), __fail_fast=True, ) - def test_historical_restrictions(self, date_offset, check_unordered): + def test_historical_restrictions(self, date_offset, check_unordered, + check_mixed): """ Test historical restrictions for both interday and intraday restrictions, as well as restrictions defined in/not in order, for both @@ -80,6 +82,19 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): else: maybe_scramble = lambda r: r + if check_mixed: + def maybe_mix_assets(rl): + # In the final restrictions list containing restrictions of + # both assets, swap the first and last restrictions so that + # the restrictions are not contiguous by asset + tmp = rl[0] + rl[0] = rl[-1] + rl[-1] = tmp + return rl + else: + def maybe_mix_assets(rl): + return rl + def rdate(s): """Convert a date string into a restriction for that date.""" # Add date_offset to check that we handle intraday changes. @@ -98,6 +113,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase): Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN), ]) ) + all_restrictions = maybe_mix_assets(all_restrictions) restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions) rl = HistoricalRestrictions(all_restrictions) diff --git a/zipline/finance/asset_restrictions.py b/zipline/finance/asset_restrictions.py index c22d5086..8724eae0 100644 --- a/zipline/finance/asset_restrictions.py +++ b/zipline/finance/asset_restrictions.py @@ -3,9 +3,9 @@ from numpy import vectorize from functools import partial, reduce import operator import pandas as pd -from six import with_metaclass +from six import with_metaclass, iteritems from collections import namedtuple -from itertools import groupby +from toolz import groupby from zipline.utils.enum import enum from zipline.utils.numpy_utils import vectorized_is_element @@ -171,7 +171,7 @@ class HistoricalRestrictions(Restrictions): restrictions_for_asset, key=lambda x: x.effective_date ) for asset, restrictions_for_asset - in groupby(restrictions, lambda x: x.asset) + in iteritems(groupby(lambda x: x.asset, restrictions)) } def is_restricted(self, assets, dt):