[rllib] Fix and add test for LR annealing config

This commit is contained in:
Eric Liang
2019-11-07 12:17:27 -08:00
committed by GitHub
parent fcb6bdbc39
commit 1f043daf69
2 changed files with 19 additions and 5 deletions
+5 -5
View File
@@ -160,26 +160,26 @@ def build_tf_policy(name,
if optimizer_fn:
return optimizer_fn(self, self.config)
else:
return TFPolicy.optimizer(self)
return base.optimizer(self)
@override(TFPolicy)
def gradients(self, optimizer, loss):
if gradients_fn:
return gradients_fn(self, optimizer, loss)
else:
return TFPolicy.gradients(self, optimizer, loss)
return base.gradients(self, optimizer, loss)
@override(TFPolicy)
def build_apply_op(self, optimizer, grads_and_vars):
if apply_gradients_fn:
return apply_gradients_fn(self, optimizer, grads_and_vars)
else:
return TFPolicy.build_apply_op(self, optimizer, grads_and_vars)
return base.build_apply_op(self, optimizer, grads_and_vars)
@override(TFPolicy)
def extra_compute_action_fetches(self):
return dict(
TFPolicy.extra_compute_action_fetches(self),
base.extra_compute_action_fetches(self),
**self._extra_action_fetches)
@override(TFPolicy)
@@ -190,7 +190,7 @@ def build_tf_policy(name,
LEARNER_STATS_KEY: {}
}, **extra_learn_fetches_fn(self))
else:
return TFPolicy.extra_compute_grad_fetches(self)
return base.extra_compute_grad_fetches(self)
@staticmethod
def with_updates(**overrides):
+14
View File
@@ -21,6 +21,20 @@ from ray.rllib.utils import try_import_tf
tf = try_import_tf()
class LRScheduleTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()
def testBasic(self):
ray.init(num_cpus=2)
ppo = PPOTrainer(
env="CartPole-v0",
config={"lr_schedule": [[0, 1e-5], [1000, 0.0]]})
for _ in range(10):
result = ppo.train()
assert result["episode_reward_mean"] < 100, "should not have learned"
class AsyncOptimizerTest(unittest.TestCase):
def tearDown(self):
ray.shutdown()