From 981d9818c13731abd3826f069d41a2e6f39663fb Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 6 Aug 2018 12:10:59 -0700 Subject: [PATCH] [rllib] Support the timesteps_per_batch in simple optimizer PPO mode (#2558) * support ts * doc * Update sync_samples_optimizer.py --- doc/source/rllib-package-ref.rst | 1 + python/ray/rllib/agents/ppo/ppo.py | 6 ++++-- .../optimizers/sync_samples_optimizer.py | 19 ++++++++++++------- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/doc/source/rllib-package-ref.rst b/doc/source/rllib-package-ref.rst index 38a578dbd..da8bf138e 100644 --- a/doc/source/rllib-package-ref.rst +++ b/doc/source/rllib-package-ref.rst @@ -14,6 +14,7 @@ ray.rllib.agents .. autoclass:: ray.rllib.agents.dqn.DQNAgent .. autoclass:: ray.rllib.agents.es.ESAgent .. autoclass:: ray.rllib.agents.pg.PGAgent +.. autoclass:: ray.rllib.agents.impala.ImpalaAgent .. autoclass:: ray.rllib.agents.ppo.PPOAgent ray.rllib.env diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index cfdd0d4cb..7a3697867 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -83,8 +83,10 @@ class PPOAgent(Agent): }) if self.config["simple_optimizer"]: self.optimizer = SyncSamplesOptimizer( - self.local_evaluator, self.remote_evaluators, - {"num_sgd_iter": self.config["num_sgd_iter"]}) + self.local_evaluator, self.remote_evaluators, { + "num_sgd_iter": self.config["num_sgd_iter"], + "timesteps_per_batch": self.config["timesteps_per_batch"] + }) else: self.optimizer = LocalMultiGPUOptimizer( self.local_evaluator, self.remote_evaluators, { diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index 76d2d9c46..7af87fcd3 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -17,12 +17,13 @@ class SyncSamplesOptimizer(PolicyOptimizer): model weights are then broadcast to all remote evaluators. """ - def _init(self, num_sgd_iter=1): + def _init(self, num_sgd_iter=1, timesteps_per_batch=1): self.update_weights_timer = TimerStat() self.sample_timer = TimerStat() self.grad_timer = TimerStat() self.throughput = RunningStat() self.num_sgd_iter = num_sgd_iter + self.timesteps_per_batch = timesteps_per_batch def step(self): with self.update_weights_timer: @@ -32,12 +33,16 @@ class SyncSamplesOptimizer(PolicyOptimizer): e.set_weights.remote(weights) with self.sample_timer: - if self.remote_evaluators: - samples = SampleBatch.concat_samples( - ray.get( - [e.sample.remote() for e in self.remote_evaluators])) - else: - samples = self.local_evaluator.sample() + samples = [] + while sum(s.count for s in samples) < self.timesteps_per_batch: + if self.remote_evaluators: + samples.extend( + ray.get([ + e.sample.remote() for e in self.remote_evaluators + ])) + else: + samples.append(self.local_evaluator.sample()) + samples = SampleBatch.concat_samples(samples) with self.grad_timer: for i in range(self.num_sgd_iter):