mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
make es worker count independent (#740)
This commit is contained in:
committed by
Robert Nishihara
parent
80e8426b5e
commit
86a7909149
@@ -30,8 +30,8 @@ Result = namedtuple("Result", [
|
||||
DEFAULT_CONFIG = dict(
|
||||
l2coeff=0.005,
|
||||
noise_stdev=0.02,
|
||||
episodes_per_batch=10000,
|
||||
timesteps_per_batch=100000,
|
||||
episodes_per_batch=1000,
|
||||
timesteps_per_batch=10000,
|
||||
calc_obstat_prob=0.01,
|
||||
eval_prob=0,
|
||||
snapshot_freq=0,
|
||||
@@ -188,6 +188,25 @@ class EvolutionStrategies(Algorithm):
|
||||
self.tstart = time.time()
|
||||
self.iteration = 0
|
||||
|
||||
def _collect_results(self, theta_id, min_eps, min_timesteps):
|
||||
num_eps, num_timesteps = 0, 0
|
||||
results = []
|
||||
while num_eps < min_eps or num_timesteps < min_timesteps:
|
||||
print(
|
||||
"Collected {} episodes {} timesteps so far this iter".format(
|
||||
num_eps, num_timesteps))
|
||||
rollout_ids = [worker.do_rollouts.remote(
|
||||
theta_id,
|
||||
self.ob_stat.mean if self.policy.needs_ob_stat else None,
|
||||
self.ob_stat.std if self.policy.needs_ob_stat else None)
|
||||
for worker in self.workers]
|
||||
# Get the results of the rollouts.
|
||||
for result in ray.get(rollout_ids):
|
||||
results.append(result)
|
||||
num_eps += result.lengths_n2.size
|
||||
num_timesteps += result.lengths_n2.sum()
|
||||
return results
|
||||
|
||||
def train(self):
|
||||
config = self.config
|
||||
|
||||
@@ -199,14 +218,10 @@ class EvolutionStrategies(Algorithm):
|
||||
theta_id = ray.put(theta)
|
||||
# Use the actors to do rollouts, note that we pass in the ID of the
|
||||
# policy weights.
|
||||
rollout_ids = [worker.do_rollouts.remote(
|
||||
results = self._collect_results(
|
||||
theta_id,
|
||||
self.ob_stat.mean if self.policy.needs_ob_stat else None,
|
||||
self.ob_stat.std if self.policy.needs_ob_stat else None)
|
||||
for worker in self.workers]
|
||||
|
||||
# Get the results of the rollouts.
|
||||
results = ray.get(rollout_ids)
|
||||
config["episodes_per_batch"],
|
||||
config["timesteps_per_batch"])
|
||||
|
||||
curr_task_results = []
|
||||
ob_count_this_batch = 0
|
||||
|
||||
Reference in New Issue
Block a user