[rllib] Set ape-x local exploration to 0, also load explorations before training steps (#3349)

## What do these changes do?

This should fix high explorations being used after restore / for rollouts.

## Related issue number

(dev list issue)
This commit is contained in:
Eric Liang
2018-11-19 20:36:25 -08:00
committed by Richard Liaw
parent afc48d7b77
commit 5972c29d28
2 changed files with 29 additions and 14 deletions
+10 -3
View File
@@ -118,9 +118,16 @@ class DDPGAgent(DQNAgent):
if self.config["per_worker_exploration"]:
assert self.config["num_workers"] > 1, \
"This requires multiple workers"
exponent = (
1 + worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(self.config["noise_scale"] * 0.4**exponent)
if worker_index >= 0:
exponent = (
1 +
worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(
self.config["noise_scale"] * 0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
else:
return LinearSchedule(
schedule_timesteps=int(self.config["exploration_fraction"] *
+19 -11
View File
@@ -124,7 +124,7 @@ class DQNAgent(Agent):
self.config["n_step"])
self.config["sample_batch_size"] = adjusted_batch_size
self.exploration0 = self._make_exploration_schedule(0)
self.exploration0 = self._make_exploration_schedule(-1)
self.explorations = [
self._make_exploration_schedule(i)
for i in range(self.config["num_workers"])
@@ -170,9 +170,15 @@ class DQNAgent(Agent):
if self.config["per_worker_exploration"]:
assert self.config["num_workers"] > 1, \
"This requires multiple workers"
exponent = (
1 + worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(0.4**exponent)
if worker_index >= 0:
exponent = (
1 +
worker_index / float(self.config["num_workers"] - 1) * 7)
return ConstantSchedule(0.4**exponent)
else:
# local ev should have zero exploration so that eval rollouts
# run properly
return ConstantSchedule(0.0)
return LinearSchedule(
schedule_timesteps=int(self.config["exploration_fraction"] *
self.config["schedule_max_timesteps"]),
@@ -194,13 +200,7 @@ class DQNAgent(Agent):
def _train(self):
start_timestep = self.global_timestep
start = time.time()
while (self.global_timestep - start_timestep <
self.config["timesteps_per_iteration"]
) or time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
self.update_target_if_needed()
# Update worker explorations
exp_vals = [self.exploration0.value(self.global_timestep)]
self.local_evaluator.foreach_trainable_policy(
lambda p, _: p.set_epsilon(exp_vals[0]))
@@ -210,6 +210,14 @@ class DQNAgent(Agent):
lambda p, _: p.set_epsilon(exp_val))
exp_vals.append(exp_val)
# Do optimization steps
start = time.time()
while (self.global_timestep - start_timestep <
self.config["timesteps_per_iteration"]
) or time.time() - start < self.config["min_iter_time_s"]:
self.optimizer.step()
self.update_target_if_needed()
if self.config["per_worker_exploration"]:
# Only collect metrics from the third of workers with lowest eps
result = collect_metrics(