[rllib] Fix per-worker exploration in Ape-X; make more kwargs required for future safety (#7504)

* fix sched

* lintc

* lint

* fix

* add unit test

* fix

* format

* fix test

* fix test
This commit is contained in:
Eric Liang
2020-03-10 11:14:14 -07:00
committed by GitHub
parent d192ef0611
commit be48e1964b
18 changed files with 97 additions and 46 deletions
+8
View File
@@ -82,6 +82,14 @@ py_test(
srcs = ["agents/dqn/tests/test_dqn.py"]
)
# APEXTrainer
py_test(
name = "test_apex",
tags = ["agents_dir"],
size = "medium",
srcs = ["agents/dqn/tests/test_apex.py"]
)
# IMPALA
py_test(
name = "test_vtrace",
+29
View File
@@ -0,0 +1,29 @@
import numpy as np
import pytest
import unittest
import ray
import ray.rllib.agents.dqn.apex as apex
class TestApex(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4)
def tearDown(self):
ray.shutdown()
def test_apex_epsilon_distribution(self):
config = apex.APEX_DEFAULT_CONFIG.copy()
config["num_workers"] = 3
config["optimizer"]["num_replay_buffer_shards"] = 1
trainer = apex.ApexTrainer(config, env="CartPole-v0")
infos = trainer.workers.foreach_policy(
lambda p, _: p.get_exploration_info())
eps = [i["cur_epsilon"] for i in infos]
assert np.allclose(eps, [1.0, 0.016190862, 0.00065536, 2.6527108e-05])
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
+1
View File
@@ -85,6 +85,7 @@ class WorkerSet:
num_workers (int): The number of remote Workers to add to this
WorkerSet.
"""
self._num_workers = num_workers
remote_args = {
"num_cpus": self._remote_config["num_cpus_per_worker"],
"num_gpus": self._remote_config["num_gpus_per_worker"],
+2 -1
View File
@@ -70,7 +70,8 @@ class SyncReplayOptimizer(PolicyOptimizer):
endpoints=[(0, prioritized_replay_beta),
(prioritized_replay_beta_annealing_timesteps,
final_prioritized_replay_beta)],
outside_value=final_prioritized_replay_beta)
outside_value=final_prioritized_replay_beta,
framework=None)
self.prioritized_replay_eps = prioritized_replay_eps
self.train_batch_size = train_batch_size
self.before_learn_on_batch = before_learn_on_batch
+2 -2
View File
@@ -386,8 +386,8 @@ class Policy(metaclass=ABCMeta):
Exploration,
config.get("exploration_config", {"type": "StochasticSampling"}),
action_space=action_space,
num_workers=config.get("num_workers"),
worker_index=config.get("worker_index"),
num_workers=config.get("num_workers", 0),
worker_index=config.get("worker_index", 0),
framework=getattr(self, "framework", "tf"))
# If config is further passed around, it'll contain an already
# instantiated object.
+8 -5
View File
@@ -663,10 +663,10 @@ class LearningRateSchedule:
def __init__(self, lr, lr_schedule):
self.cur_lr = tf.get_variable("lr", initializer=lr, trainable=False)
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr)
self.lr_schedule = ConstantSchedule(lr, framework=None)
else:
self.lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1])
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
@override(Policy)
def on_global_var_update(self, global_vars):
@@ -690,18 +690,21 @@ class EntropyCoeffSchedule:
"entropy_coeff", initializer=entropy_coeff, trainable=False)
if entropy_coeff_schedule is None:
self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
self.entropy_coeff_schedule = ConstantSchedule(
entropy_coeff, framework=None)
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self.entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1])
outside_value=entropy_coeff_schedule[-1][-1],
framework=None)
else:
# Implements previous version but enforces outside_value
self.entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0)
outside_value=0.0,
framework=None)
@override(Policy)
def on_global_var_update(self, global_vars):
+8 -5
View File
@@ -278,10 +278,10 @@ class LearningRateSchedule:
def __init__(self, lr, lr_schedule):
self.cur_lr = lr
if lr_schedule is None:
self.lr_schedule = ConstantSchedule(lr)
self.lr_schedule = ConstantSchedule(lr, framework=None)
else:
self.lr_schedule = PiecewiseSchedule(
lr_schedule, outside_value=lr_schedule[-1][-1])
lr_schedule, outside_value=lr_schedule[-1][-1], framework=None)
@override(Policy)
def on_global_var_update(self, global_vars):
@@ -304,18 +304,21 @@ class EntropyCoeffSchedule:
self.entropy_coeff = entropy_coeff
if entropy_coeff_schedule is None:
self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff)
self.entropy_coeff_schedule = ConstantSchedule(
entropy_coeff, framework=None)
else:
# Allows for custom schedule similar to lr_schedule format
if isinstance(entropy_coeff_schedule, list):
self.entropy_coeff_schedule = PiecewiseSchedule(
entropy_coeff_schedule,
outside_value=entropy_coeff_schedule[-1][-1])
outside_value=entropy_coeff_schedule[-1][-1],
framework=None)
else:
# Implements previous version but enforces outside_value
self.entropy_coeff_schedule = PiecewiseSchedule(
[[0, entropy_coeff], [entropy_coeff_schedule, 0.0]],
outside_value=0.0)
outside_value=0.0,
framework=None)
@override(Policy)
def on_global_var_update(self, global_vars):
+1 -1
View File
@@ -94,7 +94,7 @@ class EpsilonGreedy(Exploration):
chose_random = tf.random_uniform(
tf.stack([batch_size]),
minval=0, maxval=1, dtype=epsilon.dtype) \
minval=0, maxval=1, dtype=tf.float32) \
< epsilon
action = tf.cond(
+2 -2
View File
@@ -17,8 +17,8 @@ class Exploration:
def __init__(self,
action_space: Space,
num_workers: int = 0,
worker_index: int = 0,
num_workers: int,
worker_index: int,
framework: str = "tf"):
"""
Args:
@@ -10,12 +10,7 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy):
See Ape-X paper.
"""
def __init__(self,
action_space,
*,
num_workers=0,
worker_index=0,
framework="tf",
def __init__(self, action_space, *, num_workers, worker_index, framework,
**kwargs):
"""Create a PerWorkerEpsilonGreedy exploration class.
@@ -28,17 +23,21 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy):
"""
epsilon_schedule = None
# Use a fixed, different epsilon per worker. See: Ape-X paper.
assert worker_index <= num_workers, (worker_index, num_workers)
if num_workers > 0:
if worker_index >= 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
epsilon_schedule = ConstantSchedule(0.4**exponent)
epsilon_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
else:
epsilon_schedule = ConstantSchedule(0.0)
epsilon_schedule = ConstantSchedule(0.0, framework=framework)
super().__init__(
action_space,
epsilon_schedule=epsilon_schedule,
framework=framework,
num_workers=num_workers,
worker_index=worker_index,
**kwargs)
@@ -13,8 +13,8 @@ class PerWorkerGaussianNoise(GaussianNoise):
def __init__(self,
action_space,
*,
num_workers=0,
worker_index=0,
num_workers,
worker_index,
framework="tf",
**kwargs):
"""
@@ -30,11 +30,12 @@ class PerWorkerGaussianNoise(GaussianNoise):
if num_workers > 0:
if worker_index >= 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
scale_schedule = ConstantSchedule(0.4**exponent)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
else:
scale_schedule = ConstantSchedule(0.0)
scale_schedule = ConstantSchedule(0.0, framework=framework)
super().__init__(
action_space,
@@ -14,8 +14,8 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise):
def __init__(self,
action_space,
*,
num_workers=0,
worker_index=0,
num_workers,
worker_index,
framework="tf",
**kwargs):
"""
@@ -31,14 +31,17 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise):
if num_workers > 0:
if worker_index >= 0:
exponent = (1 + worker_index / float(num_workers - 1) * 7)
scale_schedule = ConstantSchedule(0.4**exponent)
scale_schedule = ConstantSchedule(
0.4**exponent, framework=framework)
# Local worker should have zero exploration so that eval
# rollouts run properly.
else:
scale_schedule = ConstantSchedule(0.0)
scale_schedule = ConstantSchedule(0.0, framework=framework)
super().__init__(
action_space,
scale_schedule=scale_schedule,
num_workers=num_workers,
worker_index=worker_index,
framework=framework,
**kwargs)
+3 -3
View File
@@ -6,13 +6,13 @@ class ConstantSchedule(Schedule):
A Schedule where the value remains constant over time.
"""
def __init__(self, value, framework=None):
def __init__(self, value, framework):
"""
Args:
value (float): The constant value to return, independently of time.
"""
super().__init__(framework=None)
super().__init__(framework=framework)
self._v = value
def _value(self, t=None):
def _value(self, t):
return self._v
@@ -4,9 +4,9 @@ from ray.rllib.utils.schedules.schedule import Schedule
class ExponentialSchedule(Schedule):
def __init__(self,
schedule_timesteps,
framework,
initial_p=1.0,
decay_rate=0.1,
framework=None):
decay_rate=0.1):
"""
Exponential decay schedule from initial_p to final_p over
schedule_timesteps. After this many time steps always `final_p` is
+2 -2
View File
@@ -8,9 +8,9 @@ def _linear_interpolation(l, r, alpha):
class PiecewiseSchedule(Schedule):
def __init__(self,
endpoints,
framework,
interpolation=_linear_interpolation,
outside_value=None,
framework=None):
outside_value=None):
"""
Args:
endpoints (List[Tuple[int,float]]): A list of tuples
+2 -2
View File
@@ -8,9 +8,9 @@ class PolynomialSchedule(Schedule):
def __init__(self,
schedule_timesteps,
final_p,
framework,
initial_p=1.0,
power=2.0,
framework=None):
power=2.0):
"""
Polynomial interpolation between initial_p and final_p over
schedule_timesteps. After this many time steps always `final_p` is
+1 -1
View File
@@ -22,7 +22,7 @@ class Schedule(metaclass=ABCMeta):
value and returns the value dependent on the Schedule and the passed time.
"""
def __init__(self, framework=None):
def __init__(self, framework):
self.framework = check_framework(framework)
@abstractmethod
@@ -75,10 +75,13 @@ class TestFrameWorkAgnosticComponents(unittest.TestCase):
check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0
# Test recognizing default package path.
component = from_config(Exploration, {
"type": "EpsilonGreedy",
"action_space": Discrete(2)
})
component = from_config(
Exploration, {
"type": "EpsilonGreedy",
"action_space": Discrete(2),
"num_workers": 0,
"worker_index": 0,
})
check(component.epsilon_schedule.outside_value, 0.05) # default
# Create torch Component from yaml-string.