diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 5d847a230..2df5e51f6 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -132,7 +132,8 @@ class ModelCatalog(object): print("Observation shape is {}".format(obs_shape)) if env_name in cls._registered_preprocessor: - return cls._registered_preprocessor[env_name](options) + return cls._registered_preprocessor[env_name]( + env.observation_space, options) if obs_shape == (): print("Using one-hot preprocessor for discrete envs.") diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py index 671e88f72..f6760e8e1 100644 --- a/python/ray/rllib/test/test_catalog.py +++ b/python/ray/rllib/test/test_catalog.py @@ -3,7 +3,7 @@ from ray.rllib.models.preprocessors import Preprocessor class FakePreprocessor(Preprocessor): - def __init__(self, options): + def _init(self): pass