BUG: fix label array code dtype condense

This commit is contained in:
Joe Jevnik
2017-03-08 16:41:12 -05:00
committed by Joe Jevnik
parent fcfc06ef0a
commit 153f6636c7
2 changed files with 66 additions and 29 deletions
+45 -22
View File
@@ -341,25 +341,30 @@ class LabelArrayTestCase(ZiplineTestCase):
arr[:] = orig_arr
check_arrays(arr, orig_arr)
def test_narrow_code_storage(self):
def check_roundtrip(arr):
assert_equal(
@staticmethod
def check_roundtrip(arr):
assert_equal(
arr.as_string_array(),
LabelArray(
arr.as_string_array(),
LabelArray(
arr.as_string_array(),
arr.missing_value,
).as_string_array(),
)
arr.missing_value,
).as_string_array(),
)
def create_categories(width, plus_one):
length = int(width / 8) + plus_one
return [
''.join(cs)
for cs in take(
2 ** width + plus_one,
product([chr(c) for c in range(256)], repeat=length),
)
]
@staticmethod
def create_categories(width, plus_one):
length = int(width / 8) + plus_one
return [
''.join(cs)
for cs in take(
2 ** width + plus_one,
product([chr(c) for c in range(256)], repeat=length),
)
]
def test_narrow_code_storage(self):
create_categories = self.create_categories
check_roundtrip = self.check_roundtrip
# uint8
categories = create_categories(8, plus_one=False)
@@ -386,11 +391,6 @@ class LabelArrayTestCase(ZiplineTestCase):
self.assertEqual(arr.itemsize, 2)
check_roundtrip(arr)
# uint16 inference
arr = LabelArray(categories, missing_value=categories[0])
self.assertEqual(arr.itemsize, 2)
check_roundtrip(arr)
# fits in uint16
categories = create_categories(16, plus_one=False)
arr = LabelArray(
@@ -422,3 +422,26 @@ class LabelArrayTestCase(ZiplineTestCase):
# NOTE: we could do this for 32 and 64; however, no one has enough RAM
# or time for that.
def test_narrow_condense_back_to_valid_size(self):
categories = ['a'] * (2 ** 8 + 1)
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 1)
self.check_roundtrip(arr)
# longer than int16 but still fits when deduped
categories = self.create_categories(16, plus_one=False)
categories.append(categories[0])
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 2)
self.check_roundtrip(arr)
def manual_narrow_condense_back_to_valid_size_slow(self):
"""This test is really slow so we don't want it run by default.
"""
# tests that we don't try to create an 'int24' (which is meaningless)
categories = self.create_categories(24, plus_one=False)
categories.append(categories[0])
arr = LabelArray(categories, missing_value=categories[0])
assert_equal(arr.itemsize, 4)
self.check_roundtrip(arr)
+21 -7
View File
@@ -1,7 +1,7 @@
"""
Factorization algorithms.
"""
from libc.math cimport floor, log
from libc.math cimport log
cimport numpy as np
import numpy as np
@@ -144,6 +144,9 @@ cdef factorize_strings_impl(np.ndarray[object] values,
return codes, categories_array, reverse_categories
cdef list _int_sizes = [1, 1, 2, 4, 4, 8, 8, 8, 8]
cpdef factorize_strings(np.ndarray[object] values,
object missing_value,
int sort):
@@ -209,11 +212,22 @@ cpdef factorize_strings(np.ndarray[object] values,
# unreachable
raise ValueError('nvalues larger than uint64')
if len(categories_array) < 2 ** codes.dtype.itemsize:
# if there are a lot of duplicates in the values we may need to shrink
# the width of the ``codes`` array
codes = codes.astype(unsigned_int_dtype_with_size_in_bytes(
floor(log2(len(categories_array))),
))
length = len(categories_array)
if length < 1:
# lim x -> 0 log2(x) == -infinity so we floor at uint8
narrowest_dtype = np.uint8
else:
# The number of bits required to hold the codes up to ``length`` is
# log2(length). The number of bits per bytes is 8. We cannot have
# fractional bytes so we need to round up. Finally, we can only have
# integers with widths 1, 2, 4, or 8 so so we need to round up to the
# next value by looking up the next largest size in ``_int_sizes``.
narrowest_dtype = unsigned_int_dtype_with_size_in_bytes(
_int_sizes[int(np.ceil(log2(length) / 8))]
)
if codes.dtype != narrowest_dtype:
# condense the codes down to the narrowest dtype possible
codes = codes.astype(narrowest_dtype)
return codes, categories_array, reverse_categories