diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 10d8604ac..91aa3ebe4 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -480,7 +480,8 @@ class TorchPolicy(Policy): state = super().get_state() state["_optimizer_variables"] = [] for i, o in enumerate(self._optimizers): - state["_optimizer_variables"].append(o.state_dict()) + optim_state_dict = convert_to_non_torch_type(o.state_dict()) + state["_optimizer_variables"].append(optim_state_dict) return state @override(Policy) @@ -492,7 +493,9 @@ class TorchPolicy(Policy): if optimizer_vars: assert len(optimizer_vars) == len(self._optimizers) for o, s in zip(self._optimizers, optimizer_vars): - o.load_state_dict(s) + optim_state_dict = convert_to_torch_tensor( + s, device=self.device) + o.load_state_dict(optim_state_dict) # Then the Policy's (NN) weights. super().set_state(state)