diff --git a/python/ray/rllib/agents/a3c/a3c.py b/python/ray/rllib/agents/a3c/a3c.py index f37af7c42..d5f201459 100644 --- a/python/ray/rllib/agents/a3c/a3c.py +++ b/python/ray/rllib/agents/a3c/a3c.py @@ -94,8 +94,8 @@ class A3CAgent(Agent): self.env_creator, policy_cls, self.config["num_workers"], {"num_gpus": 1 if self.config["use_gpu_for_workers"] else 0}) self.optimizer = AsyncGradientsOptimizer( - self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.local_evaluator, self.remote_evaluators, + self.config["optimizer"]) def _train(self): self.optimizer.step() diff --git a/python/ray/rllib/agents/bc/bc.py b/python/ray/rllib/agents/bc/bc.py index 8dee9f6e9..1484a5dbe 100644 --- a/python/ray/rllib/agents/bc/bc.py +++ b/python/ray/rllib/agents/bc/bc.py @@ -72,8 +72,8 @@ class BCAgent(Agent): remote_cls.remote(self.env_creator, self.config, self.logdir) for _ in range(self.config["num_workers"])] self.optimizer = AsyncGradientsOptimizer( - self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.local_evaluator, self.remote_evaluators, + self.config["optimizer"]) def _train(self): self.optimizer.step() diff --git a/python/ray/rllib/agents/dqn/dqn.py b/python/ray/rllib/agents/dqn/dqn.py index ba1224732..f60b43fbe 100644 --- a/python/ray/rllib/agents/dqn/dqn.py +++ b/python/ray/rllib/agents/dqn/dqn.py @@ -136,8 +136,8 @@ class DQNAgent(Agent): {"num_cpus": self.config["num_cpus_per_worker"], "num_gpus": self.config["num_gpus_per_worker"]}) self.optimizer = getattr(optimizers, self.config["optimizer_class"])( - self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.local_evaluator, self.remote_evaluators, + self.config["optimizer"]) self.last_target_update_ts = 0 self.num_target_updates = 0 diff --git a/python/ray/rllib/agents/pg/pg.py b/python/ray/rllib/agents/pg/pg.py index 05a600cdb..b971c8126 100644 --- a/python/ray/rllib/agents/pg/pg.py +++ b/python/ray/rllib/agents/pg/pg.py @@ -45,8 +45,8 @@ class PGAgent(Agent): self.remote_evaluators = self.make_remote_evaluators( self.env_creator, PGPolicyGraph, self.config["num_workers"], {}) self.optimizer = SyncSamplesOptimizer( - self.config["optimizer"], self.local_evaluator, - self.remote_evaluators) + self.local_evaluator, self.remote_evaluators, + self.config["optimizer"]) def _train(self): self.optimizer.step() diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index a83c10f3b..0ed3e03be 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -74,11 +74,11 @@ class PPOAgent(Agent): {"num_cpus": self.config["num_cpus_per_worker"], "num_gpus": self.config["num_gpus_per_worker"]}) self.optimizer = LocalMultiGPUOptimizer( + self.local_evaluator, self.remote_evaluators, {"sgd_batch_size": self.config["sgd_batchsize"], "sgd_stepsize": self.config["sgd_stepsize"], "num_sgd_iter": self.config["num_sgd_iter"], - "timesteps_per_batch": self.config["timesteps_per_batch"]}, - self.local_evaluator, self.remote_evaluators) + "timesteps_per_batch": self.config["timesteps_per_batch"]}) def _train(self): def postprocess_samples(batch): diff --git a/python/ray/rllib/evaluation/policy_graph.py b/python/ray/rllib/evaluation/policy_graph.py index e156bdde2..96ca8e660 100644 --- a/python/ray/rllib/evaluation/policy_graph.py +++ b/python/ray/rllib/evaluation/policy_graph.py @@ -106,6 +106,22 @@ class PolicyGraph(object): """ raise NotImplementedError + def compute_apply(self, samples): + """Fused compute gradients and apply gradients call. + + Returns: + grad_info: dictionary of extra metadata from compute_gradients(). + apply_info: dictionary of extra metadata from apply_gradients(). + + Examples: + >>> batch = ev.sample() + >>> ev.compute_apply(samples) + """ + + grads, grad_info = self.compute_gradients(samples) + apply_info = self.apply_gradients(grads) + return grad_info, apply_info + def get_weights(self): """Returns model weights. diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 4a30b7521..5d78e5e82 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -31,7 +31,7 @@ class PolicyOptimizer(object): evaluators created by this optimizer. """ - def __init__(self, config, local_evaluator, remote_evaluators): + def __init__(self, local_evaluator, remote_evaluators=None, config=None): """Create an optimizer instance. Args: @@ -41,10 +41,10 @@ class PolicyOptimizer(object): evaluators instances. If empty, the optimizer should fall back to using only the local evaluator. """ - self.config = config self.local_evaluator = local_evaluator - self.remote_evaluators = remote_evaluators - self._init(**config) + self.remote_evaluators = remote_evaluators or [] + self.config = config or {} + self._init(**self.config) # Counters that should be updated by sub-classes self.num_steps_trained = 0 diff --git a/python/ray/rllib/optimizers/sync_samples_optimizer.py b/python/ray/rllib/optimizers/sync_samples_optimizer.py index c1c8e7c1a..ba6eb4cef 100644 --- a/python/ray/rllib/optimizers/sync_samples_optimizer.py +++ b/python/ray/rllib/optimizers/sync_samples_optimizer.py @@ -40,8 +40,7 @@ class SyncSamplesOptimizer(PolicyOptimizer): samples = self.local_evaluator.sample() with self.grad_timer: - grad, _ = self.local_evaluator.compute_gradients(samples) - self.local_evaluator.apply_gradients(grad) + self.local_evaluator.compute_apply(samples) self.grad_timer.push_units_processed(samples.count) self.num_steps_sampled += samples.count diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index e1146dcca..8d6f38b35 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -296,7 +296,7 @@ class TestMultiAgentEnv(unittest.TestCase): batch_steps=50)] else: remote_evs = [] - optimizer = optimizer_cls({}, ev, remote_evs) + optimizer = optimizer_cls(ev, remote_evs, {}) for i in range(200): ev.foreach_policy( lambda p, _: p.set_epsilon(max(0.02, 1 - i * .02)) @@ -338,7 +338,7 @@ class TestMultiAgentEnv(unittest.TestCase): policy_graph=policies, policy_mapping_fn=lambda agent_id: random.choice(policy_ids), batch_steps=100) - optimizer = SyncSamplesOptimizer({}, ev, []) + optimizer = SyncSamplesOptimizer(ev, [], {}) for i in range(100): optimizer.step() result = collect_metrics(ev) diff --git a/python/ray/rllib/test/test_optimizers.py b/python/ray/rllib/test/test_optimizers.py index f3a4fc917..a64079dc7 100644 --- a/python/ray/rllib/test/test_optimizers.py +++ b/python/ray/rllib/test/test_optimizers.py @@ -21,9 +21,8 @@ class AsyncOptimizerTest(unittest.TestCase): local = _MockEvaluator() remotes = ray.remote(_MockEvaluator) remote_evaluators = [remotes.remote() for i in range(5)] - test_optimizer = AsyncGradientsOptimizer({ - "grads_per_step": 10 - }, local, remote_evaluators) + test_optimizer = AsyncGradientsOptimizer( + local, remote_evaluators, {"grads_per_step": 10}) test_optimizer.step() self.assertTrue(all(local.get_weights() == 0))