From deea1861ab87b64394d461f8739814937fed69bc Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 25 Aug 2020 18:34:19 -0700 Subject: [PATCH] [rllib] Try fixing torch GPU and masking errors (#10168) --- doc/source/rllib-examples.rst | 2 +- rllib/agents/dqn/dqn_torch_policy.py | 9 +++++---- rllib/examples/models/parametric_actions_model.py | 5 +++-- rllib/utils/exploration/epsilon_greedy.py | 3 ++- rllib/utils/torch_ops.py | 5 +++++ 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/doc/source/rllib-examples.rst b/doc/source/rllib-examples.rst index d19971bea..0f70a536a 100644 --- a/doc/source/rllib-examples.rst +++ b/doc/source/rllib-examples.rst @@ -102,7 +102,7 @@ Community Examples with RLlib-generated baselines. - `CARLA `__: Example of training autonomous vehicles with RLlib and `CARLA `__ simulator. -- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning `__: +- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning `__: Using Graph Neural Networks and RLlib to train multiple cooperative and adversarial agents to solve the "cover the area"-problem, thereby learning how to best communicate (or - in the adversarial case - how to disturb communication). - `Flatland `__: diff --git a/rllib/agents/dqn/dqn_torch_policy.py b/rllib/agents/dqn/dqn_torch_policy.py index 97c91b848..cb6cf77ad 100644 --- a/rllib/agents/dqn/dqn_torch_policy.py +++ b/rllib/agents/dqn/dqn_torch_policy.py @@ -15,7 +15,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.exploration.parameter_noise import ParameterNoise from ray.rllib.utils.framework import try_import_torch from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf, \ - softmax_cross_entropy_with_logits + softmax_cross_entropy_with_logits, FLOAT_MIN torch, nn = try_import_torch() F = None @@ -215,7 +215,7 @@ def build_q_losses(policy, model, _, train_batch): one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS], policy.action_space.n) q_t_selected = torch.sum( - torch.where(q_t > -float("inf"), q_t, + torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device)) * one_hot_selection, 1) q_logits_t_selected = torch.sum( @@ -234,7 +234,7 @@ def build_q_losses(policy, model, _, train_batch): q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net, policy.action_space.n) q_tp1_best = torch.sum( - torch.where(q_tp1 > -float("inf"), q_tp1, + torch.where(q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device)) * q_tp1_best_one_hot_selection, 1) q_probs_tp1_best = torch.sum( @@ -243,7 +243,8 @@ def build_q_losses(policy, model, _, train_batch): q_tp1_best_one_hot_selection = F.one_hot( torch.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = torch.sum( - torch.where(q_tp1 > -float("inf"), q_tp1, torch.tensor(0.0)) * + torch.where(q_tp1 > FLOAT_MIN, q_tp1, + torch.tensor(0.0, device=policy.device)) * q_tp1_best_one_hot_selection, 1) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) diff --git a/rllib/examples/models/parametric_actions_model.py b/rllib/examples/models/parametric_actions_model.py index 1dc2945f9..abffbcafd 100644 --- a/rllib/examples/models/parametric_actions_model.py +++ b/rllib/examples/models/parametric_actions_model.py @@ -7,6 +7,7 @@ from ray.rllib.agents.dqn.dqn_torch_model import \ from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -101,8 +102,8 @@ class TorchParametricActionsModel(DQNTorchModel): # Mask out invalid actions (use -inf to tag invalid). # These are then recognized by the EpsilonGreedy exploration component # as invalid actions that are not to be chosen. - inf_mask = torch.clamp( - torch.log(action_mask), -float("inf"), float("inf")) + inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX) + return action_logits + inf_mask, state def value_function(self): diff --git a/rllib/utils/exploration/epsilon_greedy.py b/rllib/utils/exploration/epsilon_greedy.py index 3558d73b3..c8155a893 100644 --- a/rllib/utils/exploration/epsilon_greedy.py +++ b/rllib/utils/exploration/epsilon_greedy.py @@ -7,6 +7,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \ get_variable from ray.rllib.utils.from_config import from_config from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule +from ray.rllib.utils.torch_ops import FLOAT_MIN tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -139,7 +140,7 @@ class EpsilonGreedy(Exploration): # Mask out actions, whose Q-values are -inf, so that we don't # even consider them for exploration. random_valid_action_logits = torch.where( - q_values == -float("inf"), + q_values <= FLOAT_MIN, torch.ones_like(q_values) * 0.0, torch.ones_like(q_values)) # A random action. random_actions = torch.squeeze( diff --git a/rllib/utils/torch_ops.py b/rllib/utils/torch_ops.py index 8df44ea57..4005bcdf0 100644 --- a/rllib/utils/torch_ops.py +++ b/rllib/utils/torch_ops.py @@ -6,6 +6,11 @@ from ray.rllib.utils.framework import try_import_torch torch, nn = try_import_torch() +# Limit values suitable for use as close to a -inf logit. These are useful +# since -inf / inf cause NaNs during backprop. +FLOAT_MIN = -3.4e38 +FLOAT_MAX = 3.4e38 + def atanh(x): return 0.5 * torch.log((1 + x) / (1 - x))