[rllib] Make observation filter optional (#940)

* make observation filter optional

* fix linting
This commit is contained in:
Philipp Moritz
2017-09-14 17:37:19 -07:00
committed by Robert Nishihara
parent 413140df38
commit 6601bb5f9e
2 changed files with 11 additions and 3 deletions
+2
View File
@@ -55,6 +55,8 @@ DEFAULT_CONFIG = {
"kl_target": 0.01,
# Config params to pass to the model
"model": {"free_log_std": False},
# Which observation filter to apply to the observation
"observation_filter": "MeanStdFilter",
# If >1, adds frameskip
"extra_frameskip": 1,
# Number of timesteps collected in each outer loop
+9 -3
View File
@@ -16,7 +16,7 @@ from ray.rllib.parallel import LocalSyncParallelOptimizer
from ray.rllib.models import ModelCatalog
from ray.rllib.ppo.env import BatchedEnv
from ray.rllib.ppo.loss import ProximalPolicyLoss
from ray.rllib.ppo.filter import MeanStdFilter
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
from ray.rllib.ppo.rollout import (
rollouts, add_return_values, add_advantage_values)
from ray.rllib.ppo.utils import flatten, concatenate
@@ -140,8 +140,14 @@ class Runner(object):
self.common_policy = self.par_opt.get_common_loss()
self.variables = ray.experimental.TensorFlowVariables(
self.common_policy.loss, self.sess)
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
if config["observation_filter"] == "MeanStdFilter":
self.observation_filter = MeanStdFilter(
self.preprocessor_shape, clip=None)
elif config["observation_filter"] == "NoFilter":
self.observation_filter = NoFilter()
else:
raise Exception("Unknown observation_filter: " +
str(config["observation_filter"]))
self.reward_filter = MeanStdFilter((), clip=5.0)
self.sess.run(tf.global_variables_initializer())