mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
[rllib] Implement learn_on_batch() in torch policy graph
This commit is contained in:
@@ -85,6 +85,24 @@ class TorchPolicyGraph(PolicyGraph):
|
||||
[h.cpu().numpy() for h in state],
|
||||
self.extra_action_out(model_out))
|
||||
|
||||
@override(PolicyGraph)
|
||||
def learn_on_batch(self, postprocessed_batch):
|
||||
with self.lock:
|
||||
loss_in = []
|
||||
for key in self._loss_inputs:
|
||||
loss_in.append(
|
||||
torch.from_numpy(postprocessed_batch[key]).to(self.device))
|
||||
loss_out = self._loss(self._model, *loss_in)
|
||||
self._optimizer.zero_grad()
|
||||
loss_out.backward()
|
||||
|
||||
grad_process_info = self.extra_grad_process()
|
||||
self._optimizer.step()
|
||||
|
||||
grad_info = self.extra_grad_info()
|
||||
grad_info.update(grad_process_info)
|
||||
return {LEARNER_STATS_KEY: grad_info}
|
||||
|
||||
@override(PolicyGraph)
|
||||
def compute_gradients(self, postprocessed_batch):
|
||||
with self.lock:
|
||||
|
||||
Reference in New Issue
Block a user