From 69352e3302d1a8eeff594f4ee73c858a874df2b4 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Sun, 12 May 2019 21:29:58 -0700 Subject: [PATCH] [rllib] Implement learn_on_batch() in torch policy graph --- .../ray/rllib/evaluation/torch_policy_graph.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 35220dc54..fb5c879a1 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -85,6 +85,24 @@ class TorchPolicyGraph(PolicyGraph): [h.cpu().numpy() for h in state], self.extra_action_out(model_out)) + @override(PolicyGraph) + def learn_on_batch(self, postprocessed_batch): + with self.lock: + loss_in = [] + for key in self._loss_inputs: + loss_in.append( + torch.from_numpy(postprocessed_batch[key]).to(self.device)) + loss_out = self._loss(self._model, *loss_in) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + self._optimizer.step() + + grad_info = self.extra_grad_info() + grad_info.update(grad_process_info) + return {LEARNER_STATS_KEY: grad_info} + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): with self.lock: