[rllib] Try fixing torch GPU and masking errors (#10168)

This commit is contained in:
Eric Liang
2020-08-25 18:34:19 -07:00
committed by GitHub
parent 6fcb816fdd
commit deea1861ab
5 changed files with 16 additions and 8 deletions
+1 -1
View File
@@ -102,7 +102,7 @@ Community Examples
with RLlib-generated baselines.
- `CARLA <https://github.com/layssi/Carla_Ray_Rlib>`__:
Example of training autonomous vehicles with RLlib and `CARLA <http://carla.org/>`__ simulator.
- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning <https://arxiv.org/pdf/2008.02616.pdf>`__:
- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning <https://arxiv.org/pdf/2008.02616.pdf>`__:
Using Graph Neural Networks and RLlib to train multiple cooperative and adversarial agents to solve the
"cover the area"-problem, thereby learning how to best communicate (or - in the adversarial case - how to disturb communication).
- `Flatland <https://flatland.aicrowd.com/intro.html>`__:
+5 -4
View File
@@ -15,7 +15,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf, \
softmax_cross_entropy_with_logits
softmax_cross_entropy_with_logits, FLOAT_MIN
torch, nn = try_import_torch()
F = None
@@ -215,7 +215,7 @@ def build_q_losses(policy, model, _, train_batch):
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS],
policy.action_space.n)
q_t_selected = torch.sum(
torch.where(q_t > -float("inf"), q_t,
torch.where(q_t > FLOAT_MIN, q_t,
torch.tensor(0.0, device=policy.device)) *
one_hot_selection, 1)
q_logits_t_selected = torch.sum(
@@ -234,7 +234,7 @@ def build_q_losses(policy, model, _, train_batch):
q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net,
policy.action_space.n)
q_tp1_best = torch.sum(
torch.where(q_tp1 > -float("inf"), q_tp1,
torch.where(q_tp1 > FLOAT_MIN, q_tp1,
torch.tensor(0.0, device=policy.device)) *
q_tp1_best_one_hot_selection, 1)
q_probs_tp1_best = torch.sum(
@@ -243,7 +243,8 @@ def build_q_losses(policy, model, _, train_batch):
q_tp1_best_one_hot_selection = F.one_hot(
torch.argmax(q_tp1, 1), policy.action_space.n)
q_tp1_best = torch.sum(
torch.where(q_tp1 > -float("inf"), q_tp1, torch.tensor(0.0)) *
torch.where(q_tp1 > FLOAT_MIN, q_tp1,
torch.tensor(0.0, device=policy.device)) *
q_tp1_best_one_hot_selection, 1)
q_probs_tp1_best = torch.sum(
q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1)
@@ -7,6 +7,7 @@ from ray.rllib.agents.dqn.dqn_torch_model import \
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
@@ -101,8 +102,8 @@ class TorchParametricActionsModel(DQNTorchModel):
# Mask out invalid actions (use -inf to tag invalid).
# These are then recognized by the EpsilonGreedy exploration component
# as invalid actions that are not to be chosen.
inf_mask = torch.clamp(
torch.log(action_mask), -float("inf"), float("inf"))
inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
return action_logits + inf_mask, state
def value_function(self):
+2 -1
View File
@@ -7,6 +7,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
get_variable
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
from ray.rllib.utils.torch_ops import FLOAT_MIN
tf1, tf, tfv = try_import_tf()
torch, _ = try_import_torch()
@@ -139,7 +140,7 @@ class EpsilonGreedy(Exploration):
# Mask out actions, whose Q-values are -inf, so that we don't
# even consider them for exploration.
random_valid_action_logits = torch.where(
q_values == -float("inf"),
q_values <= FLOAT_MIN,
torch.ones_like(q_values) * 0.0, torch.ones_like(q_values))
# A random action.
random_actions = torch.squeeze(
+5
View File
@@ -6,6 +6,11 @@ from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
# Limit values suitable for use as close to a -inf logit. These are useful
# since -inf / inf cause NaNs during backprop.
FLOAT_MIN = -3.4e38
FLOAT_MAX = 3.4e38
def atanh(x):
return 0.5 * torch.log((1 + x) / (1 - x))