diff --git a/rllib/agents/ppo/tests/test_ppo.py b/rllib/agents/ppo/tests/test_ppo.py index 50e3b99bc..c00cd36ba 100644 --- a/rllib/agents/ppo/tests/test_ppo.py +++ b/rllib/agents/ppo/tests/test_ppo.py @@ -3,6 +3,7 @@ import numpy as np import unittest import ray +from ray.rllib.agents.callbacks import DefaultCallbacks import ray.rllib.agents.ppo as ppo from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \ postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf @@ -36,6 +37,30 @@ FAKE_BATCH = { } +class MyCallbacks(DefaultCallbacks): + @staticmethod + def _check_lr_torch(policy, policy_id): + for j, opt in enumerate(policy._optimizers): + for p in opt.param_groups: + assert p["lr"] == policy.cur_lr, "LR scheduling error!" + + @staticmethod + def _check_lr_tf(policy, policy_id): + lr = policy.cur_lr + sess = policy.get_session() + if sess: + lr = sess.run(lr) + optim_lr = sess.run(policy._optimizer._lr) + else: + lr = lr.numpy() + optim_lr = policy._optimizer.lr.numpy() + assert lr == optim_lr, "LR scheduling error!" + + def on_train_result(self, *, trainer, result: dict, **kwargs): + trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[ + "framework"] == "torch" else self._check_lr_tf) + + class TestPPO(unittest.TestCase): @classmethod def setUpClass(cls): @@ -45,9 +70,12 @@ class TestPPO(unittest.TestCase): def tearDownClass(cls): ray.shutdown() - def test_ppo_compilation(self): + def test_ppo_compilation_and_lr_schedule(self): """Test whether a PPOTrainer can be built with all frameworks.""" config = copy.deepcopy(ppo.DEFAULT_CONFIG) + # for checking lr-schedule correctness + config["callbacks"] = MyCallbacks + config["num_workers"] = 1 config["num_sgd_iter"] = 2 # Settings in case we use an LSTM. diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index aaae43246..7b20a34ac 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -612,15 +612,11 @@ class LearningRateSchedule: @override(Policy) def on_global_var_update(self, global_vars): - super(LearningRateSchedule, self).on_global_var_update(global_vars) + super().on_global_var_update(global_vars) self.cur_lr = self.lr_schedule.value(global_vars["timestep"]) - - @override(TorchPolicy) - def optimizer(self): for opt in self._optimizers: for p in opt.param_groups: p["lr"] = self.cur_lr - return self._optimizers @DeveloperAPI