diff --git a/.travis.yml b/.travis.yml index 3a3f7aac3..619284ed7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -117,3 +117,5 @@ script: - python test/component_failures_test.py - python test/multi_node_test.py - python test/recursion_test.py + + - python -m pytest python/ray/rllib/test/test_catalog.py diff --git a/.travis/install-dependencies.sh b/.travis/install-dependencies.sh index 2d056de9a..cba8a1a47 100755 --- a/.travis/install-dependencies.sh +++ b/.travis/install-dependencies.sh @@ -24,7 +24,7 @@ if [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-Linux-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow flatbuffers + pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow gym flatbuffers elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then sudo apt-get update sudo apt-get install -y cmake pkg-config python-dev python-numpy build-essential autoconf curl libtool libboost-dev libboost-filesystem-dev libboost-system-dev unzip @@ -32,7 +32,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "linux" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow flatbuffers + pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow gym flatbuffers elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -48,7 +48,7 @@ elif [[ "$PYTHON" == "2.7" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda2-latest-MacOSX-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow flatbuffers + pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow gym flatbuffers elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then # check that brew is installed which -s brew @@ -64,7 +64,7 @@ elif [[ "$PYTHON" == "3.5" ]] && [[ "$platform" == "macosx" ]]; then wget https://repo.continuum.io/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -O miniconda.sh bash miniconda.sh -b -p $HOME/miniconda export PATH="$HOME/miniconda/bin:$PATH" - pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow flatbuffers + pip install numpy cloudpickle==0.3.0 cython cmake funcsigs click colorama psutil redis tensorflow gym flatbuffers elif [[ "$LINT" == "1" ]]; then sudo apt-get update sudo apt-get install -y cmake build-essential autoconf curl libtool libboost-dev libboost-filesystem-dev libboost-system-dev unzip diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index ec1332747..f5df8100b 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -32,6 +32,8 @@ class ModelCatalog(object): action_op = dist.sample() """ + _registered_preprocessor = dict() + @staticmethod def get_action_dist(action_space, dist_type=None): """Returns action distribution class and size for the given action space. @@ -76,8 +78,8 @@ class ModelCatalog(object): return FullyConnectedNetwork(inputs, num_outputs, options) - @staticmethod - def get_preprocessor(env_name, obs_shape, options=dict()): + @classmethod + def get_preprocessor(cls, env_name, obs_shape, options=dict()): """Returns a suitable processor for the given environment. Args: @@ -97,6 +99,9 @@ class ModelCatalog(object): "Unknown config key `{}`, all keys: {}".format( k, MODEL_CONFIGS)) + if env_name in cls._registered_preprocessor: + return cls._registered_preprocessor[env_name](options) + if obs_shape == ATARI_OBS_SHAPE: print("Assuming Atari pixel env, using AtariPixelPreprocessor.") return AtariPixelPreprocessor(options) @@ -106,3 +111,15 @@ class ModelCatalog(object): print("Non-atari env, not using any observation preprocessor.") return NoPreprocessor(options) + + @classmethod + def register_preprocessor(cls, env_name, preprocessor_class): + """Register a preprocessor class for a specific environment. + + Args: + env_name (str): Name of the gym env we register the + preprocessor for. + preprocessor_class (type): + Python class of the distribution. + """ + cls._registered_preprocessor[env_name] = preprocessor_class diff --git a/python/ray/rllib/test/test_catalog.py b/python/ray/rllib/test/test_catalog.py new file mode 100644 index 000000000..e229e2616 --- /dev/null +++ b/python/ray/rllib/test/test_catalog.py @@ -0,0 +1,14 @@ +from ray.rllib.models import ModelCatalog +from ray.rllib.models.preprocessors import Preprocessor + + +class FakePreprocessor(Preprocessor): + + def __init__(self, options): + pass + + +def test_preprocessor(): + ModelCatalog.register_preprocessor("FakeEnv-v0", FakePreprocessor) + preprocessor = ModelCatalog.get_preprocessor("FakeEnv-v0", (1, 1)) + assert type(preprocessor) == FakePreprocessor diff --git a/python/setup.py b/python/setup.py index 2684b55be..49533a8ef 100644 --- a/python/setup.py +++ b/python/setup.py @@ -86,6 +86,7 @@ setup(name="ray", "click", "colorama", "psutil", + "pytest", "redis", "cloudpickle >= 0.2.2", # The six module is required by pyarrow.