mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 13:19:38 +08:00
This commit is contained in:
@@ -480,7 +480,8 @@ class TorchPolicy(Policy):
|
||||
state = super().get_state()
|
||||
state["_optimizer_variables"] = []
|
||||
for i, o in enumerate(self._optimizers):
|
||||
state["_optimizer_variables"].append(o.state_dict())
|
||||
optim_state_dict = convert_to_non_torch_type(o.state_dict())
|
||||
state["_optimizer_variables"].append(optim_state_dict)
|
||||
return state
|
||||
|
||||
@override(Policy)
|
||||
@@ -492,7 +493,9 @@ class TorchPolicy(Policy):
|
||||
if optimizer_vars:
|
||||
assert len(optimizer_vars) == len(self._optimizers)
|
||||
for o, s in zip(self._optimizers, optimizer_vars):
|
||||
o.load_state_dict(s)
|
||||
optim_state_dict = convert_to_torch_tensor(
|
||||
s, device=self.device)
|
||||
o.load_state_dict(optim_state_dict)
|
||||
# Then the Policy's (NN) weights.
|
||||
super().set_state(state)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user