From 6601bb5f9e73449667192926229ac4c6b1e083f9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 14 Sep 2017 17:37:19 -0700 Subject: [PATCH] [rllib] Make observation filter optional (#940) * make observation filter optional * fix linting --- python/ray/rllib/ppo/ppo.py | 2 ++ python/ray/rllib/ppo/runner.py | 12 +++++++++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index c53c49c31..212150110 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -55,6 +55,8 @@ DEFAULT_CONFIG = { "kl_target": 0.01, # Config params to pass to the model "model": {"free_log_std": False}, + # Which observation filter to apply to the observation + "observation_filter": "MeanStdFilter", # If >1, adds frameskip "extra_frameskip": 1, # Number of timesteps collected in each outer loop diff --git a/python/ray/rllib/ppo/runner.py b/python/ray/rllib/ppo/runner.py index 444f63d49..eb395f0cd 100644 --- a/python/ray/rllib/ppo/runner.py +++ b/python/ray/rllib/ppo/runner.py @@ -16,7 +16,7 @@ from ray.rllib.parallel import LocalSyncParallelOptimizer from ray.rllib.models import ModelCatalog from ray.rllib.ppo.env import BatchedEnv from ray.rllib.ppo.loss import ProximalPolicyLoss -from ray.rllib.ppo.filter import MeanStdFilter +from ray.rllib.ppo.filter import NoFilter, MeanStdFilter from ray.rllib.ppo.rollout import ( rollouts, add_return_values, add_advantage_values) from ray.rllib.ppo.utils import flatten, concatenate @@ -140,8 +140,14 @@ class Runner(object): self.common_policy = self.par_opt.get_common_loss() self.variables = ray.experimental.TensorFlowVariables( self.common_policy.loss, self.sess) - self.observation_filter = MeanStdFilter( - self.preprocessor_shape, clip=None) + if config["observation_filter"] == "MeanStdFilter": + self.observation_filter = MeanStdFilter( + self.preprocessor_shape, clip=None) + elif config["observation_filter"] == "NoFilter": + self.observation_filter = NoFilter() + else: + raise Exception("Unknown observation_filter: " + + str(config["observation_filter"])) self.reward_filter = MeanStdFilter((), clip=5.0) self.sess.run(tf.global_variables_initializer())