mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 21:56:20 +08:00
[rllib] Set ape-x local exploration to 0, also load explorations before training steps (#3349)
## What do these changes do? This should fix high explorations being used after restore / for rollouts. ## Related issue number (dev list issue)
This commit is contained in:
@@ -118,9 +118,16 @@ class DDPGAgent(DQNAgent):
|
||||
if self.config["per_worker_exploration"]:
|
||||
assert self.config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
exponent = (
|
||||
1 + worker_index / float(self.config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(self.config["noise_scale"] * 0.4**exponent)
|
||||
if worker_index >= 0:
|
||||
exponent = (
|
||||
1 +
|
||||
worker_index / float(self.config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(
|
||||
self.config["noise_scale"] * 0.4**exponent)
|
||||
else:
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
else:
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(self.config["exploration_fraction"] *
|
||||
|
||||
@@ -124,7 +124,7 @@ class DQNAgent(Agent):
|
||||
self.config["n_step"])
|
||||
self.config["sample_batch_size"] = adjusted_batch_size
|
||||
|
||||
self.exploration0 = self._make_exploration_schedule(0)
|
||||
self.exploration0 = self._make_exploration_schedule(-1)
|
||||
self.explorations = [
|
||||
self._make_exploration_schedule(i)
|
||||
for i in range(self.config["num_workers"])
|
||||
@@ -170,9 +170,15 @@ class DQNAgent(Agent):
|
||||
if self.config["per_worker_exploration"]:
|
||||
assert self.config["num_workers"] > 1, \
|
||||
"This requires multiple workers"
|
||||
exponent = (
|
||||
1 + worker_index / float(self.config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
if worker_index >= 0:
|
||||
exponent = (
|
||||
1 +
|
||||
worker_index / float(self.config["num_workers"] - 1) * 7)
|
||||
return ConstantSchedule(0.4**exponent)
|
||||
else:
|
||||
# local ev should have zero exploration so that eval rollouts
|
||||
# run properly
|
||||
return ConstantSchedule(0.0)
|
||||
return LinearSchedule(
|
||||
schedule_timesteps=int(self.config["exploration_fraction"] *
|
||||
self.config["schedule_max_timesteps"]),
|
||||
@@ -194,13 +200,7 @@ class DQNAgent(Agent):
|
||||
def _train(self):
|
||||
start_timestep = self.global_timestep
|
||||
|
||||
start = time.time()
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]
|
||||
) or time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
self.update_target_if_needed()
|
||||
|
||||
# Update worker explorations
|
||||
exp_vals = [self.exploration0.value(self.global_timestep)]
|
||||
self.local_evaluator.foreach_trainable_policy(
|
||||
lambda p, _: p.set_epsilon(exp_vals[0]))
|
||||
@@ -210,6 +210,14 @@ class DQNAgent(Agent):
|
||||
lambda p, _: p.set_epsilon(exp_val))
|
||||
exp_vals.append(exp_val)
|
||||
|
||||
# Do optimization steps
|
||||
start = time.time()
|
||||
while (self.global_timestep - start_timestep <
|
||||
self.config["timesteps_per_iteration"]
|
||||
) or time.time() - start < self.config["min_iter_time_s"]:
|
||||
self.optimizer.step()
|
||||
self.update_target_if_needed()
|
||||
|
||||
if self.config["per_worker_exploration"]:
|
||||
# Only collect metrics from the third of workers with lowest eps
|
||||
result = collect_metrics(
|
||||
|
||||
Reference in New Issue
Block a user