mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 22:01:12 +08:00
[RLlib] Torch LR schedule not working. Fix and added test case. (#12396)
This commit is contained in:
@@ -3,6 +3,7 @@ import numpy as np
|
||||
import unittest
|
||||
|
||||
import ray
|
||||
from ray.rllib.agents.callbacks import DefaultCallbacks
|
||||
import ray.rllib.agents.ppo as ppo
|
||||
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \
|
||||
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf
|
||||
@@ -36,6 +37,30 @@ FAKE_BATCH = {
|
||||
}
|
||||
|
||||
|
||||
class MyCallbacks(DefaultCallbacks):
|
||||
@staticmethod
|
||||
def _check_lr_torch(policy, policy_id):
|
||||
for j, opt in enumerate(policy._optimizers):
|
||||
for p in opt.param_groups:
|
||||
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
|
||||
|
||||
@staticmethod
|
||||
def _check_lr_tf(policy, policy_id):
|
||||
lr = policy.cur_lr
|
||||
sess = policy.get_session()
|
||||
if sess:
|
||||
lr = sess.run(lr)
|
||||
optim_lr = sess.run(policy._optimizer._lr)
|
||||
else:
|
||||
lr = lr.numpy()
|
||||
optim_lr = policy._optimizer.lr.numpy()
|
||||
assert lr == optim_lr, "LR scheduling error!"
|
||||
|
||||
def on_train_result(self, *, trainer, result: dict, **kwargs):
|
||||
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
|
||||
"framework"] == "torch" else self._check_lr_tf)
|
||||
|
||||
|
||||
class TestPPO(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@@ -45,9 +70,12 @@ class TestPPO(unittest.TestCase):
|
||||
def tearDownClass(cls):
|
||||
ray.shutdown()
|
||||
|
||||
def test_ppo_compilation(self):
|
||||
def test_ppo_compilation_and_lr_schedule(self):
|
||||
"""Test whether a PPOTrainer can be built with all frameworks."""
|
||||
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
|
||||
# for checking lr-schedule correctness
|
||||
config["callbacks"] = MyCallbacks
|
||||
|
||||
config["num_workers"] = 1
|
||||
config["num_sgd_iter"] = 2
|
||||
# Settings in case we use an LSTM.
|
||||
|
||||
@@ -612,15 +612,11 @@ class LearningRateSchedule:
|
||||
|
||||
@override(Policy)
|
||||
def on_global_var_update(self, global_vars):
|
||||
super(LearningRateSchedule, self).on_global_var_update(global_vars)
|
||||
super().on_global_var_update(global_vars)
|
||||
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
|
||||
|
||||
@override(TorchPolicy)
|
||||
def optimizer(self):
|
||||
for opt in self._optimizers:
|
||||
for p in opt.param_groups:
|
||||
p["lr"] = self.cur_lr
|
||||
return self._optimizers
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
|
||||
Reference in New Issue
Block a user