[rllib] Implement learn_on_batch() in torch policy graph

This commit is contained in:
Eric Liang
2019-05-12 21:29:58 -07:00
committed by GitHub
parent f3b8b9093d
commit 69352e3302
@@ -85,6 +85,24 @@ class TorchPolicyGraph(PolicyGraph):
[h.cpu().numpy() for h in state],
self.extra_action_out(model_out))
@override(PolicyGraph)
def learn_on_batch(self, postprocessed_batch):
with self.lock:
loss_in = []
for key in self._loss_inputs:
loss_in.append(
torch.from_numpy(postprocessed_batch[key]).to(self.device))
loss_out = self._loss(self._model, *loss_in)
self._optimizer.zero_grad()
loss_out.backward()
grad_process_info = self.extra_grad_process()
self._optimizer.step()
grad_info = self.extra_grad_info()
grad_info.update(grad_process_info)
return {LEARNER_STATS_KEY: grad_info}
@override(PolicyGraph)
def compute_gradients(self, postprocessed_batch):
with self.lock: