diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 980f5e749..c95fdcbf5 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -160,26 +160,26 @@ def build_tf_policy(name, if optimizer_fn: return optimizer_fn(self, self.config) else: - return TFPolicy.optimizer(self) + return base.optimizer(self) @override(TFPolicy) def gradients(self, optimizer, loss): if gradients_fn: return gradients_fn(self, optimizer, loss) else: - return TFPolicy.gradients(self, optimizer, loss) + return base.gradients(self, optimizer, loss) @override(TFPolicy) def build_apply_op(self, optimizer, grads_and_vars): if apply_gradients_fn: return apply_gradients_fn(self, optimizer, grads_and_vars) else: - return TFPolicy.build_apply_op(self, optimizer, grads_and_vars) + return base.build_apply_op(self, optimizer, grads_and_vars) @override(TFPolicy) def extra_compute_action_fetches(self): return dict( - TFPolicy.extra_compute_action_fetches(self), + base.extra_compute_action_fetches(self), **self._extra_action_fetches) @override(TFPolicy) @@ -190,7 +190,7 @@ def build_tf_policy(name, LEARNER_STATS_KEY: {} }, **extra_learn_fetches_fn(self)) else: - return TFPolicy.extra_compute_grad_fetches(self) + return base.extra_compute_grad_fetches(self) @staticmethod def with_updates(**overrides): diff --git a/rllib/tests/test_optimizers.py b/rllib/tests/test_optimizers.py index 58e5fef3b..395aaeda6 100644 --- a/rllib/tests/test_optimizers.py +++ b/rllib/tests/test_optimizers.py @@ -21,6 +21,20 @@ from ray.rllib.utils import try_import_tf tf = try_import_tf() +class LRScheduleTest(unittest.TestCase): + def tearDown(self): + ray.shutdown() + + def testBasic(self): + ray.init(num_cpus=2) + ppo = PPOTrainer( + env="CartPole-v0", + config={"lr_schedule": [[0, 1e-5], [1000, 0.0]]}) + for _ in range(10): + result = ppo.train() + assert result["episode_reward_mean"] < 100, "should not have learned" + + class AsyncOptimizerTest(unittest.TestCase): def tearDown(self): ray.shutdown()