diff --git a/tests/pipeline/test_classifier.py b/tests/pipeline/test_classifier.py index 5bfc31b0..fa47df87 100644 --- a/tests/pipeline/test_classifier.py +++ b/tests/pipeline/test_classifier.py @@ -513,6 +513,58 @@ class ClassifierTestCase(BasePipelineTestCase): mask=self.build_mask(self.ones_mask(shape=data.shape)), ) + @parameter_space( + __fail_fast=True, + missing_value=[None, 'M'], + ) + def test_relabel_missing_value_interactions(self, missing_value): + + mv = missing_value + + class C(Classifier): + inputs = () + dtype = categorical_dtype + missing_value = mv + window_length = 0 + + c = C() + + def relabel_func(s): + if s == 'B': + return mv + return ''.join([s, s]) + + raw = np.asarray( + [['A', 'B', 'C', mv], + [mv, 'A', 'B', 'C'], + ['C', mv, 'A', 'B'], + ['B', 'C', mv, 'A']], + dtype=categorical_dtype, + ) + data = LabelArray(raw, missing_value=mv) + + expected_relabeled_raw = np.asarray( + [['AA', mv, 'CC', mv], + [mv, 'AA', mv, 'CC'], + ['CC', mv, 'AA', mv], + [mv, 'CC', mv, 'AA']], + dtype=categorical_dtype, + ) + + terms = { + 'relabeled': c.relabel(relabel_func), + } + expected_results = { + 'relabeled': LabelArray(expected_relabeled_raw, missing_value=mv), + } + + self.check_terms( + terms, + expected_results, + initial_workspace={c: data}, + mask=self.build_mask(self.ones_mask(shape=data.shape)), + ) + def test_relabel_int_classifier_not_yet_supported(self): class C(Classifier): inputs = ()