diff --git a/python/ray/rllib/evaluation/torch_policy_graph.py b/python/ray/rllib/evaluation/torch_policy_graph.py index 35220dc54..fb5c879a1 100644 --- a/python/ray/rllib/evaluation/torch_policy_graph.py +++ b/python/ray/rllib/evaluation/torch_policy_graph.py @@ -85,6 +85,24 @@ class TorchPolicyGraph(PolicyGraph): [h.cpu().numpy() for h in state], self.extra_action_out(model_out)) + @override(PolicyGraph) + def learn_on_batch(self, postprocessed_batch): + with self.lock: + loss_in = [] + for key in self._loss_inputs: + loss_in.append( + torch.from_numpy(postprocessed_batch[key]).to(self.device)) + loss_out = self._loss(self._model, *loss_in) + self._optimizer.zero_grad() + loss_out.backward() + + grad_process_info = self.extra_grad_process() + self._optimizer.step() + + grad_info = self.extra_grad_info() + grad_info.update(grad_process_info) + return {LEARNER_STATS_KEY: grad_info} + @override(PolicyGraph) def compute_gradients(self, postprocessed_batch): with self.lock: