mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-03 20:55:01 +08:00
BUG: fix label array code dtype condense
This commit is contained in:
+45
-22
@@ -341,25 +341,30 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
arr[:] = orig_arr
|
||||
check_arrays(arr, orig_arr)
|
||||
|
||||
def test_narrow_code_storage(self):
|
||||
def check_roundtrip(arr):
|
||||
assert_equal(
|
||||
@staticmethod
|
||||
def check_roundtrip(arr):
|
||||
assert_equal(
|
||||
arr.as_string_array(),
|
||||
LabelArray(
|
||||
arr.as_string_array(),
|
||||
LabelArray(
|
||||
arr.as_string_array(),
|
||||
arr.missing_value,
|
||||
).as_string_array(),
|
||||
)
|
||||
arr.missing_value,
|
||||
).as_string_array(),
|
||||
)
|
||||
|
||||
def create_categories(width, plus_one):
|
||||
length = int(width / 8) + plus_one
|
||||
return [
|
||||
''.join(cs)
|
||||
for cs in take(
|
||||
2 ** width + plus_one,
|
||||
product([chr(c) for c in range(256)], repeat=length),
|
||||
)
|
||||
]
|
||||
@staticmethod
|
||||
def create_categories(width, plus_one):
|
||||
length = int(width / 8) + plus_one
|
||||
return [
|
||||
''.join(cs)
|
||||
for cs in take(
|
||||
2 ** width + plus_one,
|
||||
product([chr(c) for c in range(256)], repeat=length),
|
||||
)
|
||||
]
|
||||
|
||||
def test_narrow_code_storage(self):
|
||||
create_categories = self.create_categories
|
||||
check_roundtrip = self.check_roundtrip
|
||||
|
||||
# uint8
|
||||
categories = create_categories(8, plus_one=False)
|
||||
@@ -386,11 +391,6 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# uint16 inference
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
self.assertEqual(arr.itemsize, 2)
|
||||
check_roundtrip(arr)
|
||||
|
||||
# fits in uint16
|
||||
categories = create_categories(16, plus_one=False)
|
||||
arr = LabelArray(
|
||||
@@ -422,3 +422,26 @@ class LabelArrayTestCase(ZiplineTestCase):
|
||||
|
||||
# NOTE: we could do this for 32 and 64; however, no one has enough RAM
|
||||
# or time for that.
|
||||
|
||||
def test_narrow_condense_back_to_valid_size(self):
|
||||
categories = ['a'] * (2 ** 8 + 1)
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
assert_equal(arr.itemsize, 1)
|
||||
self.check_roundtrip(arr)
|
||||
|
||||
# longer than int16 but still fits when deduped
|
||||
categories = self.create_categories(16, plus_one=False)
|
||||
categories.append(categories[0])
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
assert_equal(arr.itemsize, 2)
|
||||
self.check_roundtrip(arr)
|
||||
|
||||
def manual_narrow_condense_back_to_valid_size_slow(self):
|
||||
"""This test is really slow so we don't want it run by default.
|
||||
"""
|
||||
# tests that we don't try to create an 'int24' (which is meaningless)
|
||||
categories = self.create_categories(24, plus_one=False)
|
||||
categories.append(categories[0])
|
||||
arr = LabelArray(categories, missing_value=categories[0])
|
||||
assert_equal(arr.itemsize, 4)
|
||||
self.check_roundtrip(arr)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Factorization algorithms.
|
||||
"""
|
||||
from libc.math cimport floor, log
|
||||
from libc.math cimport log
|
||||
cimport numpy as np
|
||||
import numpy as np
|
||||
|
||||
@@ -144,6 +144,9 @@ cdef factorize_strings_impl(np.ndarray[object] values,
|
||||
return codes, categories_array, reverse_categories
|
||||
|
||||
|
||||
cdef list _int_sizes = [1, 1, 2, 4, 4, 8, 8, 8, 8]
|
||||
|
||||
|
||||
cpdef factorize_strings(np.ndarray[object] values,
|
||||
object missing_value,
|
||||
int sort):
|
||||
@@ -209,11 +212,22 @@ cpdef factorize_strings(np.ndarray[object] values,
|
||||
# unreachable
|
||||
raise ValueError('nvalues larger than uint64')
|
||||
|
||||
if len(categories_array) < 2 ** codes.dtype.itemsize:
|
||||
# if there are a lot of duplicates in the values we may need to shrink
|
||||
# the width of the ``codes`` array
|
||||
codes = codes.astype(unsigned_int_dtype_with_size_in_bytes(
|
||||
floor(log2(len(categories_array))),
|
||||
))
|
||||
length = len(categories_array)
|
||||
if length < 1:
|
||||
# lim x -> 0 log2(x) == -infinity so we floor at uint8
|
||||
narrowest_dtype = np.uint8
|
||||
else:
|
||||
# The number of bits required to hold the codes up to ``length`` is
|
||||
# log2(length). The number of bits per bytes is 8. We cannot have
|
||||
# fractional bytes so we need to round up. Finally, we can only have
|
||||
# integers with widths 1, 2, 4, or 8 so so we need to round up to the
|
||||
# next value by looking up the next largest size in ``_int_sizes``.
|
||||
narrowest_dtype = unsigned_int_dtype_with_size_in_bytes(
|
||||
_int_sizes[int(np.ceil(log2(length) / 8))]
|
||||
)
|
||||
|
||||
if codes.dtype != narrowest_dtype:
|
||||
# condense the codes down to the narrowest dtype possible
|
||||
codes = codes.astype(narrowest_dtype)
|
||||
|
||||
return codes, categories_array, reverse_categories
|
||||
|
||||
Reference in New Issue
Block a user