mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 08:02:33 +08:00
[RLlib] Fix for issue https://github.com/ray-project/ray/issues/8191 (#8200)
Fix attribute error when missing exploration in Policy. Issue #8191
This commit is contained in:
+1
-1
@@ -1388,7 +1388,7 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "examples/multi_agent_custom_policy",
|
||||
tags = ["examples", "examples_M_xxx"],
|
||||
tags = ["examples", "examples_M"],
|
||||
size = "medium",
|
||||
srcs = ["examples/multi_agent_custom_policy.py"],
|
||||
)
|
||||
|
||||
@@ -152,8 +152,10 @@ class MultiAgentSampleBatchBuilder:
|
||||
post_batches[agent_id] = policy.postprocess_trajectory(
|
||||
pre_batch, other_batches, episode)
|
||||
# Call the Policy's Exploration's postprocess method.
|
||||
policy.exploration.postprocess_trajectory(
|
||||
policy, post_batches[agent_id], getattr(policy, "_sess", None))
|
||||
if getattr(policy, "exploration", None) is not None:
|
||||
policy.exploration.postprocess_trajectory(
|
||||
policy, post_batches[agent_id],
|
||||
getattr(policy, "_sess", None))
|
||||
|
||||
if log_once("after_post"):
|
||||
logger.info(
|
||||
|
||||
+12
-10
@@ -313,11 +313,12 @@ def _env_runner(worker, base_env, extra_batch_callback, policies,
|
||||
get_batch_builder, extra_batch_callback)
|
||||
# Call each policy's Exploration.on_episode_start method.
|
||||
for p in policies.values():
|
||||
p.exploration.on_episode_start(
|
||||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_start(
|
||||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
callbacks.on_episode_start(
|
||||
worker=worker,
|
||||
base_env=base_env,
|
||||
@@ -505,11 +506,12 @@ def _process_observations(worker, base_env, policies, batch_builder_pool,
|
||||
batch_builder_pool.append(episode.batch_builder)
|
||||
# Call each policy's Exploration.on_episode_end method.
|
||||
for p in policies.values():
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
if getattr(p, "exploration", None) is not None:
|
||||
p.exploration.on_episode_end(
|
||||
policy=p,
|
||||
environment=base_env,
|
||||
episode=episode,
|
||||
tf_sess=getattr(p, "_sess", None))
|
||||
# Call custom on_episode_end callback.
|
||||
callbacks.on_episode_end(
|
||||
worker=worker,
|
||||
|
||||
@@ -28,6 +28,8 @@ class Policy(metaclass=ABCMeta):
|
||||
Attributes:
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
exploration (Exploration): The exploration object to use for
|
||||
computing actions, or None.
|
||||
"""
|
||||
|
||||
@DeveloperAPI
|
||||
@@ -42,8 +44,6 @@ class Policy(metaclass=ABCMeta):
|
||||
observation_space (gym.Space): Observation space of the policy.
|
||||
action_space (gym.Space): Action space of the policy.
|
||||
config (dict): Policy-specific configuration data.
|
||||
exploration (Exploration): The exploration object to use for
|
||||
computing actions.
|
||||
"""
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
Reference in New Issue
Block a user