mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 00:33:03 +08:00
[rllib] Make observation filter optional (#940)
* make observation filter optional * fix linting
This commit is contained in:
committed by
Robert Nishihara
parent
413140df38
commit
6601bb5f9e
@@ -55,6 +55,8 @@ DEFAULT_CONFIG = {
|
||||
"kl_target": 0.01,
|
||||
# Config params to pass to the model
|
||||
"model": {"free_log_std": False},
|
||||
# Which observation filter to apply to the observation
|
||||
"observation_filter": "MeanStdFilter",
|
||||
# If >1, adds frameskip
|
||||
"extra_frameskip": 1,
|
||||
# Number of timesteps collected in each outer loop
|
||||
|
||||
@@ -16,7 +16,7 @@ from ray.rllib.parallel import LocalSyncParallelOptimizer
|
||||
from ray.rllib.models import ModelCatalog
|
||||
from ray.rllib.ppo.env import BatchedEnv
|
||||
from ray.rllib.ppo.loss import ProximalPolicyLoss
|
||||
from ray.rllib.ppo.filter import MeanStdFilter
|
||||
from ray.rllib.ppo.filter import NoFilter, MeanStdFilter
|
||||
from ray.rllib.ppo.rollout import (
|
||||
rollouts, add_return_values, add_advantage_values)
|
||||
from ray.rllib.ppo.utils import flatten, concatenate
|
||||
@@ -140,8 +140,14 @@ class Runner(object):
|
||||
self.common_policy = self.par_opt.get_common_loss()
|
||||
self.variables = ray.experimental.TensorFlowVariables(
|
||||
self.common_policy.loss, self.sess)
|
||||
self.observation_filter = MeanStdFilter(
|
||||
self.preprocessor_shape, clip=None)
|
||||
if config["observation_filter"] == "MeanStdFilter":
|
||||
self.observation_filter = MeanStdFilter(
|
||||
self.preprocessor_shape, clip=None)
|
||||
elif config["observation_filter"] == "NoFilter":
|
||||
self.observation_filter = NoFilter()
|
||||
else:
|
||||
raise Exception("Unknown observation_filter: " +
|
||||
str(config["observation_filter"]))
|
||||
self.reward_filter = MeanStdFilter((), clip=5.0)
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user