mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[rllib] Support continuous action distributions in IMPALA/APPO (#4771)
This commit is contained in:
@@ -34,6 +34,7 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
tf = try_import_tf()
|
||||
@@ -48,12 +49,15 @@ VTraceFromLogitsReturns = collections.namedtuple("VTraceFromLogitsReturns", [
|
||||
VTraceReturns = collections.namedtuple("VTraceReturns", "vs pg_advantages")
|
||||
|
||||
|
||||
def log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
return multi_log_probs_from_logits_and_actions([policy_logits],
|
||||
[actions])[0]
|
||||
def log_probs_from_logits_and_actions(policy_logits,
|
||||
actions,
|
||||
dist_class=Categorical):
|
||||
return multi_log_probs_from_logits_and_actions([policy_logits], [actions],
|
||||
dist_class)[0]
|
||||
|
||||
|
||||
def multi_log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
def multi_log_probs_from_logits_and_actions(policy_logits, actions,
|
||||
dist_class):
|
||||
"""Computes action log-probs from policy logits and actions.
|
||||
|
||||
In the notation used throughout documentation and comments, T refers to the
|
||||
@@ -68,11 +72,11 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
...,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
with un-normalized log-probabilities parameterizing a softmax policy.
|
||||
actions: A list with length of ACTION_SPACE of int32
|
||||
actions: A list with length of ACTION_SPACE of
|
||||
tensors of shapes
|
||||
[T, B],
|
||||
[T, B, ...],
|
||||
...,
|
||||
[T, B]
|
||||
[T, B, ...]
|
||||
with actions.
|
||||
|
||||
Returns:
|
||||
@@ -87,8 +91,16 @@ def multi_log_probs_from_logits_and_actions(policy_logits, actions):
|
||||
|
||||
log_probs = []
|
||||
for i in range(len(policy_logits)):
|
||||
log_probs.append(-tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=policy_logits[i], labels=actions[i]))
|
||||
p_shape = tf.shape(policy_logits[i])
|
||||
a_shape = tf.shape(actions[i])
|
||||
policy_logits_flat = tf.reshape(policy_logits[i],
|
||||
tf.concat([[-1], p_shape[2:]], axis=0))
|
||||
actions_flat = tf.reshape(actions[i],
|
||||
tf.concat([[-1], a_shape[2:]], axis=0))
|
||||
log_probs.append(
|
||||
tf.reshape(
|
||||
dist_class(policy_logits_flat).logp(actions_flat),
|
||||
a_shape[:2]))
|
||||
|
||||
return log_probs
|
||||
|
||||
@@ -100,6 +112,7 @@ def from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class=Categorical,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name="vtrace_from_logits"):
|
||||
@@ -111,6 +124,7 @@ def from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold,
|
||||
name=name)
|
||||
@@ -133,6 +147,7 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
clip_rho_threshold=1.0,
|
||||
clip_pg_rho_threshold=1.0,
|
||||
name="vtrace_from_logits"):
|
||||
@@ -168,11 +183,11 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
[T, B, ACTION_SPACE[-1]]
|
||||
with un-normalized log-probabilities parameterizing the softmax target
|
||||
policy.
|
||||
actions: A list with length of ACTION_SPACE of int32
|
||||
actions: A list with length of ACTION_SPACE of
|
||||
tensors of shapes
|
||||
[T, B],
|
||||
[T, B, ...],
|
||||
...,
|
||||
[T, B]
|
||||
[T, B, ...]
|
||||
with actions sampled from the behaviour policy.
|
||||
discounts: A float32 tensor of shape [T, B] with the discount encountered
|
||||
when following the behaviour policy.
|
||||
@@ -182,6 +197,7 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
wrt. the target policy.
|
||||
bootstrap_value: A float32 of shape [B] with the value function estimate at
|
||||
time T.
|
||||
dist_class: action distribution class for the logits.
|
||||
clip_rho_threshold: A scalar float32 tensor with the clipping threshold for
|
||||
importance weights (rho) when calculating the baseline targets (vs).
|
||||
rho^bar in the paper.
|
||||
@@ -208,13 +224,11 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
behaviour_policy_logits[i], dtype=tf.float32)
|
||||
target_policy_logits[i] = tf.convert_to_tensor(
|
||||
target_policy_logits[i], dtype=tf.float32)
|
||||
actions[i] = tf.convert_to_tensor(actions[i], dtype=tf.int32)
|
||||
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
behaviour_policy_logits[i].shape.assert_has_rank(3)
|
||||
target_policy_logits[i].shape.assert_has_rank(3)
|
||||
actions[i].shape.assert_has_rank(2)
|
||||
|
||||
with tf.name_scope(
|
||||
name,
|
||||
@@ -223,9 +237,9 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
discounts, rewards, values, bootstrap_value
|
||||
]):
|
||||
target_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
target_policy_logits, actions)
|
||||
target_policy_logits, actions, dist_class)
|
||||
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions)
|
||||
behaviour_policy_logits, actions, dist_class)
|
||||
|
||||
log_rhos = get_log_rhos(target_action_log_probs,
|
||||
behaviour_action_log_probs)
|
||||
|
||||
@@ -18,7 +18,6 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.utils import try_import_tf
|
||||
|
||||
@@ -40,6 +39,7 @@ class VTraceLoss(object):
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
@@ -52,7 +52,7 @@ class VTraceLoss(object):
|
||||
handle episode cut boundaries.
|
||||
|
||||
Args:
|
||||
actions: An int32 tensor of shape [T, B, ACTION_SPACE].
|
||||
actions: An int|float32 tensor of shape [T, B, ACTION_SPACE].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
@@ -70,6 +70,7 @@ class VTraceLoss(object):
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
"""
|
||||
|
||||
@@ -78,11 +79,12 @@ class VTraceLoss(object):
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
actions=tf.unstack(actions, axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
dist_class=dist_class,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
@@ -140,30 +142,28 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
|
||||
if isinstance(action_space, gym.spaces.Discrete):
|
||||
is_multidiscrete = False
|
||||
actions_shape = [None]
|
||||
output_hidden_shape = [action_space.n]
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
actions_shape = [None, len(action_space.nvec)]
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
else:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for IMPALA.".format(
|
||||
action_space))
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
|
||||
# Create input placeholders
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
if existing_inputs:
|
||||
actions, dones, behaviour_logits, rewards, observations, \
|
||||
prev_actions, prev_rewards = existing_inputs[:7]
|
||||
existing_state_in = existing_inputs[7:-1]
|
||||
existing_seq_lens = existing_inputs[-1]
|
||||
else:
|
||||
actions = tf.placeholder(tf.int64, actions_shape, name="ac")
|
||||
actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
dones = tf.placeholder(tf.bool, [None], name="dones")
|
||||
rewards = tf.placeholder(tf.float32, [None], name="rewards")
|
||||
behaviour_logits = tf.placeholder(
|
||||
tf.float32, [None, sum(output_hidden_shape)],
|
||||
name="behaviour_logits")
|
||||
tf.float32, [None, logit_dim], name="behaviour_logits")
|
||||
observations = tf.placeholder(
|
||||
tf.float32, [None] + list(observation_space.shape))
|
||||
existing_state_in = None
|
||||
@@ -174,8 +174,6 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
|
||||
# Setup the policy
|
||||
dist_class, logit_dim = ModelCatalog.get_action_dist(
|
||||
action_space, self.config["model"])
|
||||
prev_actions = ModelCatalog.get_action_placeholder(action_space)
|
||||
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
|
||||
self.model = ModelCatalog.get_model(
|
||||
@@ -261,6 +259,7 @@ class VTracePolicyGraph(LearningRateSchedule, VTracePostprocessing,
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
|
||||
@@ -18,7 +18,6 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.evaluation.postprocessing import compute_advantages
|
||||
@@ -94,6 +93,7 @@ class VTraceSurrogateLoss(object):
|
||||
rewards,
|
||||
values,
|
||||
bootstrap_value,
|
||||
dist_class,
|
||||
valid_mask,
|
||||
vf_loss_coeff=0.5,
|
||||
entropy_coeff=0.01,
|
||||
@@ -107,18 +107,19 @@ class VTraceSurrogateLoss(object):
|
||||
handle episode cut boundaries.
|
||||
|
||||
Arguments:
|
||||
actions: An int32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
actions: An int|float32 tensor of shape [T, B, logit_dim].
|
||||
prev_actions_logp: A float32 tensor of shape [T, B].
|
||||
actions_logp: A float32 tensor of shape [T, B].
|
||||
action_kl: A float32 tensor of shape [T, B].
|
||||
actions_entropy: A float32 tensor of shape [T, B].
|
||||
dones: A bool tensor of shape [T, B].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
target_logits: A float32 tensor of shape [T, B, NUM_ACTIONS].
|
||||
behaviour_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
target_logits: A float32 tensor of shape [T, B, logit_dim].
|
||||
discount: A float32 scalar.
|
||||
rewards: A float32 tensor of shape [T, B].
|
||||
values: A float32 tensor of shape [T, B].
|
||||
bootstrap_value: A float32 tensor of shape [B].
|
||||
dist_class: action distribution class for logits.
|
||||
valid_mask: A bool tensor of valid RNN input elements (#2992).
|
||||
"""
|
||||
|
||||
@@ -127,11 +128,12 @@ class VTraceSurrogateLoss(object):
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
target_policy_logits=target_logits,
|
||||
actions=tf.unstack(tf.cast(actions, tf.int32), axis=2),
|
||||
actions=tf.unstack(actions, axis=2),
|
||||
discounts=tf.to_float(~dones) * discount,
|
||||
rewards=rewards,
|
||||
values=values,
|
||||
bootstrap_value=bootstrap_value,
|
||||
dist_class=dist_class,
|
||||
clip_rho_threshold=tf.cast(clip_rho_threshold, tf.float32),
|
||||
clip_pg_rho_threshold=tf.cast(clip_pg_rho_threshold,
|
||||
tf.float32))
|
||||
@@ -218,10 +220,6 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing,
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
is_multidiscrete = True
|
||||
output_hidden_shape = action_space.nvec.astype(np.int32)
|
||||
elif self.config["vtrace"]:
|
||||
raise UnsupportedSpaceException(
|
||||
"Action space {} is not supported for APPO + VTrace.",
|
||||
format(action_space))
|
||||
else:
|
||||
is_multidiscrete = False
|
||||
output_hidden_shape = 1
|
||||
@@ -365,6 +363,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, APPOPostprocessing,
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
pendulum-appo-vt:
|
||||
env: Pendulum-v0
|
||||
run: APPO
|
||||
stop:
|
||||
episode_reward_mean: -900 # just check it learns a bit
|
||||
timesteps_total: 500000
|
||||
config:
|
||||
num_gpus: 0
|
||||
num_workers: 1
|
||||
gamma: 0.95
|
||||
train_batch_size: 50
|
||||
vtrace: true
|
||||
Reference in New Issue
Block a user