diff --git a/tests/test_labelarray.py b/tests/test_labelarray.py index c9df9a4f..c08a58e4 100644 --- a/tests/test_labelarray.py +++ b/tests/test_labelarray.py @@ -7,6 +7,7 @@ from toolz import take from zipline.lib.labelarray import LabelArray from zipline.testing import check_arrays, parameter_space, ZiplineTestCase +from zipline.testing.predicates import assert_equal from zipline.utils.compat import unicode @@ -341,6 +342,15 @@ class LabelArrayTestCase(ZiplineTestCase): check_arrays(arr, orig_arr) def test_narrow_code_storage(self): + def check_roundtrip(arr): + assert_equal( + arr.as_string_array(), + LabelArray( + arr.as_string_array(), + arr.missing_value, + ).as_string_array(), + ) + def create_categories(width, plus_one): length = int(width / 8) + plus_one return [ @@ -359,10 +369,12 @@ class LabelArrayTestCase(ZiplineTestCase): categories=categories, ) self.assertEqual(arr.itemsize, 1) + check_roundtrip(arr) # uint8 inference arr = LabelArray(categories, missing_value=categories[0]) self.assertEqual(arr.itemsize, 1) + check_roundtrip(arr) # just over uint8 categories = create_categories(8, plus_one=True) @@ -372,10 +384,12 @@ class LabelArrayTestCase(ZiplineTestCase): categories=categories, ) self.assertEqual(arr.itemsize, 2) + check_roundtrip(arr) # uint16 inference arr = LabelArray(categories, missing_value=categories[0]) self.assertEqual(arr.itemsize, 2) + check_roundtrip(arr) # fits in uint16 categories = create_categories(16, plus_one=False) @@ -384,10 +398,12 @@ class LabelArrayTestCase(ZiplineTestCase): categories=categories, ) self.assertEqual(arr.itemsize, 2) + check_roundtrip(arr) # uint16 inference arr = LabelArray(categories, missing_value=categories[0]) self.assertEqual(arr.itemsize, 2) + check_roundtrip(arr) # just over uint16 categories = create_categories(16, plus_one=True) @@ -397,10 +413,12 @@ class LabelArrayTestCase(ZiplineTestCase): categories=categories, ) self.assertEqual(arr.itemsize, 4) + check_roundtrip(arr) # uint32 inference arr = LabelArray(categories, missing_value=categories[0]) self.assertEqual(arr.itemsize, 4) + check_roundtrip(arr) # NOTE: we could do this for 32 and 64; however, no one has enough RAM # or time for that.