mirror of
https://github.com/wassname/catalyst.git
synced 2026-07-04 21:02:39 +08:00
TST: Add test for missing values in relabel.
This commit is contained in:
@@ -513,6 +513,58 @@ class ClassifierTestCase(BasePipelineTestCase):
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
@parameter_space(
|
||||
__fail_fast=True,
|
||||
missing_value=[None, 'M'],
|
||||
)
|
||||
def test_relabel_missing_value_interactions(self, missing_value):
|
||||
|
||||
mv = missing_value
|
||||
|
||||
class C(Classifier):
|
||||
inputs = ()
|
||||
dtype = categorical_dtype
|
||||
missing_value = mv
|
||||
window_length = 0
|
||||
|
||||
c = C()
|
||||
|
||||
def relabel_func(s):
|
||||
if s == 'B':
|
||||
return mv
|
||||
return ''.join([s, s])
|
||||
|
||||
raw = np.asarray(
|
||||
[['A', 'B', 'C', mv],
|
||||
[mv, 'A', 'B', 'C'],
|
||||
['C', mv, 'A', 'B'],
|
||||
['B', 'C', mv, 'A']],
|
||||
dtype=categorical_dtype,
|
||||
)
|
||||
data = LabelArray(raw, missing_value=mv)
|
||||
|
||||
expected_relabeled_raw = np.asarray(
|
||||
[['AA', mv, 'CC', mv],
|
||||
[mv, 'AA', mv, 'CC'],
|
||||
['CC', mv, 'AA', mv],
|
||||
[mv, 'CC', mv, 'AA']],
|
||||
dtype=categorical_dtype,
|
||||
)
|
||||
|
||||
terms = {
|
||||
'relabeled': c.relabel(relabel_func),
|
||||
}
|
||||
expected_results = {
|
||||
'relabeled': LabelArray(expected_relabeled_raw, missing_value=mv),
|
||||
}
|
||||
|
||||
self.check_terms(
|
||||
terms,
|
||||
expected_results,
|
||||
initial_workspace={c: data},
|
||||
mask=self.build_mask(self.ones_mask(shape=data.shape)),
|
||||
)
|
||||
|
||||
def test_relabel_int_classifier_not_yet_supported(self):
|
||||
class C(Classifier):
|
||||
inputs = ()
|
||||
|
||||
Reference in New Issue
Block a user