From d11e62f9e61a2eb2c5ce9c8d437b3d0d9cae6511 Mon Sep 17 00:00:00 2001 From: Saeid Date: Thu, 21 Jan 2021 15:36:11 +0000 Subject: [PATCH] [RLlib] Fix problem in preprocessing nested MultiDiscrete (#13308) --- rllib/models/preprocessors.py | 2 +- rllib/models/tests/test_preprocessors.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 2b0bcb092..44312a807 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -174,7 +174,7 @@ class OneHotPreprocessor(Preprocessor): @override(Preprocessor) def write(self, observation: TensorType, array: np.ndarray, offset: int) -> None: - array[offset + observation] = 1 + array[offset:offset + self.size] = self.transform(observation) class NoPreprocessor(Preprocessor): diff --git a/rllib/models/tests/test_preprocessors.py b/rllib/models/tests/test_preprocessors.py index 5515b6fea..4ce7b73e7 100644 --- a/rllib/models/tests/test_preprocessors.py +++ b/rllib/models/tests/test_preprocessors.py @@ -71,6 +71,17 @@ class TestPreprocessors(unittest.TestCase): pp.transform(np.array([0, 1, 3])), [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + def test_nested_multidiscrete_one_hot_preprocessor(self): + space = Tuple((MultiDiscrete([2, 3, 4]), )) + pp = get_preprocessor(space)(space) + self.assertTrue(pp.shape == (9, )) + check( + pp.transform((np.array([1, 2, 0]), )), + [0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0]) + check( + pp.transform((np.array([0, 1, 3]), )), + [1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]) + if __name__ == "__main__": import pytest