mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 08:40:02 +08:00
[RLlib] IMPALA PyTorch GPU fixes (#8397)
This commit is contained in:
@@ -201,9 +201,9 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
target_policy_logits, device="cpu")
|
||||
actions = convert_to_torch_tensor(actions, device="cpu")
|
||||
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
for i in range(len(behaviour_policy_logits)):
|
||||
# Make sure tensor ranks are as expected.
|
||||
# The rest will be checked by from_action_log_probs.
|
||||
assert len(behaviour_policy_logits[i].size()) == 3
|
||||
assert len(target_policy_logits[i].size()) == 3
|
||||
|
||||
@@ -215,9 +215,11 @@ def multi_from_logits(behaviour_policy_logits,
|
||||
# can't use precalculated values, recompute them. Note that
|
||||
# recomputing won't work well for autoregressive action dists
|
||||
# which may have variables not captured by 'logits'
|
||||
behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions, dist_class, model))
|
||||
behaviour_action_log_probs = multi_log_probs_from_logits_and_actions(
|
||||
behaviour_policy_logits, actions, dist_class, model)
|
||||
|
||||
behaviour_action_log_probs = convert_to_torch_tensor(
|
||||
behaviour_action_log_probs, device="cpu")
|
||||
behaviour_action_log_probs = force_list(behaviour_action_log_probs)
|
||||
log_rhos = get_log_rhos(target_action_log_probs,
|
||||
behaviour_action_log_probs)
|
||||
|
||||
@@ -77,6 +77,7 @@ class VTraceLoss:
|
||||
|
||||
# Compute vtrace on the CPU for better perf
|
||||
# (devices handled inside `vtrace.multi_from_logits`).
|
||||
device = behaviour_action_logp[0].device
|
||||
self.vtrace_returns = vtrace.multi_from_logits(
|
||||
behaviour_action_log_probs=behaviour_action_logp,
|
||||
behaviour_policy_logits=behaviour_logits,
|
||||
@@ -90,14 +91,16 @@ class VTraceLoss:
|
||||
model=model,
|
||||
clip_rho_threshold=clip_rho_threshold,
|
||||
clip_pg_rho_threshold=clip_pg_rho_threshold)
|
||||
self.value_targets = self.vtrace_returns.vs
|
||||
# Move v-trace results back to GPU for actual loss computing.
|
||||
self.value_targets = self.vtrace_returns.vs.to(device)
|
||||
|
||||
# The policy gradients loss
|
||||
self.pi_loss = -torch.sum(
|
||||
actions_logp * self.vtrace_returns.pg_advantages * valid_mask)
|
||||
actions_logp * self.vtrace_returns.pg_advantages.to(device) *
|
||||
valid_mask)
|
||||
|
||||
# The baseline loss
|
||||
delta = (values - self.vtrace_returns.vs) * valid_mask
|
||||
delta = (values - self.value_targets) * valid_mask
|
||||
self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0))
|
||||
|
||||
# The entropy loss
|
||||
|
||||
Reference in New Issue
Block a user