From 4e713152e99b8524df2cabd0e8faf5080ca03a20 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 27 Apr 2020 23:19:26 +0200 Subject: [PATCH] [RLlib] Fix for issue https://github.com/ray-project/ray/issues/8191 (#8200) Fix attribute error when missing exploration in Policy. Issue #8191 --- rllib/BUILD | 2 +- rllib/evaluation/sample_batch_builder.py | 6 ++++-- rllib/evaluation/sampler.py | 22 ++++++++++++---------- rllib/policy/policy.py | 4 ++-- 4 files changed, 19 insertions(+), 15 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 6c6704acc..de1064d0a 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1388,7 +1388,7 @@ py_test( py_test( name = "examples/multi_agent_custom_policy", - tags = ["examples", "examples_M_xxx"], + tags = ["examples", "examples_M"], size = "medium", srcs = ["examples/multi_agent_custom_policy.py"], ) diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 48b3e0331..580cbcbac 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -152,8 +152,10 @@ class MultiAgentSampleBatchBuilder: post_batches[agent_id] = policy.postprocess_trajectory( pre_batch, other_batches, episode) # Call the Policy's Exploration's postprocess method. - policy.exploration.postprocess_trajectory( - policy, post_batches[agent_id], getattr(policy, "_sess", None)) + if getattr(policy, "exploration", None) is not None: + policy.exploration.postprocess_trajectory( + policy, post_batches[agent_id], + getattr(policy, "_sess", None)) if log_once("after_post"): logger.info( diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index ee302cdb9..373f61551 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -313,11 +313,12 @@ def _env_runner(worker, base_env, extra_batch_callback, policies, get_batch_builder, extra_batch_callback) # Call each policy's Exploration.on_episode_start method. for p in policies.values(): - p.exploration.on_episode_start( - policy=p, - environment=base_env, - episode=episode, - tf_sess=getattr(p, "_sess", None)) + if getattr(p, "exploration", None) is not None: + p.exploration.on_episode_start( + policy=p, + environment=base_env, + episode=episode, + tf_sess=getattr(p, "_sess", None)) callbacks.on_episode_start( worker=worker, base_env=base_env, @@ -505,11 +506,12 @@ def _process_observations(worker, base_env, policies, batch_builder_pool, batch_builder_pool.append(episode.batch_builder) # Call each policy's Exploration.on_episode_end method. for p in policies.values(): - p.exploration.on_episode_end( - policy=p, - environment=base_env, - episode=episode, - tf_sess=getattr(p, "_sess", None)) + if getattr(p, "exploration", None) is not None: + p.exploration.on_episode_end( + policy=p, + environment=base_env, + episode=episode, + tf_sess=getattr(p, "_sess", None)) # Call custom on_episode_end callback. callbacks.on_episode_end( worker=worker, diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 660cadc6c..09fc96b5b 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -28,6 +28,8 @@ class Policy(metaclass=ABCMeta): Attributes: observation_space (gym.Space): Observation space of the policy. action_space (gym.Space): Action space of the policy. + exploration (Exploration): The exploration object to use for + computing actions, or None. """ @DeveloperAPI @@ -42,8 +44,6 @@ class Policy(metaclass=ABCMeta): observation_space (gym.Space): Observation space of the policy. action_space (gym.Space): Action space of the policy. config (dict): Policy-specific configuration data. - exploration (Exploration): The exploration object to use for - computing actions. """ self.observation_space = observation_space self.action_space = action_space