mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:10:40 +08:00
[rllib] user defined preprocessor (#985)
* add register_preprocessor to ModelCatalog * add pytest * make staticmethod a classmethod * update * install gym on travis * fix linting * fix
This commit is contained in:
committed by
Robert Nishihara
parent
29ac95d87a
commit
73f40bd844
@@ -32,6 +32,8 @@ class ModelCatalog(object):
|
||||
action_op = dist.sample()
|
||||
"""
|
||||
|
||||
_registered_preprocessor = dict()
|
||||
|
||||
@staticmethod
|
||||
def get_action_dist(action_space, dist_type=None):
|
||||
"""Returns action distribution class and size for the given action space.
|
||||
@@ -76,8 +78,8 @@ class ModelCatalog(object):
|
||||
|
||||
return FullyConnectedNetwork(inputs, num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessor(env_name, obs_shape, options=dict()):
|
||||
@classmethod
|
||||
def get_preprocessor(cls, env_name, obs_shape, options=dict()):
|
||||
"""Returns a suitable processor for the given environment.
|
||||
|
||||
Args:
|
||||
@@ -97,6 +99,9 @@ class ModelCatalog(object):
|
||||
"Unknown config key `{}`, all keys: {}".format(
|
||||
k, MODEL_CONFIGS))
|
||||
|
||||
if env_name in cls._registered_preprocessor:
|
||||
return cls._registered_preprocessor[env_name](options)
|
||||
|
||||
if obs_shape == ATARI_OBS_SHAPE:
|
||||
print("Assuming Atari pixel env, using AtariPixelPreprocessor.")
|
||||
return AtariPixelPreprocessor(options)
|
||||
@@ -106,3 +111,15 @@ class ModelCatalog(object):
|
||||
|
||||
print("Non-atari env, not using any observation preprocessor.")
|
||||
return NoPreprocessor(options)
|
||||
|
||||
@classmethod
|
||||
def register_preprocessor(cls, env_name, preprocessor_class):
|
||||
"""Register a preprocessor class for a specific environment.
|
||||
|
||||
Args:
|
||||
env_name (str): Name of the gym env we register the
|
||||
preprocessor for.
|
||||
preprocessor_class (type):
|
||||
Python class of the distribution.
|
||||
"""
|
||||
cls._registered_preprocessor[env_name] = preprocessor_class
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.models.preprocessors import Preprocessor
|
||||
|
||||
|
||||
class FakePreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, options):
|
||||
pass
|
||||
|
||||
|
||||
def test_preprocessor():
|
||||
ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor)
|
||||
preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1))
|
||||
assert type(preprocessor) == FakePreprocessor
|
||||
@@ -86,6 +86,7 @@ setup(name="ray",
|
||||
"click",
|
||||
"colorama",
|
||||
"psutil",
|
||||
"pytest",
|
||||
"redis",
|
||||
"cloudpickle >= 0.2.2",
|
||||
# The six module is required by pyarrow.
|
||||
|
||||
Reference in New Issue
Block a user