Fix attribute error when missing exploration in Policy.
Issue #8191
This commit is contained in:
Sven Mika
2020-04-27 23:19:26 +02:00
committed by GitHub
parent 48250217ac
commit 4e713152e9
4 changed files with 19 additions and 15 deletions
+1 -1
View File
@@ -1388,7 +1388,7 @@ py_test(
py_test(
name = "examples/multi_agent_custom_policy",
tags = ["examples", "examples_M_xxx"],
tags = ["examples", "examples_M"],
size = "medium",
srcs = ["examples/multi_agent_custom_policy.py"],
)
+4 -2
View File
@@ -152,8 +152,10 @@ class MultiAgentSampleBatchBuilder:
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches, episode)
# Call the Policy's Exploration's postprocess method.
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id], getattr(policy, "_sess", None))
if getattr(policy, "exploration", None) is not None:
policy.exploration.postprocess_trajectory(
policy, post_batches[agent_id],
getattr(policy, "_sess", None))
if log_once("after_post"):
logger.info(
+12 -10
View File
@@ -313,11 +313,12 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
get_batch_builder, extra_batch_callback)
# Call each policy's Exploration.on_episode_start method.
for p in policies.values():
p.exploration.on_episode_start(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_start(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
callbacks.on_episode_start(
worker=worker,
base_env=base_env,
@@ -505,11 +506,12 @@ def _process_observations(worker, base_env, policies, batch_builder_pool,
batch_builder_pool.append(episode.batch_builder)
# Call each policy's Exploration.on_episode_end method.
for p in policies.values():
p.exploration.on_episode_end(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
if getattr(p, "exploration", None) is not None:
p.exploration.on_episode_end(
policy=p,
environment=base_env,
episode=episode,
tf_sess=getattr(p, "_sess", None))
# Call custom on_episode_end callback.
callbacks.on_episode_end(
worker=worker,
+2 -2
View File
@@ -28,6 +28,8 @@ class Policy(metaclass=ABCMeta):
Attributes:
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
exploration (Exploration): The exploration object to use for
computing actions, or None.
"""
@DeveloperAPI
@@ -42,8 +44,6 @@ class Policy(metaclass=ABCMeta):
observation_space (gym.Space): Observation space of the policy.
action_space (gym.Space): Action space of the policy.
config (dict): Policy-specific configuration data.
exploration (Exploration): The exploration object to use for
computing actions.
"""
self.observation_space = observation_space
self.action_space = action_space