From be48e1964b7d685f77baa5a45aca29d3cfcb9883 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Tue, 10 Mar 2020 11:14:14 -0700 Subject: [PATCH] [rllib] Fix per-worker exploration in Ape-X; make more kwargs required for future safety (#7504) * fix sched * lintc * lint * fix * add unit test * fix * format * fix test * fix test --- rllib/BUILD | 8 +++++ rllib/agents/dqn/tests/test_apex.py | 29 +++++++++++++++++++ rllib/evaluation/worker_set.py | 1 + rllib/optimizers/sync_replay_optimizer.py | 3 +- rllib/policy/policy.py | 4 +-- rllib/policy/tf_policy.py | 13 +++++---- rllib/policy/torch_policy.py | 13 +++++---- rllib/utils/exploration/epsilon_greedy.py | 2 +- rllib/utils/exploration/exploration.py | 4 +-- .../exploration/per_worker_epsilon_greedy.py | 15 +++++----- .../exploration/per_worker_gaussian_noise.py | 9 +++--- .../per_worker_ornstein_uhlenbeck_noise.py | 11 ++++--- rllib/utils/schedules/constant_schedule.py | 6 ++-- rllib/utils/schedules/exponential_schedule.py | 4 +-- rllib/utils/schedules/piecewise_schedule.py | 4 +-- rllib/utils/schedules/polynomial_schedule.py | 4 +-- rllib/utils/schedules/schedule.py | 2 +- .../test_framework_agnostic_components.py | 11 ++++--- 18 files changed, 97 insertions(+), 46 deletions(-) create mode 100644 rllib/agents/dqn/tests/test_apex.py diff --git a/rllib/BUILD b/rllib/BUILD index 110bbbbf6..92023e937 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -82,6 +82,14 @@ py_test( srcs = ["agents/dqn/tests/test_dqn.py"] ) +# APEXTrainer +py_test( + name = "test_apex", + tags = ["agents_dir"], + size = "medium", + srcs = ["agents/dqn/tests/test_apex.py"] +) + # IMPALA py_test( name = "test_vtrace", diff --git a/rllib/agents/dqn/tests/test_apex.py b/rllib/agents/dqn/tests/test_apex.py new file mode 100644 index 000000000..c840957b8 --- /dev/null +++ b/rllib/agents/dqn/tests/test_apex.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +import unittest + +import ray +import ray.rllib.agents.dqn.apex as apex + + +class TestApex(unittest.TestCase): + def setUp(self): + ray.init(num_cpus=4) + + def tearDown(self): + ray.shutdown() + + def test_apex_epsilon_distribution(self): + config = apex.APEX_DEFAULT_CONFIG.copy() + config["num_workers"] = 3 + config["optimizer"]["num_replay_buffer_shards"] = 1 + trainer = apex.ApexTrainer(config, env="CartPole-v0") + infos = trainer.workers.foreach_policy( + lambda p, _: p.get_exploration_info()) + eps = [i["cur_epsilon"] for i in infos] + assert np.allclose(eps, [1.0, 0.016190862, 0.00065536, 2.6527108e-05]) + + +if __name__ == "__main__": + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index c4922d2eb..8c64cb1b5 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -85,6 +85,7 @@ class WorkerSet: num_workers (int): The number of remote Workers to add to this WorkerSet. """ + self._num_workers = num_workers remote_args = { "num_cpus": self._remote_config["num_cpus_per_worker"], "num_gpus": self._remote_config["num_gpus_per_worker"], diff --git a/rllib/optimizers/sync_replay_optimizer.py b/rllib/optimizers/sync_replay_optimizer.py index 6db44d27e..fef50e778 100644 --- a/rllib/optimizers/sync_replay_optimizer.py +++ b/rllib/optimizers/sync_replay_optimizer.py @@ -70,7 +70,8 @@ class SyncReplayOptimizer(PolicyOptimizer): endpoints=[(0, prioritized_replay_beta), (prioritized_replay_beta_annealing_timesteps, final_prioritized_replay_beta)], - outside_value=final_prioritized_replay_beta) + outside_value=final_prioritized_replay_beta, + framework=None) self.prioritized_replay_eps = prioritized_replay_eps self.train_batch_size = train_batch_size self.before_learn_on_batch = before_learn_on_batch diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 4e0b2d19d..6d24b82c5 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -386,8 +386,8 @@ class Policy(metaclass=ABCMeta): Exploration, config.get("exploration_config", {"type": "StochasticSampling"}), action_space=action_space, - num_workers=config.get("num_workers"), - worker_index=config.get("worker_index"), + num_workers=config.get("num_workers", 0), + worker_index=config.get("worker_index", 0), framework=getattr(self, "framework", "tf")) # If config is further passed around, it'll contain an already # instantiated object. diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index 11f6b17f2..0e06902fe 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -663,10 +663,10 @@ class LearningRateSchedule: def __init__(self, lr, lr_schedule): self.cur_lr = tf.get_variable("lr", initializer=lr, trainable=False) if lr_schedule is None: - self.lr_schedule = ConstantSchedule(lr) + self.lr_schedule = ConstantSchedule(lr, framework=None) else: self.lr_schedule = PiecewiseSchedule( - lr_schedule, outside_value=lr_schedule[-1][-1]) + lr_schedule, outside_value=lr_schedule[-1][-1], framework=None) @override(Policy) def on_global_var_update(self, global_vars): @@ -690,18 +690,21 @@ class EntropyCoeffSchedule: "entropy_coeff", initializer=entropy_coeff, trainable=False) if entropy_coeff_schedule is None: - self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff) + self.entropy_coeff_schedule = ConstantSchedule( + entropy_coeff, framework=None) else: # Allows for custom schedule similar to lr_schedule format if isinstance(entropy_coeff_schedule, list): self.entropy_coeff_schedule = PiecewiseSchedule( entropy_coeff_schedule, - outside_value=entropy_coeff_schedule[-1][-1]) + outside_value=entropy_coeff_schedule[-1][-1], + framework=None) else: # Implements previous version but enforces outside_value self.entropy_coeff_schedule = PiecewiseSchedule( [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]], - outside_value=0.0) + outside_value=0.0, + framework=None) @override(Policy) def on_global_var_update(self, global_vars): diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index 417107fbe..d9b6dacb4 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -278,10 +278,10 @@ class LearningRateSchedule: def __init__(self, lr, lr_schedule): self.cur_lr = lr if lr_schedule is None: - self.lr_schedule = ConstantSchedule(lr) + self.lr_schedule = ConstantSchedule(lr, framework=None) else: self.lr_schedule = PiecewiseSchedule( - lr_schedule, outside_value=lr_schedule[-1][-1]) + lr_schedule, outside_value=lr_schedule[-1][-1], framework=None) @override(Policy) def on_global_var_update(self, global_vars): @@ -304,18 +304,21 @@ class EntropyCoeffSchedule: self.entropy_coeff = entropy_coeff if entropy_coeff_schedule is None: - self.entropy_coeff_schedule = ConstantSchedule(entropy_coeff) + self.entropy_coeff_schedule = ConstantSchedule( + entropy_coeff, framework=None) else: # Allows for custom schedule similar to lr_schedule format if isinstance(entropy_coeff_schedule, list): self.entropy_coeff_schedule = PiecewiseSchedule( entropy_coeff_schedule, - outside_value=entropy_coeff_schedule[-1][-1]) + outside_value=entropy_coeff_schedule[-1][-1], + framework=None) else: # Implements previous version but enforces outside_value self.entropy_coeff_schedule = PiecewiseSchedule( [[0, entropy_coeff], [entropy_coeff_schedule, 0.0]], - outside_value=0.0) + outside_value=0.0, + framework=None) @override(Policy) def on_global_var_update(self, global_vars): diff --git a/rllib/utils/exploration/epsilon_greedy.py b/rllib/utils/exploration/epsilon_greedy.py index 8d803f03e..d0195b6f8 100644 --- a/rllib/utils/exploration/epsilon_greedy.py +++ b/rllib/utils/exploration/epsilon_greedy.py @@ -94,7 +94,7 @@ class EpsilonGreedy(Exploration): chose_random = tf.random_uniform( tf.stack([batch_size]), - minval=0, maxval=1, dtype=epsilon.dtype) \ + minval=0, maxval=1, dtype=tf.float32) \ < epsilon action = tf.cond( diff --git a/rllib/utils/exploration/exploration.py b/rllib/utils/exploration/exploration.py index ad8914edf..78f1eaa95 100644 --- a/rllib/utils/exploration/exploration.py +++ b/rllib/utils/exploration/exploration.py @@ -17,8 +17,8 @@ class Exploration: def __init__(self, action_space: Space, - num_workers: int = 0, - worker_index: int = 0, + num_workers: int, + worker_index: int, framework: str = "tf"): """ Args: diff --git a/rllib/utils/exploration/per_worker_epsilon_greedy.py b/rllib/utils/exploration/per_worker_epsilon_greedy.py index 5d79b88ee..ed6d07f90 100644 --- a/rllib/utils/exploration/per_worker_epsilon_greedy.py +++ b/rllib/utils/exploration/per_worker_epsilon_greedy.py @@ -10,12 +10,7 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy): See Ape-X paper. """ - def __init__(self, - action_space, - *, - num_workers=0, - worker_index=0, - framework="tf", + def __init__(self, action_space, *, num_workers, worker_index, framework, **kwargs): """Create a PerWorkerEpsilonGreedy exploration class. @@ -28,17 +23,21 @@ class PerWorkerEpsilonGreedy(EpsilonGreedy): """ epsilon_schedule = None # Use a fixed, different epsilon per worker. See: Ape-X paper. + assert worker_index <= num_workers, (worker_index, num_workers) if num_workers > 0: if worker_index >= 0: exponent = (1 + worker_index / float(num_workers - 1) * 7) - epsilon_schedule = ConstantSchedule(0.4**exponent) + epsilon_schedule = ConstantSchedule( + 0.4**exponent, framework=framework) # Local worker should have zero exploration so that eval # rollouts run properly. else: - epsilon_schedule = ConstantSchedule(0.0) + epsilon_schedule = ConstantSchedule(0.0, framework=framework) super().__init__( action_space, epsilon_schedule=epsilon_schedule, framework=framework, + num_workers=num_workers, + worker_index=worker_index, **kwargs) diff --git a/rllib/utils/exploration/per_worker_gaussian_noise.py b/rllib/utils/exploration/per_worker_gaussian_noise.py index cca902b17..72834849d 100644 --- a/rllib/utils/exploration/per_worker_gaussian_noise.py +++ b/rllib/utils/exploration/per_worker_gaussian_noise.py @@ -13,8 +13,8 @@ class PerWorkerGaussianNoise(GaussianNoise): def __init__(self, action_space, *, - num_workers=0, - worker_index=0, + num_workers, + worker_index, framework="tf", **kwargs): """ @@ -30,11 +30,12 @@ class PerWorkerGaussianNoise(GaussianNoise): if num_workers > 0: if worker_index >= 0: exponent = (1 + worker_index / float(num_workers - 1) * 7) - scale_schedule = ConstantSchedule(0.4**exponent) + scale_schedule = ConstantSchedule( + 0.4**exponent, framework=framework) # Local worker should have zero exploration so that eval # rollouts run properly. else: - scale_schedule = ConstantSchedule(0.0) + scale_schedule = ConstantSchedule(0.0, framework=framework) super().__init__( action_space, diff --git a/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py b/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py index 21eb804da..acc5ec019 100644 --- a/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py +++ b/rllib/utils/exploration/per_worker_ornstein_uhlenbeck_noise.py @@ -14,8 +14,8 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise): def __init__(self, action_space, *, - num_workers=0, - worker_index=0, + num_workers, + worker_index, framework="tf", **kwargs): """ @@ -31,14 +31,17 @@ class PerWorkerOrnsteinUhlenbeckNoise(OrnsteinUhlenbeckNoise): if num_workers > 0: if worker_index >= 0: exponent = (1 + worker_index / float(num_workers - 1) * 7) - scale_schedule = ConstantSchedule(0.4**exponent) + scale_schedule = ConstantSchedule( + 0.4**exponent, framework=framework) # Local worker should have zero exploration so that eval # rollouts run properly. else: - scale_schedule = ConstantSchedule(0.0) + scale_schedule = ConstantSchedule(0.0, framework=framework) super().__init__( action_space, scale_schedule=scale_schedule, + num_workers=num_workers, + worker_index=worker_index, framework=framework, **kwargs) diff --git a/rllib/utils/schedules/constant_schedule.py b/rllib/utils/schedules/constant_schedule.py index 321c438d4..0a52aceec 100644 --- a/rllib/utils/schedules/constant_schedule.py +++ b/rllib/utils/schedules/constant_schedule.py @@ -6,13 +6,13 @@ class ConstantSchedule(Schedule): A Schedule where the value remains constant over time. """ - def __init__(self, value, framework=None): + def __init__(self, value, framework): """ Args: value (float): The constant value to return, independently of time. """ - super().__init__(framework=None) + super().__init__(framework=framework) self._v = value - def _value(self, t=None): + def _value(self, t): return self._v diff --git a/rllib/utils/schedules/exponential_schedule.py b/rllib/utils/schedules/exponential_schedule.py index 6cf452661..507797062 100644 --- a/rllib/utils/schedules/exponential_schedule.py +++ b/rllib/utils/schedules/exponential_schedule.py @@ -4,9 +4,9 @@ from ray.rllib.utils.schedules.schedule import Schedule class ExponentialSchedule(Schedule): def __init__(self, schedule_timesteps, + framework, initial_p=1.0, - decay_rate=0.1, - framework=None): + decay_rate=0.1): """ Exponential decay schedule from initial_p to final_p over schedule_timesteps. After this many time steps always `final_p` is diff --git a/rllib/utils/schedules/piecewise_schedule.py b/rllib/utils/schedules/piecewise_schedule.py index c39e58d9c..febf2d220 100644 --- a/rllib/utils/schedules/piecewise_schedule.py +++ b/rllib/utils/schedules/piecewise_schedule.py @@ -8,9 +8,9 @@ def _linear_interpolation(l, r, alpha): class PiecewiseSchedule(Schedule): def __init__(self, endpoints, + framework, interpolation=_linear_interpolation, - outside_value=None, - framework=None): + outside_value=None): """ Args: endpoints (List[Tuple[int,float]]): A list of tuples diff --git a/rllib/utils/schedules/polynomial_schedule.py b/rllib/utils/schedules/polynomial_schedule.py index 9ddc10a44..d015ff205 100644 --- a/rllib/utils/schedules/polynomial_schedule.py +++ b/rllib/utils/schedules/polynomial_schedule.py @@ -8,9 +8,9 @@ class PolynomialSchedule(Schedule): def __init__(self, schedule_timesteps, final_p, + framework, initial_p=1.0, - power=2.0, - framework=None): + power=2.0): """ Polynomial interpolation between initial_p and final_p over schedule_timesteps. After this many time steps always `final_p` is diff --git a/rllib/utils/schedules/schedule.py b/rllib/utils/schedules/schedule.py index f35d42e0f..65b6566ef 100644 --- a/rllib/utils/schedules/schedule.py +++ b/rllib/utils/schedules/schedule.py @@ -22,7 +22,7 @@ class Schedule(metaclass=ABCMeta): value and returns the value dependent on the Schedule and the passed time. """ - def __init__(self, framework=None): + def __init__(self, framework): self.framework = check_framework(framework) @abstractmethod diff --git a/rllib/utils/tests/test_framework_agnostic_components.py b/rllib/utils/tests/test_framework_agnostic_components.py index 82fa4a66e..8d6305f4a 100644 --- a/rllib/utils/tests/test_framework_agnostic_components.py +++ b/rllib/utils/tests/test_framework_agnostic_components.py @@ -75,10 +75,13 @@ class TestFrameWorkAgnosticComponents(unittest.TestCase): check(component.add(-1.1).numpy(), -2.1) # prop_b == -1.0 # Test recognizing default package path. - component = from_config(Exploration, { - "type": "EpsilonGreedy", - "action_space": Discrete(2) - }) + component = from_config( + Exploration, { + "type": "EpsilonGreedy", + "action_space": Discrete(2), + "num_workers": 0, + "worker_index": 0, + }) check(component.epsilon_schedule.outside_value, 0.05) # default # Create torch Component from yaml-string.