From e19f02a2ecb24baebddbeb17060d7b068c710e4d Mon Sep 17 00:00:00 2001 From: Jonathan Kamens Date: Tue, 10 Mar 2015 20:36:52 -0400 Subject: [PATCH] BUG: Handle all possible types of Security object __richcmp__ args A cython __richcmp__ function isn't allowed to assume that its first argument is the same as the type of the class to which it belongs, so our code needs to account for either of its two arguments being of the wrong type. Furthermore, the correct way for __richcmp__ to handle when it doesn't know how to do a comparison is to return NotImplemented. --- tests/test_security_object.py | 4 ++++ zipline/assets/_securities.pyx | 27 +++++++++++++++++---------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/tests/test_security_object.py b/tests/test_security_object.py index 1c02c822..d6605fb6 100644 --- a/tests/test_security_object.py +++ b/tests/test_security_object.py @@ -27,3 +27,7 @@ class TestSecurityRichCmp(TestCase): self.assertFalse(Security(3) > Security(4)) self.assertFalse(Security(4) > Security(4)) self.assertTrue(Security(5) > Security(4)) + + def test_type_mismatch(self): + self.assertIsNotNone(Security(3) < 'a') + self.assertIsNotNone('a' < Security(3)) diff --git a/zipline/assets/_securities.pyx b/zipline/assets/_securities.pyx index 8f0cd400..30f94401 100644 --- a/zipline/assets/_securities.pyx +++ b/zipline/assets/_securities.pyx @@ -78,7 +78,7 @@ cdef class Security: def __get__(self): return self.end_date - def __richcmp__(self, other, int op): + def __richcmp__(x, y, int op): """ Cython rich comparison method. This is used in place of various equality checkers in pure python. @@ -90,16 +90,23 @@ cdef class Security: > 4 >= 5 """ - cdef int other_as_int - if isinstance(other, Security): - other_as_int = other.sid - elif isinstance(other, int): - other_as_int = other - else: - retvals = [True, True, False, True, False, False] - return retvals[op] + cdef int x_as_int, y_as_int - compared = self.sid - other_as_int + if isinstance(x, Security): + x_as_int = x.sid + elif isinstance(x, int): + x_as_int = x + else: + return NotImplemented + + if isinstance(y, Security): + y_as_int = y.sid + elif isinstance(y, int): + y_as_int = y + else: + return NotImplemented + + compared = x_as_int - y_as_int # Handle == and != first because they're significantly more common # operations.