diff --git a/python/ray/rllib/models/preprocessors.py b/python/ray/rllib/models/preprocessors.py index e68ffbefc..dd133f2aa 100644 --- a/python/ray/rllib/models/preprocessors.py +++ b/python/ray/rllib/models/preprocessors.py @@ -30,6 +30,7 @@ class Preprocessor(object): self._obs_space = obs_space self._options = options or {} self.shape = self._init_shape(obs_space, options) + self._size = int(np.product(self.shape)) @PublicAPI def _init_shape(self, obs_space, options): @@ -41,10 +42,14 @@ class Preprocessor(object): """Returns the preprocessed observation.""" raise NotImplementedError + def write(self, observation, array, offset): + """Alternative to transform for more efficient flattening.""" + array[offset:offset + self._size] = self.transform(observation) + @property @PublicAPI def size(self): - return int(np.product(self.shape)) + return self._size @property @PublicAPI @@ -123,6 +128,10 @@ class OneHotPreprocessor(Preprocessor): arr[observation] = 1 return arr + @override(Preprocessor) + def write(self, observation, array, offset): + array[offset + observation] = 1 + class NoPreprocessor(Preprocessor): @override(Preprocessor) @@ -133,6 +142,11 @@ class NoPreprocessor(Preprocessor): def transform(self, observation): return observation + @override(Preprocessor) + def write(self, observation, array, offset): + array[offset:offset + self._size] = np.array( + observation, copy=False).ravel() + class TupleFlatteningPreprocessor(Preprocessor): """Preprocesses each tuple element, then flattens it all into a vector. @@ -155,11 +169,16 @@ class TupleFlatteningPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + array = np.zeros(self.shape) + self.write(observation, array, 0) + return array + + @override(Preprocessor) + def write(self, observation, array, offset): assert len(observation) == len(self.preprocessors), observation - return np.concatenate([ - np.reshape(p.transform(o), [p.size]) - for (o, p) in zip(observation, self.preprocessors) - ]) + for o, p in zip(observation, self.preprocessors): + p.write(o, array, offset) + offset += p.size class DictFlatteningPreprocessor(Preprocessor): @@ -182,14 +201,19 @@ class DictFlatteningPreprocessor(Preprocessor): @override(Preprocessor) def transform(self, observation): + array = np.zeros(self.shape) + self.write(observation, array, 0) + return array + + @override(Preprocessor) + def write(self, observation, array, offset): if not isinstance(observation, OrderedDict): observation = OrderedDict(sorted(list(observation.items()))) assert len(observation) == len(self.preprocessors), \ (len(observation), len(self.preprocessors)) - return np.concatenate([ - np.reshape(p.transform(o), [p.size]) - for (o, p) in zip(observation.values(), self.preprocessors) - ]) + for o, p in zip(observation.values(), self.preprocessors): + p.write(o, array, offset) + offset += p.size @PublicAPI diff --git a/python/ray/rllib/tests/test_catalog.py b/python/ray/rllib/tests/test_catalog.py index 9346e1064..fc9b71d2c 100644 --- a/python/ray/rllib/tests/test_catalog.py +++ b/python/ray/rllib/tests/test_catalog.py @@ -16,12 +16,12 @@ from ray.rllib.models.visionnet import VisionNetwork class CustomPreprocessor(Preprocessor): def _init_shape(self, obs_space, options): - return None + return [1] class CustomPreprocessor2(Preprocessor): def _init_shape(self, obs_space, options): - return None + return [1] class CustomModel(Model):