From 5972c29d28534c499e028cd2491ed72376937129 Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 19 Nov 2018 20:36:25 -0800 Subject: [PATCH] [rllib] Set ape-x local exploration to 0, also load explorations before training steps (#3349) ## What do these changes do? This should fix high explorations being used after restore / for rollouts. ## Related issue number (dev list issue) --- python/ray/rllib/agents/ddpg/ddpg.py | 13 +++++++++--- python/ray/rllib/agents/dqn/dqn.py | 30 ++++++++++++++++++---------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/python/ray/rllib/agents/ddpg/ddpg.py b/python/ray/rllib/agents/ddpg/ddpg.py index b3dded59d..e5980805b 100644 --- a/python/ray/rllib/agents/ddpg/ddpg.py +++ b/python/ray/rllib/agents/ddpg/ddpg.py @@ -118,9 +118,16 @@ class DDPGAgent(DQNAgent): if self.config["per_worker_exploration"]: assert self.config["num_workers"] > 1, \ "This requires multiple workers" - exponent = ( - 1 + worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(self.config["noise_scale"] * 0.4**exponent) + if worker_index >= 0: + exponent = ( + 1 + + worker_index / float(self.config["num_workers"] - 1) * 7) + return ConstantSchedule( + self.config["noise_scale"] * 0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) else: return LinearSchedule( schedule_timesteps=int(self.config["exploration_fraction"] * diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index 29003b4a3..cdace0865 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -124,7 +124,7 @@ class DQNAgent(Agent): self.config["n_step"]) self.config["sample_batch_size"] = adjusted_batch_size - self.exploration0 = self._make_exploration_schedule(0) + self.exploration0 = self._make_exploration_schedule(-1) self.explorations = [ self._make_exploration_schedule(i) for i in range(self.config["num_workers"]) @@ -170,9 +170,15 @@ class DQNAgent(Agent): if self.config["per_worker_exploration"]: assert self.config["num_workers"] > 1, \ "This requires multiple workers" - exponent = ( - 1 + worker_index / float(self.config["num_workers"] - 1) * 7) - return ConstantSchedule(0.4**exponent) + if worker_index >= 0: + exponent = ( + 1 + + worker_index / float(self.config["num_workers"] - 1) * 7) + return ConstantSchedule(0.4**exponent) + else: + # local ev should have zero exploration so that eval rollouts + # run properly + return ConstantSchedule(0.0) return LinearSchedule( schedule_timesteps=int(self.config["exploration_fraction"] * self.config["schedule_max_timesteps"]), @@ -194,13 +200,7 @@ class DQNAgent(Agent): def _train(self): start_timestep = self.global_timestep - start = time.time() - while (self.global_timestep - start_timestep < - self.config["timesteps_per_iteration"] - ) or time.time() - start < self.config["min_iter_time_s"]: - self.optimizer.step() - self.update_target_if_needed() - + # Update worker explorations exp_vals = [self.exploration0.value(self.global_timestep)] self.local_evaluator.foreach_trainable_policy( lambda p, _: p.set_epsilon(exp_vals[0])) @@ -210,6 +210,14 @@ class DQNAgent(Agent): lambda p, _: p.set_epsilon(exp_val)) exp_vals.append(exp_val) + # Do optimization steps + start = time.time() + while (self.global_timestep - start_timestep < + self.config["timesteps_per_iteration"] + ) or time.time() - start < self.config["min_iter_time_s"]: + self.optimizer.step() + self.update_target_if_needed() + if self.config["per_worker_exploration"]: # Only collect metrics from the third of workers with lowest eps result = collect_metrics(