mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 00:50:21 +08:00
[rllib] Try fixing torch GPU and masking errors (#10168)
This commit is contained in:
@@ -102,7 +102,7 @@ Community Examples
|
||||
with RLlib-generated baselines.
|
||||
- `CARLA <https://github.com/layssi/Carla_Ray_Rlib>`__:
|
||||
Example of training autonomous vehicles with RLlib and `CARLA <http://carla.org/>`__ simulator.
|
||||
- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning <https://arxiv.org/pdf/2008.02616.pdf>`__:
|
||||
- `The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning <https://arxiv.org/pdf/2008.02616.pdf>`__:
|
||||
Using Graph Neural Networks and RLlib to train multiple cooperative and adversarial agents to solve the
|
||||
"cover the area"-problem, thereby learning how to best communicate (or - in the adversarial case - how to disturb communication).
|
||||
- `Flatland <https://flatland.aicrowd.com/intro.html>`__:
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.rllib.utils.error import UnsupportedSpaceException
|
||||
from ray.rllib.utils.exploration.parameter_noise import ParameterNoise
|
||||
from ray.rllib.utils.framework import try_import_torch
|
||||
from ray.rllib.utils.torch_ops import huber_loss, reduce_mean_ignore_inf, \
|
||||
softmax_cross_entropy_with_logits
|
||||
softmax_cross_entropy_with_logits, FLOAT_MIN
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
F = None
|
||||
@@ -215,7 +215,7 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS],
|
||||
policy.action_space.n)
|
||||
q_t_selected = torch.sum(
|
||||
torch.where(q_t > -float("inf"), q_t,
|
||||
torch.where(q_t > FLOAT_MIN, q_t,
|
||||
torch.tensor(0.0, device=policy.device)) *
|
||||
one_hot_selection, 1)
|
||||
q_logits_t_selected = torch.sum(
|
||||
@@ -234,7 +234,7 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net,
|
||||
policy.action_space.n)
|
||||
q_tp1_best = torch.sum(
|
||||
torch.where(q_tp1 > -float("inf"), q_tp1,
|
||||
torch.where(q_tp1 > FLOAT_MIN, q_tp1,
|
||||
torch.tensor(0.0, device=policy.device)) *
|
||||
q_tp1_best_one_hot_selection, 1)
|
||||
q_probs_tp1_best = torch.sum(
|
||||
@@ -243,7 +243,8 @@ def build_q_losses(policy, model, _, train_batch):
|
||||
q_tp1_best_one_hot_selection = F.one_hot(
|
||||
torch.argmax(q_tp1, 1), policy.action_space.n)
|
||||
q_tp1_best = torch.sum(
|
||||
torch.where(q_tp1 > -float("inf"), q_tp1, torch.tensor(0.0)) *
|
||||
torch.where(q_tp1 > FLOAT_MIN, q_tp1,
|
||||
torch.tensor(0.0, device=policy.device)) *
|
||||
q_tp1_best_one_hot_selection, 1)
|
||||
q_probs_tp1_best = torch.sum(
|
||||
q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ray.rllib.agents.dqn.dqn_torch_model import \
|
||||
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
|
||||
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
|
||||
from ray.rllib.utils.framework import try_import_tf, try_import_torch
|
||||
from ray.rllib.utils.torch_ops import FLOAT_MIN, FLOAT_MAX
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, nn = try_import_torch()
|
||||
@@ -101,8 +102,8 @@ class TorchParametricActionsModel(DQNTorchModel):
|
||||
# Mask out invalid actions (use -inf to tag invalid).
|
||||
# These are then recognized by the EpsilonGreedy exploration component
|
||||
# as invalid actions that are not to be chosen.
|
||||
inf_mask = torch.clamp(
|
||||
torch.log(action_mask), -float("inf"), float("inf"))
|
||||
inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)
|
||||
|
||||
return action_logits + inf_mask, state
|
||||
|
||||
def value_function(self):
|
||||
|
||||
@@ -7,6 +7,7 @@ from ray.rllib.utils.framework import try_import_tf, try_import_torch, \
|
||||
get_variable
|
||||
from ray.rllib.utils.from_config import from_config
|
||||
from ray.rllib.utils.schedules import Schedule, PiecewiseSchedule
|
||||
from ray.rllib.utils.torch_ops import FLOAT_MIN
|
||||
|
||||
tf1, tf, tfv = try_import_tf()
|
||||
torch, _ = try_import_torch()
|
||||
@@ -139,7 +140,7 @@ class EpsilonGreedy(Exploration):
|
||||
# Mask out actions, whose Q-values are -inf, so that we don't
|
||||
# even consider them for exploration.
|
||||
random_valid_action_logits = torch.where(
|
||||
q_values == -float("inf"),
|
||||
q_values <= FLOAT_MIN,
|
||||
torch.ones_like(q_values) * 0.0, torch.ones_like(q_values))
|
||||
# A random action.
|
||||
random_actions = torch.squeeze(
|
||||
|
||||
@@ -6,6 +6,11 @@ from ray.rllib.utils.framework import try_import_torch
|
||||
|
||||
torch, nn = try_import_torch()
|
||||
|
||||
# Limit values suitable for use as close to a -inf logit. These are useful
|
||||
# since -inf / inf cause NaNs during backprop.
|
||||
FLOAT_MIN = -3.4e38
|
||||
FLOAT_MAX = 3.4e38
|
||||
|
||||
|
||||
def atanh(x):
|
||||
return 0.5 * torch.log((1 + x) / (1 - x))
|
||||
|
||||
Reference in New Issue
Block a user