mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:34:51 +08:00
[rllib] Fix Multidiscrete support (#4869)
This commit is contained in:
@@ -15,7 +15,7 @@ from ray.rllib.policy.policy import Policy
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy import TFPolicy, \
|
||||
LearningRateSchedule
|
||||
from ray.rllib.models.action_dist import MultiCategorical
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.models.catalog import ModelCatalog
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.explained_variance import explained_variance
|
||||
@@ -191,9 +191,7 @@ class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
unpacked_outputs = tf.split(
|
||||
self.model.outputs, output_hidden_shape, axis=1)
|
||||
|
||||
dist_inputs = unpacked_outputs if is_multidiscrete else \
|
||||
self.model.outputs
|
||||
action_dist = dist_class(dist_inputs)
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
|
||||
values = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
@@ -258,32 +256,13 @@ class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=dist_class,
|
||||
dist_class=Categorical if is_multidiscrete else dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=self.config["vf_loss_coeff"],
|
||||
entropy_coeff=self.config["entropy_coeff"],
|
||||
clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
|
||||
clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"])
|
||||
|
||||
# KL divergence between worker and learner logits for debugging
|
||||
model_dist = MultiCategorical(unpacked_outputs)
|
||||
behaviour_dist = MultiCategorical(unpacked_behaviour_logits)
|
||||
|
||||
kls = model_dist.kl(behaviour_dist)
|
||||
if len(kls) > 1:
|
||||
self.KL_stats = {}
|
||||
|
||||
for i, kl in enumerate(kls):
|
||||
self.KL_stats.update({
|
||||
"mean_KL_{}".format(i): tf.reduce_mean(kl),
|
||||
"max_KL_{}".format(i): tf.reduce_max(kl),
|
||||
})
|
||||
else:
|
||||
self.KL_stats = {
|
||||
"mean_KL": tf.reduce_mean(kls[0]),
|
||||
"max_KL": tf.reduce_max(kls[0]),
|
||||
}
|
||||
|
||||
# Initialize TFPolicy
|
||||
loss_in = [
|
||||
(SampleBatch.ACTIONS, actions),
|
||||
@@ -318,7 +297,7 @@ class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
|
||||
self.stats_fetches = {
|
||||
LEARNER_STATS_KEY: dict({
|
||||
LEARNER_STATS_KEY: {
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"policy_loss": self.loss.pi_loss,
|
||||
"entropy": self.loss.entropy,
|
||||
@@ -328,7 +307,7 @@ class VTraceTFPolicy(LearningRateSchedule, VTracePostprocessing, TFPolicy):
|
||||
"vf_explained_var": explained_variance(
|
||||
tf.reshape(self.loss.vtrace_returns.vs, [-1]),
|
||||
tf.reshape(make_time_major(values, drop_last=True), [-1])),
|
||||
}, **self.KL_stats),
|
||||
},
|
||||
}
|
||||
|
||||
@override(TFPolicy)
|
||||
|
||||
@@ -13,6 +13,7 @@ import gym
|
||||
import ray
|
||||
from ray.rllib.agents.impala import vtrace
|
||||
from ray.rllib.evaluation.postprocessing import Postprocessing
|
||||
from ray.rllib.models.action_dist import Categorical
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.policy.tf_policy_template import build_tf_policy
|
||||
from ray.rllib.policy.tf_policy import LearningRateSchedule
|
||||
@@ -220,10 +221,8 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
behaviour_logits, output_hidden_shape, axis=1)
|
||||
unpacked_outputs = tf.split(
|
||||
policy.model.outputs, output_hidden_shape, axis=1)
|
||||
prev_dist_inputs = unpacked_behaviour_logits if is_multidiscrete else \
|
||||
behaviour_logits
|
||||
action_dist = policy.action_dist
|
||||
prev_action_dist = policy.dist_class(prev_dist_inputs)
|
||||
prev_action_dist = policy.dist_class(behaviour_logits)
|
||||
values = policy.value_function
|
||||
|
||||
if policy.model.state_in:
|
||||
@@ -257,7 +256,7 @@ def build_appo_surrogate_loss(policy, batch_tensors):
|
||||
rewards=make_time_major(rewards, drop_last=True),
|
||||
values=make_time_major(values, drop_last=True),
|
||||
bootstrap_value=make_time_major(values)[-1],
|
||||
dist_class=policy.dist_class,
|
||||
dist_class=Categorical if is_multidiscrete else policy.dist_class,
|
||||
valid_mask=make_time_major(mask, drop_last=True),
|
||||
vf_loss_coeff=policy.config["vf_loss_coeff"],
|
||||
entropy_coeff=policy.config["entropy_coeff"],
|
||||
|
||||
@@ -76,7 +76,7 @@ class Categorical(ActionDistribution):
|
||||
@override(ActionDistribution)
|
||||
def logp(self, x):
|
||||
return -tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=self.inputs, labels=x)
|
||||
logits=self.inputs, labels=tf.cast(x, tf.int32))
|
||||
|
||||
@override(ActionDistribution)
|
||||
def entropy(self):
|
||||
@@ -126,14 +126,17 @@ class Categorical(ActionDistribution):
|
||||
class MultiCategorical(ActionDistribution):
|
||||
"""Categorical distribution for discrete action spaces."""
|
||||
|
||||
def __init__(self, inputs):
|
||||
self.cats = [Categorical(input_) for input_ in inputs]
|
||||
def __init__(self, inputs, input_lens):
|
||||
self.cats = [
|
||||
Categorical(input_)
|
||||
for input_ in tf.split(inputs, input_lens, axis=1)
|
||||
]
|
||||
self.sample_op = self._build_sample_op()
|
||||
|
||||
def logp(self, actions):
|
||||
# If tensor is provided, unstack it into list
|
||||
if isinstance(actions, tf.Tensor):
|
||||
actions = tf.unstack(actions, axis=1)
|
||||
actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
|
||||
logps = tf.stack(
|
||||
[cat.logp(act) for cat, act in zip(self.cats, actions)])
|
||||
return tf.reduce_sum(logps, axis=0)
|
||||
|
||||
@@ -149,7 +149,8 @@ class ModelCatalog(object):
|
||||
elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
|
||||
if torch:
|
||||
raise NotImplementedError
|
||||
return MultiCategorical, int(sum(action_space.nvec))
|
||||
return partial(MultiCategorical, input_lens=action_space.nvec), \
|
||||
int(sum(action_space.nvec))
|
||||
|
||||
raise NotImplementedError("Unsupported args: {} {}".format(
|
||||
action_space, dist_type))
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
import traceback
|
||||
|
||||
import gym
|
||||
from gym.spaces import Box, Discrete, Tuple, Dict
|
||||
from gym.spaces import Box, Discrete, Tuple, Dict, MultiDiscrete
|
||||
from gym.envs.registration import EnvSpec
|
||||
import numpy as np
|
||||
import sys
|
||||
@@ -17,6 +17,7 @@ from ray.tune.registry import register_env
|
||||
ACTION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
||||
"multidiscrete": MultiDiscrete([1, 2, 3, 4]),
|
||||
"tuple": Tuple(
|
||||
[Discrete(2),
|
||||
Discrete(3),
|
||||
@@ -61,7 +62,7 @@ def make_stub_env(action_space, obs_space, check_action_bounds):
|
||||
return StubEnv
|
||||
|
||||
|
||||
def check_support(alg, config, stats, check_bounds=False):
|
||||
def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||
for a_name, action_space in ACTION_SPACES_TO_TEST.items():
|
||||
for o_name, obs_space in OBSERVATION_SPACES_TO_TEST.items():
|
||||
print("=== Testing", alg, action_space, obs_space, "===")
|
||||
@@ -87,7 +88,7 @@ def check_support(alg, config, stats, check_bounds=False):
|
||||
pass
|
||||
print(stat)
|
||||
print()
|
||||
stats[alg, a_name, o_name] = stat
|
||||
stats[name or alg, a_name, o_name] = stat
|
||||
|
||||
|
||||
def check_support_multiagent(alg, config):
|
||||
@@ -114,6 +115,11 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
stats = {}
|
||||
check_support("IMPALA", {"num_gpus": 0}, stats)
|
||||
check_support("APPO", {"num_gpus": 0, "vtrace": False}, stats)
|
||||
check_support(
|
||||
"APPO", {
|
||||
"num_gpus": 0,
|
||||
"vtrace": True
|
||||
}, stats, name="APPO-vt")
|
||||
check_support(
|
||||
"DDPG", {
|
||||
"exploration_ou_noise_scale": 100.0,
|
||||
|
||||
Reference in New Issue
Block a user