From c7cb2f5416be220ff7e6facc7112bc8aeea697c1 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Mon, 11 May 2020 22:03:27 +0200 Subject: [PATCH] [RLlib] IMPALA PyTorch GPU fixes (#8397) --- rllib/agents/impala/vtrace_torch.py | 10 ++++++---- rllib/agents/impala/vtrace_torch_policy.py | 9 ++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/rllib/agents/impala/vtrace_torch.py b/rllib/agents/impala/vtrace_torch.py index 540f3bf0b..0a4cfb9ba 100644 --- a/rllib/agents/impala/vtrace_torch.py +++ b/rllib/agents/impala/vtrace_torch.py @@ -201,9 +201,9 @@ def multi_from_logits(behaviour_policy_logits, target_policy_logits, device="cpu") actions = convert_to_torch_tensor(actions, device="cpu") + # Make sure tensor ranks are as expected. + # The rest will be checked by from_action_log_probs. for i in range(len(behaviour_policy_logits)): - # Make sure tensor ranks are as expected. - # The rest will be checked by from_action_log_probs. assert len(behaviour_policy_logits[i].size()) == 3 assert len(target_policy_logits[i].size()) == 3 @@ -215,9 +215,11 @@ def multi_from_logits(behaviour_policy_logits, # can't use precalculated values, recompute them. Note that # recomputing won't work well for autoregressive action dists # which may have variables not captured by 'logits' - behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions( - behaviour_policy_logits, actions, dist_class, model)) + behaviour_action_log_probs = multi_log_probs_from_logits_and_actions( + behaviour_policy_logits, actions, dist_class, model) + behaviour_action_log_probs = convert_to_torch_tensor( + behaviour_action_log_probs, device="cpu") behaviour_action_log_probs = force_list(behaviour_action_log_probs) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) diff --git a/rllib/agents/impala/vtrace_torch_policy.py b/rllib/agents/impala/vtrace_torch_policy.py index e5095a867..3a2d8c718 100644 --- a/rllib/agents/impala/vtrace_torch_policy.py +++ b/rllib/agents/impala/vtrace_torch_policy.py @@ -77,6 +77,7 @@ class VTraceLoss: # Compute vtrace on the CPU for better perf # (devices handled inside `vtrace.multi_from_logits`). + device = behaviour_action_logp[0].device self.vtrace_returns = vtrace.multi_from_logits( behaviour_action_log_probs=behaviour_action_logp, behaviour_policy_logits=behaviour_logits, @@ -90,14 +91,16 @@ class VTraceLoss: model=model, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) - self.value_targets = self.vtrace_returns.vs + # Move v-trace results back to GPU for actual loss computing. + self.value_targets = self.vtrace_returns.vs.to(device) # The policy gradients loss self.pi_loss = -torch.sum( - actions_logp * self.vtrace_returns.pg_advantages * valid_mask) + actions_logp * self.vtrace_returns.pg_advantages.to(device) * + valid_mask) # The baseline loss - delta = (values - self.vtrace_returns.vs) * valid_mask + delta = (values - self.value_targets) * valid_mask self.vf_loss = 0.5 * torch.sum(torch.pow(delta, 2.0)) # The entropy loss