[rllib] user defined preprocessor (#985)

* add register_preprocessor to ModelCatalog

* add pytest

* make staticmethod a classmethod

* update

* install gym on travis

* fix linting

* fix
This commit is contained in:
Philipp Moritz
2017-09-16 15:53:19 -07:00
committed by Robert Nishihara
parent 29ac95d87a
commit 73f40bd844
5 changed files with 40 additions and 6 deletions
+19 -2
View File
@@ -32,6 +32,8 @@ class ModelCatalog(object):
action_op = dist.sample()
"""
_registered_preprocessor = dict()
@staticmethod
def get_action_dist(action_space, dist_type=None):
"""Returns action distribution class and size for the given action space.
@@ -76,8 +78,8 @@ class ModelCatalog(object):
return FullyConnectedNetwork(inputs, num_outputs, options)
@staticmethod
def get_preprocessor(env_name, obs_shape, options=dict()):
@classmethod
def get_preprocessor(cls, env_name, obs_shape, options=dict()):
"""Returns a suitable processor for the given environment.
Args:
@@ -97,6 +99,9 @@ class ModelCatalog(object):
"Unknown config key `{}`, all keys: {}".format(
k, MODEL_CONFIGS))
if env_name in cls._registered_preprocessor:
return cls._registered_preprocessor[env_name](options)
if obs_shape == ATARI_OBS_SHAPE:
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
return AtariPixelPreprocessor(options)
@@ -106,3 +111,15 @@ class ModelCatalog(object):
print("Non-atari env, not using any observation preprocessor.")
return NoPreprocessor(options)
@classmethod
def register_preprocessor(cls, env_name, preprocessor_class):
"""Register a preprocessor class for a specific environment.
Args:
env_name (str): Name of the gym env we register the
preprocessor for.
preprocessor_class (type):
Python class of the distribution.
"""
cls._registered_preprocessor[env_name] = preprocessor_class
+14
View File
@@ -0,0 +1,14 @@
from ray.rllib.models import ModelCatalog
from ray.rllib.models.preprocessors import Preprocessor
class FakePreprocessor(Preprocessor):
def __init__(self, options):
pass
def test_preprocessor():
ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor)
preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1))
assert type(preprocessor) == FakePreprocessor
+1
View File
@@ -86,6 +86,7 @@ setup(name="ray",
"click",
"colorama",
"psutil",
"pytest",
"redis",
"cloudpickle >= 0.2.2",
# The six module is required by pyarrow.