From 2b26d2ca1b7ad4f400dec0d97feddcd2f0bf6026 Mon Sep 17 00:00:00 2001 From: Philsik Chang Date: Tue, 6 Oct 2020 14:01:55 +0900 Subject: [PATCH] [rllib] Fix for Torch checkpoint taken on GPU fails to deserialize on CPU (#11071) (#11208) --- rllib/policy/torch_policy.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 10d8604ac..91aa3ebe4 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -480,7 +480,8 @@ class TorchPolicy(Policy): state = super().get_state() state["_optimizer_variables"] = [] for i, o in enumerate(self._optimizers): - state["_optimizer_variables"].append(o.state_dict()) + optim_state_dict = convert_to_non_torch_type(o.state_dict()) + state["_optimizer_variables"].append(optim_state_dict) return state @override(Policy) @@ -492,7 +493,9 @@ class TorchPolicy(Policy): if optimizer_vars: assert len(optimizer_vars) == len(self._optimizers) for o, s in zip(self._optimizers, optimizer_vars): - o.load_state_dict(s) + optim_state_dict = convert_to_torch_tensor( + s, device=self.device) + o.load_state_dict(optim_state_dict) # Then the Policy's (NN) weights. super().set_state(state)