From 1f043daf697517e85e6885bfbed4037990e0ec44 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Thu, 7 Nov 2019 12:17:27 -0800 Subject: [PATCH] [rllib] Fix and add test for LR annealing config --- rllib/policy/tf_policy_template.py | 10 +++++----- rllib/tests/test_optimizers.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 980f5e749..c95fdcbf5 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -160,26 +160,26 @@ def build_tf_policy(name, if optimizer_fn: return optimizer_fn(self, self.config) else: - return TFPolicy.optimizer(self) + return base.optimizer(self) @override(TFPolicy) def gradients(self, optimizer, loss): if gradients_fn: return gradients_fn(self, optimizer, loss) else: - return TFPolicy.gradients(self, optimizer, loss) + return base.gradients(self, optimizer, loss) @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): if apply_gradients_fn: return apply_gradients_fn(self, optimizer, grads_and_vars) else: - return TFPolicy.build_apply_op(self, optimizer, grads_and_vars) + return base.build_apply_op(self, optimizer, grads_and_vars) @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicy.extra_compute_action_fetches(self), + base.extra_compute_action_fetches(self), **self._extra_action_fetches) @override(TFPolicy) @@ -190,7 +190,7 @@ def build_tf_policy(name, LEARNER_STATS_KEY: {} }, **extra_learn_fetches_fn(self)) else: - return TFPolicy.extra_compute_grad_fetches(self) + return base.extra_compute_grad_fetches(self) @staticmethod def with_updates(**overrides): diff --git a/rllib/tests/test_optimizers.py b/rllib/tests/test_optimizers.py index 58e5fef3b..395aaeda6 100644 --- a/rllib/tests/test_optimizers.py +++ b/rllib/tests/test_optimizers.py @@ -21,6 +21,20 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() +class LRScheduleTest(unittest.TestCase): + def tearDown(self): + ray.shutdown() + + def testBasic(self): + ray.init(num_cpus=2) + ppo = PPOTrainer( + env="CartPole-v0", + config={"lr_schedule": [[0, 1e-5], [1000, 0.0]]}) + for _ in range(10): + result = ppo.train() + assert result["episode_reward_mean"] < 100, "should not have learned" + + class AsyncOptimizerTest(unittest.TestCase): def tearDown(self): ray.shutdown()