BUG: Restrictions passed into HistoricalRestrictions not sorted correctly

This commit is contained in:
Andrew Liang
2016-10-05 14:09:26 -04:00
parent 08df8dfbf2
commit b0aba20a6e
2 changed files with 20 additions and 4 deletions
+17 -1
View File
@@ -61,9 +61,11 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
pd.Timedelta('15 hours 5 minutes')
),
check_unordered=(False, True),
check_mixed=(False, True),
__fail_fast=True,
)
def test_historical_restrictions(self, date_offset, check_unordered):
def test_historical_restrictions(self, date_offset, check_unordered,
check_mixed):
"""
Test historical restrictions for both interday and intraday
restrictions, as well as restrictions defined in/not in order, for both
@@ -80,6 +82,19 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
else:
maybe_scramble = lambda r: r
if check_mixed:
def maybe_mix_assets(rl):
# In the final restrictions list containing restrictions of
# both assets, swap the first and last restrictions so that
# the restrictions are not contiguous by asset
tmp = rl[0]
rl[0] = rl[-1]
rl[-1] = tmp
return rl
else:
def maybe_mix_assets(rl):
return rl
def rdate(s):
"""Convert a date string into a restriction for that date."""
# Add date_offset to check that we handle intraday changes.
@@ -98,6 +113,7 @@ class RestrictionsTestCase(WithDataPortal, ZiplineTestCase):
Restriction(self.ASSET2, rdate('2011-01-07'), FROZEN),
])
)
all_restrictions = maybe_mix_assets(all_restrictions)
restrictions_by_asset = groupby(lambda r: r.asset, all_restrictions)
rl = HistoricalRestrictions(all_restrictions)
+3 -3
View File
@@ -3,9 +3,9 @@ from numpy import vectorize
from functools import partial, reduce
import operator
import pandas as pd
from six import with_metaclass
from six import with_metaclass, iteritems
from collections import namedtuple
from itertools import groupby
from toolz import groupby
from zipline.utils.enum import enum
from zipline.utils.numpy_utils import vectorized_is_element
@@ -171,7 +171,7 @@ class HistoricalRestrictions(Restrictions):
restrictions_for_asset, key=lambda x: x.effective_date
)
for asset, restrictions_for_asset
in groupby(restrictions, lambda x: x.asset)
in iteritems(groupby(lambda x: x.asset, restrictions))
}
def is_restricted(self, assets, dt):