mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 17:49:47 +08:00
[rllib] Fix and add test for LR annealing config
This commit is contained in:
@@ -160,26 +160,26 @@ def build_tf_policy(name,
|
||||
if optimizer_fn:
|
||||
return optimizer_fn(self, self.config)
|
||||
else:
|
||||
return TFPolicy.optimizer(self)
|
||||
return base.optimizer(self)
|
||||
|
||||
@override(TFPolicy)
|
||||
def gradients(self, optimizer, loss):
|
||||
if gradients_fn:
|
||||
return gradients_fn(self, optimizer, loss)
|
||||
else:
|
||||
return TFPolicy.gradients(self, optimizer, loss)
|
||||
return base.gradients(self, optimizer, loss)
|
||||
|
||||
@override(TFPolicy)
|
||||
def build_apply_op(self, optimizer, grads_and_vars):
|
||||
if apply_gradients_fn:
|
||||
return apply_gradients_fn(self, optimizer, grads_and_vars)
|
||||
else:
|
||||
return TFPolicy.build_apply_op(self, optimizer, grads_and_vars)
|
||||
return base.build_apply_op(self, optimizer, grads_and_vars)
|
||||
|
||||
@override(TFPolicy)
|
||||
def extra_compute_action_fetches(self):
|
||||
return dict(
|
||||
TFPolicy.extra_compute_action_fetches(self),
|
||||
base.extra_compute_action_fetches(self),
|
||||
**self._extra_action_fetches)
|
||||
|
||||
@override(TFPolicy)
|
||||
@@ -190,7 +190,7 @@ def build_tf_policy(name,
|
||||
LEARNER_STATS_KEY: {}
|
||||
}, **extra_learn_fetches_fn(self))
|
||||
else:
|
||||
return TFPolicy.extra_compute_grad_fetches(self)
|
||||
return base.extra_compute_grad_fetches(self)
|
||||
|
||||
@staticmethod
|
||||
def with_updates(**overrides):
|
||||
|
||||
@@ -21,6 +21,20 @@ from ray.rllib.utils import try_import_tf
|
||||
tf = try_import_tf()
|
||||
|
||||
|
||||
class LRScheduleTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
def testBasic(self):
|
||||
ray.init(num_cpus=2)
|
||||
ppo = PPOTrainer(
|
||||
env="CartPole-v0",
|
||||
config={"lr_schedule": [[0, 1e-5], [1000, 0.0]]})
|
||||
for _ in range(10):
|
||||
result = ppo.train()
|
||||
assert result["episode_reward_mean"] < 100, "should not have learned"
|
||||
|
||||
|
||||
class AsyncOptimizerTest(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
Reference in New Issue
Block a user