diff --git a/rllib/models/preprocessors.py b/rllib/models/preprocessors.py index 44312a807..0abfb8658 100644 --- a/rllib/models/preprocessors.py +++ b/rllib/models/preprocessors.py @@ -140,7 +140,7 @@ class AtariRamPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation: TensorType) -> np.ndarray: self.check_shape(observation) - return (observation - 128) / 128 + return (observation.astype("float32") - 128) / 128 class OneHotPreprocessor(Preprocessor):