[rllib] Fix for Torch checkpoint taken on GPU fails to deserialize on CPU (#11071) (#11208)

This commit is contained in:
Philsik Chang
2020-10-06 14:01:55 +09:00
committed by GitHub
parent dc7c2a70b8
commit 2b26d2ca1b
+5 -2
View File
@@ -480,7 +480,8 @@ class TorchPolicy(Policy):
state = super().get_state()
state["_optimizer_variables"] = []
for i, o in enumerate(self._optimizers):
state["_optimizer_variables"].append(o.state_dict())
optim_state_dict = convert_to_non_torch_type(o.state_dict())
state["_optimizer_variables"].append(optim_state_dict)
return state
@override(Policy)
@@ -492,7 +493,9 @@ class TorchPolicy(Policy):
if optimizer_vars:
assert len(optimizer_vars) == len(self._optimizers)
for o, s in zip(self._optimizers, optimizer_vars):
o.load_state_dict(s)
optim_state_dict = convert_to_torch_tensor(
s, device=self.device)
o.load_state_dict(optim_state_dict)
# Then the Policy's (NN) weights.
super().set_state(state)