[RLlib] Torch LR schedule not working. Fix and added test case. (#12396)

This commit is contained in:
Sven Mika
2020-11-26 13:14:11 +01:00
committed by GitHub
parent d5215745e4
commit 6475297bd3
2 changed files with 30 additions and 6 deletions
+29 -1
View File
@@ -3,6 +3,7 @@ import numpy as np
import unittest
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
import ray.rllib.agents.ppo as ppo
from ray.rllib.agents.ppo.ppo_tf_policy import postprocess_ppo_gae as \
postprocess_ppo_gae_tf, ppo_surrogate_loss as ppo_surrogate_loss_tf
@@ -36,6 +37,30 @@ FAKE_BATCH = {
}
class MyCallbacks(DefaultCallbacks):
@staticmethod
def _check_lr_torch(policy, policy_id):
for j, opt in enumerate(policy._optimizers):
for p in opt.param_groups:
assert p["lr"] == policy.cur_lr, "LR scheduling error!"
@staticmethod
def _check_lr_tf(policy, policy_id):
lr = policy.cur_lr
sess = policy.get_session()
if sess:
lr = sess.run(lr)
optim_lr = sess.run(policy._optimizer._lr)
else:
lr = lr.numpy()
optim_lr = policy._optimizer.lr.numpy()
assert lr == optim_lr, "LR scheduling error!"
def on_train_result(self, *, trainer, result: dict, **kwargs):
trainer.workers.foreach_policy(self._check_lr_torch if trainer.config[
"framework"] == "torch" else self._check_lr_tf)
class TestPPO(unittest.TestCase):
@classmethod
def setUpClass(cls):
@@ -45,9 +70,12 @@ class TestPPO(unittest.TestCase):
def tearDownClass(cls):
ray.shutdown()
def test_ppo_compilation(self):
def test_ppo_compilation_and_lr_schedule(self):
"""Test whether a PPOTrainer can be built with all frameworks."""
config = copy.deepcopy(ppo.DEFAULT_CONFIG)
# for checking lr-schedule correctness
config["callbacks"] = MyCallbacks
config["num_workers"] = 1
config["num_sgd_iter"] = 2
# Settings in case we use an LSTM.
+1 -5
View File
@@ -612,15 +612,11 @@ class LearningRateSchedule:
@override(Policy)
def on_global_var_update(self, global_vars):
super(LearningRateSchedule, self).on_global_var_update(global_vars)
super().on_global_var_update(global_vars)
self.cur_lr = self.lr_schedule.value(global_vars["timestep"])
@override(TorchPolicy)
def optimizer(self):
for opt in self._optimizers:
for p in opt.param_groups:
p["lr"] = self.cur_lr
return self._optimizers
@DeveloperAPI