mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-01 17:26:10 +08:00
TST: add roundtrip check
This commit is contained in:
@@ -7,6 +7,7 @@ from toolz import take
|
||||
|
||||
from zipline.lib.labelarray import LabelArray
|
||||
from zipline.testing import check_arrays, parameter_space, ZiplineTestCase
|
||||
from zipline.testing.predicates import assert_equal
|
||||
from zipline.utils.compat import unicode
|
||||
|
||||
|
||||
@@ -341,6 +342,15 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
check_arrays(arr, orig_arr)
|
||||
|
||||
def test_narrow_code_storage(self):
|
||||
def check_roundtrip(arr):
|
||||
assert_equal(
|
||||
arr.as_string_array(),
|
||||
LabelArray(
|
||||
arr.as_string_array(),
|
||||
arr.missing_value,
|
||||
).as_string_array(),
|
||||
)
|
||||
|
||||
def create_categories(width, plus_one):
|
||||
length = int(width / 8) + plus_one
|
||||
return [
|
||||
@@ -359,10 +369,12 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
categories=categories,
|
||||
)
|
||||
self.assertEqual(arr.itemsize, 1)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# uint8 inference
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
self.assertEqual(arr.itemsize, 1)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# just over uint8
|
||||
categories = create_categories(8, plus_one=True)
|
||||
@@ -372,10 +384,12 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
categories=categories,
|
||||
)
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# uint16 inference
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# fits in uint16
|
||||
categories = create_categories(16, plus_one=False)
|
||||
@@ -384,10 +398,12 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
categories=categories,
|
||||
)
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# uint16 inference
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# just over uint16
|
||||
categories = create_categories(16, plus_one=True)
|
||||
@@ -397,10 +413,12 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
categories=categories,
|
||||
)
|
||||
self.assertEqual(arr.itemsize, 4)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# uint32 inference
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
self.assertEqual(arr.itemsize, 4)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# NOTE: we could do this for 32 and 64; however, no one has enough RAM
|
||||
# or time for that.
|
||||
|
||||
Reference in New Issue
Block a user