From a9a26b756098dcbd108966c52e3e7e9c9bfaf7ee Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 25 Jun 2018 22:33:57 -0700 Subject: [PATCH] [rllib] Part 2 of multiagent support (#2286) * wip * cls * re * wip * wip * a3c working * torch support * pg works * lint * rm v2 * consumer id * clean up pg * clean up more * fix python 2.7 * tf session management * docs * dqn wip * fix compile * dqn * apex runs * up * impotrs * ddpg * quotes * fix tests * fix last r * fix tests * lint * pass checkpoint restore * kwar * nits * policy graph * fix yapf * com * class * pyt * vectorization * update * test cpe * unit test * fix ddpg2 * changes * wip * args * faster test * common * fix * add alg option * batch mode and policy serving * multi serving test * todo * wip * serving test * doc async env * num envs * comments * thread * remove init hook * update * fix ppo * comments1 * fix * updates * add jenkins tests * fix * fix pytorch * fix * fixes * fix a3c policy * fix squeeze * fix trunc on apex * fix squeezing for real * update * remove horizon test for now * multiagent wip * update * fix race condition * fix ma * t * doc * st * wip * example * wip * working * cartpole * wip * batch wip * fix bug * make other_batches None default * working * debug * nit * warn * comments * fix ppo * fix obs filter * update * fix obs filter * pass thru worker index * fix * fix log action * debug name * fix sphinx --- python/ray/rllib/a3c/a3c.py | 19 +- python/ray/rllib/a3c/a3c_tf_policy.py | 4 +- python/ray/rllib/a3c/a3c_torch_policy.py | 2 + python/ray/rllib/agent.py | 3 +- python/ray/rllib/ddpg/ddpg.py | 10 +- python/ray/rllib/ddpg/ddpg_policy_graph.py | 4 +- python/ray/rllib/dqn/dqn.py | 27 +- python/ray/rllib/dqn/dqn_policy_graph.py | 5 +- python/ray/rllib/es/es.py | 1 - .../ray/rllib/examples/multiagent_cartpole.py | 70 +++++ python/ray/rllib/optimizers/apex_optimizer.py | 1 + .../ray/rllib/optimizers/async_optimizer.py | 3 + .../ray/rllib/optimizers/local_sync_replay.py | 91 ++++--- python/ray/rllib/optimizers/multi_gpu.py | 3 +- .../ray/rllib/optimizers/policy_evaluator.py | 9 +- .../ray/rllib/optimizers/policy_optimizer.py | 10 +- python/ray/rllib/optimizers/sample_batch.py | 57 +++- python/ray/rllib/pg/pg.py | 12 +- python/ray/rllib/pg/pg_policy_graph.py | 4 +- python/ray/rllib/ppo/ppo.py | 1 - .../test/test_common_policy_evaluator.py | 16 ++ python/ray/rllib/test/test_multi_agent_env.py | 235 ++++++++++++++-- python/ray/rllib/utils/async_vector_env.py | 10 +- .../rllib/utils/common_policy_evaluator.py | 255 ++++++++++++++---- python/ray/rllib/utils/env_context.py | 22 ++ python/ray/rllib/utils/policy_graph.py | 4 + python/ray/rllib/utils/sampler.py | 77 ++++-- python/ray/rllib/utils/serving_env.py | 2 + python/ray/rllib/utils/tf_policy_graph.py | 96 ++++--- python/ray/rllib/utils/tf_run_builder.py | 82 ++++++ python/ray/tune/result.py | 3 + test/jenkins_tests/run_multi_node_tests.sh | 3 + 32 files changed, 939 insertions(+), 202 deletions(-) create mode 100644 python/ray/rllib/examples/multiagent_cartpole.py create mode 100644 python/ray/rllib/utils/env_context.py create mode 100644 python/ray/rllib/utils/tf_run_builder.py diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 04c9ce4df..f18ebc05c 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -63,13 +63,18 @@ DEFAULT_CONFIG = { }, # Arguments to pass to the env creator "env_config": {}, + + # === Multiagent === + "multiagent": { + "policy_graphs": {}, + "policy_mapping_fn": None, + }, } class A3CAgent(Agent): _agent_name = "A3C" _default_config = DEFAULT_CONFIG - _allow_unknown_subkeys = ["model", "optimizer", "env_config"] @classmethod def default_resource_request(cls, config): @@ -98,7 +103,9 @@ class A3CAgent(Agent): remote_cls = CommonPolicyEvaluator.as_remote( num_gpus=1 if self.config["use_gpu_for_workers"] else 0) self.local_evaluator = CommonPolicyEvaluator( - self.env_creator, self.policy_cls, + self.env_creator, + self.config["multiagent"]["policy_graphs"] or self.policy_cls, + policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", tf_session_creator=session_creator, @@ -107,13 +114,17 @@ class A3CAgent(Agent): num_envs=self.config["num_envs"]) self.remote_evaluators = [ remote_cls.remote( - self.env_creator, self.policy_cls, + self.env_creator, + self.config["multiagent"]["policy_graphs"] or self.policy_cls, + policy_mapping_fn=( + self.config["multiagent"]["policy_mapping_fn"]), batch_steps=self.config["batch_size"], batch_mode="truncate_episodes", sample_async=True, tf_session_creator=session_creator, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, - num_envs=self.config["num_envs"]) + num_envs=self.config["num_envs"], + worker_index=i+1) for i in range(self.config["num_workers"])] self.optimizer = AsyncOptimizer( diff --git a/python/ray/rllib/a3c/a3c_tf_policy.py b/python/ray/rllib/a3c/a3c_tf_policy.py index 8532734c2..a23d4b9c4 100644 --- a/python/ray/rllib/a3c/a3c_tf_policy.py +++ b/python/ray/rllib/a3c/a3c_tf_policy.py @@ -5,6 +5,7 @@ from __future__ import print_function import tensorflow as tf import gym +import ray from ray.rllib.utils.error import UnsupportedSpaceException from ray.rllib.utils.process_rollout import compute_advantages from ray.rllib.utils.tf_policy_graph import TFPolicyGraph @@ -14,6 +15,7 @@ class A3CTFPolicyGraph(TFPolicyGraph): """The TF policy base class.""" def __init__(self, ob_space, action_space, config): + config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config) self.local_steps = 0 self.config = config self.summarize = config.get("summarize") @@ -27,7 +29,7 @@ class A3CTFPolicyGraph(TFPolicyGraph): self.sess = tf.get_default_session() TFPolicyGraph.__init__( - self, self.sess, obs_input=self.x, + self, ob_space, action_space, self.sess, obs_input=self.x, action_sampler=self.action_dist.sample(), loss=self.loss, loss_inputs=self.loss_in, is_training=self.is_training, state_inputs=self.state_in, state_outputs=self.state_out) diff --git a/python/ray/rllib/a3c/a3c_torch_policy.py b/python/ray/rllib/a3c/a3c_torch_policy.py index a1cb5d866..79ae75c11 100644 --- a/python/ray/rllib/a3c/a3c_torch_policy.py +++ b/python/ray/rllib/a3c/a3c_torch_policy.py @@ -8,6 +8,7 @@ from threading import Lock import torch import torch.nn.functional as F +import ray from ray.rllib.models.pytorch.misc import var_to_np, convert_batch from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.process_rollout import compute_advantages @@ -18,6 +19,7 @@ class SharedTorchPolicy(PolicyGraph): """A simple, non-recurrent PyTorch policy example.""" def __init__(self, obs_space, action_space, config): + config = dict(ray.rllib.a3c.a3c.DEFAULT_CONFIG, **config) PolicyGraph.__init__(self, obs_space, action_space, config) self.local_steps = 0 self.config = config diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index 5e1db81c0..195d76a9b 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -59,7 +59,8 @@ class Agent(Trainable): """ _allow_unknown_configs = False - _allow_unknown_subkeys = ["env_config", "model", "optimizer"] + _allow_unknown_subkeys = [ + "tf_session_args", "env_config", "model", "optimizer", "multiagent"] @classmethod def resource_help(cls, config): diff --git a/python/ray/rllib/ddpg/ddpg.py b/python/ray/rllib/ddpg/ddpg.py index cf4bf4431..adb323843 100644 --- a/python/ray/rllib/ddpg/ddpg.py +++ b/python/ray/rllib/ddpg/ddpg.py @@ -108,14 +108,18 @@ DEFAULT_CONFIG = { # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. - "worker_side_prioritization": False + "worker_side_prioritization": False, + + # === Multiagent === + "multiagent": { + "policy_graphs": {}, + "policy_mapping_fn": None, + }, } class DDPGAgent(DQNAgent): _agent_name = "DDPG" - _allow_unknown_subkeys = [ - "model", "optimizer", "tf_session_args", "env_config"] _default_config = DEFAULT_CONFIG _policy_graph = DDPGPolicyGraph diff --git a/python/ray/rllib/ddpg/ddpg_policy_graph.py b/python/ray/rllib/ddpg/ddpg_policy_graph.py index da1b64a30..a76d3fa8d 100644 --- a/python/ray/rllib/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/ddpg/ddpg_policy_graph.py @@ -82,6 +82,7 @@ def _build_q_network(inputs, action_inputs, config): class DDPGPolicyGraph(TFPolicyGraph): def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.ddpg.ddpg.DEFAULT_CONFIG, **config) if not isinstance(action_space, Box): raise UnsupportedSpaceException( "Action space {} is not supported for DDPG.".format( @@ -232,7 +233,8 @@ class DDPGPolicyGraph(TFPolicyGraph): ] self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( - self, self.sess, obs_input=self.cur_observations, + self, observation_space, action_space, self.sess, + obs_input=self.cur_observations, action_sampler=self.output_actions, loss=self.loss, loss_inputs=self.loss_inputs, is_training=self.is_training) self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 83dc1078e..960b30185 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -102,14 +102,18 @@ DEFAULT_CONFIG = { # Whether to use a distribution of epsilons across workers for exploration. "per_worker_exploration": False, # Whether to compute priorities on workers. - "worker_side_prioritization": False + "worker_side_prioritization": False, + + # === Multiagent === + "multiagent": { + "policy_graphs": {}, + "policy_mapping_fn": None, + }, } class DQNAgent(Agent): _agent_name = "DQN" - _allow_unknown_subkeys = [ - "model", "optimizer", "tf_session_args", "env_config"] _default_config = DEFAULT_CONFIG _policy_graph = DQNPolicyGraph @@ -125,7 +129,9 @@ class DQNAgent(Agent): adjusted_batch_size = ( self.config["sample_batch_size"] + self.config["n_step"] - 1) self.local_evaluator = CommonPolicyEvaluator( - self.env_creator, self._policy_graph, + self.env_creator, + self.config["multiagent"]["policy_graphs"] or self._policy_graph, + policy_mapping_fn=self.config["multiagent"]["policy_mapping_fn"], batch_steps=adjusted_batch_size, batch_mode="truncate_episodes", preprocessor_pref="deepmind", compress_observations=True, @@ -143,8 +149,9 @@ class DQNAgent(Agent): compress_observations=True, env_config=self.config["env_config"], model_config=self.config["model"], policy_config=self.config, - num_envs=self.config["num_envs"]) - for _ in range(self.config["num_workers"])] + num_envs=self.config["num_envs"], + worker_index=i+1) + for i in range(self.config["num_workers"])] self.exploration0 = self._make_exploration_schedule(0) self.explorations = [ @@ -185,7 +192,7 @@ class DQNAgent(Agent): def update_target_if_needed(self): if self.global_timestep - self.last_target_update_ts > \ self.config["target_network_update_freq"]: - self.local_evaluator.for_policy(lambda p: p.update_target()) + self.local_evaluator.foreach_policy(lambda p, _: p.update_target()) self.last_target_update_ts = self.global_timestep self.num_target_updates += 1 @@ -198,11 +205,11 @@ class DQNAgent(Agent): self.update_target_if_needed() exp_vals = [self.exploration0.value(self.global_timestep)] - self.local_evaluator.for_policy( - lambda p: p.set_epsilon(exp_vals[0])) + self.local_evaluator.foreach_policy( + lambda p, _: p.set_epsilon(exp_vals[0])) for i, e in enumerate(self.remote_evaluators): exp_val = self.explorations[i].value(self.global_timestep) - e.for_policy.remote(lambda p: p.set_epsilon(exp_val)) + e.foreach_policy.remote(lambda p, _: p.set_epsilon(exp_val)) exp_vals.append(exp_val) result = collect_metrics( diff --git a/python/ray/rllib/dqn/dqn_policy_graph.py b/python/ray/rllib/dqn/dqn_policy_graph.py index 9c7ceedc4..5db7bc651 100644 --- a/python/ray/rllib/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/dqn/dqn_policy_graph.py @@ -7,6 +7,7 @@ import numpy as np import tensorflow as tf import tensorflow.contrib.layers as layers +import ray from ray.rllib.models import ModelCatalog from ray.rllib.optimizers.sample_batch import SampleBatch from ray.rllib.utils.error import UnsupportedSpaceException @@ -47,6 +48,7 @@ def adjust_nstep(n_step, gamma, obs, actions, rewards, new_obs, dones): class DQNPolicyGraph(TFPolicyGraph): def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.dqn.dqn.DEFAULT_CONFIG, **config) if not isinstance(action_space, Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format( @@ -144,7 +146,8 @@ class DQNPolicyGraph(TFPolicyGraph): ] self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( - self, self.sess, obs_input=self.cur_observations, + self, observation_space, action_space, self.sess, + obs_input=self.cur_observations, action_sampler=self.output_actions, loss=self.loss, loss_inputs=self.loss_inputs, is_training=self.is_training) self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index 8e5dbe064..b900f88a7 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -137,7 +137,6 @@ class Worker(object): class ESAgent(agent.Agent): _agent_name = "ES" _default_config = DEFAULT_CONFIG - _allow_unknown_subkeys = ["env_config"] @classmethod def default_resource_request(cls, config): diff --git a/python/ray/rllib/examples/multiagent_cartpole.py b/python/ray/rllib/examples/multiagent_cartpole.py new file mode 100644 index 000000000..158fec293 --- /dev/null +++ b/python/ray/rllib/examples/multiagent_cartpole.py @@ -0,0 +1,70 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +"""Simple example of setting up a multi-agent policy mapping. + +Control the number of agents and policies via --num-agents and --num-policies. + +This works with hundreds of agents and policies, but note that initializing +many TF policy graphs will take some time. + +Also, TF evals might slow down with large numbers of policies. To debug TF +execution, set the TF_TIMELINE_DIR environment variable. +""" + +import argparse +import gym +import random + +import ray +from ray.rllib.pg.pg import PGAgent +from ray.rllib.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.test.test_multi_agent_env import MultiCartpole +from ray.tune.logger import pretty_print +from ray.tune.registry import register_env + + +parser = argparse.ArgumentParser() + +parser.add_argument("--num-agents", type=int, default=4) +parser.add_argument("--num-policies", type=int, default=2) +parser.add_argument("--num-iters", type=int, default=20) + + +if __name__ == "__main__": + args = parser.parse_args() + ray.init() + + # Simple environment with `num_agents` independent cartpole entities + register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents)) + single_env = gym.make("CartPole-v0") + obs_space = single_env.observation_space + act_space = single_env.action_space + + def gen_policy(): + config = { + "gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]), + "n_step": random.choice([1, 2, 3, 4, 5]), + } + return (PGPolicyGraph, obs_space, act_space, config) + + # Setup PG with an ensemble of `num_policies` different policy graphs + policy_graphs = { + "policy_{}".format(i): gen_policy() for i in range(args.num_policies) + } + policy_ids = list(policy_graphs.keys()) + + agent = PGAgent( + env="multi_cartpole", + config={ + "multiagent": { + "policy_graphs": policy_graphs, + "policy_mapping_fn": ( + lambda agent_id: random.choice(policy_ids)), + }, + }) + + for i in range(args.num_iters): + print("== Iteration", i, "==") + print(pretty_print(agent.train())) diff --git a/python/ray/rllib/optimizers/apex_optimizer.py b/python/ray/rllib/optimizers/apex_optimizer.py index e113213b4..cb07342f3 100644 --- a/python/ray/rllib/optimizers/apex_optimizer.py +++ b/python/ray/rllib/optimizers/apex_optimizer.py @@ -217,6 +217,7 @@ class ApexOptimizer(PolicyOptimizer): with self.timers["sample_processing"]: for ev, sample_batch in self.sample_tasks.completed(): + self._check_not_multiagent(sample_batch) sample_timesteps += self.sample_batch_size # Send the data to the replay buffer diff --git a/python/ray/rllib/optimizers/async_optimizer.py b/python/ray/rllib/optimizers/async_optimizer.py index 93c363345..2fd253e95 100644 --- a/python/ray/rllib/optimizers/async_optimizer.py +++ b/python/ray/rllib/optimizers/async_optimizer.py @@ -20,6 +20,9 @@ class AsyncOptimizer(PolicyOptimizer): self.dispatch_timer = TimerStat() self.grads_per_step = grads_per_step self.batch_size = batch_size + if not self.remote_evaluators: + raise ValueError( + "Async optimizer requires at least 1 remote evaluator") def step(self): weights = ray.put(self.local_evaluator.get_weights()) diff --git a/python/ray/rllib/optimizers/local_sync_replay.py b/python/ray/rllib/optimizers/local_sync_replay.py index ac430c6a1..b6545cb85 100644 --- a/python/ray/rllib/optimizers/local_sync_replay.py +++ b/python/ray/rllib/optimizers/local_sync_replay.py @@ -2,13 +2,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import numpy as np import ray from ray.rllib.optimizers.replay_buffer import ReplayBuffer, \ PrioritizedReplayBuffer from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.optimizers.sample_batch import SampleBatch +from ray.rllib.optimizers.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \ + MultiAgentBatch from ray.rllib.utils.compression import pack_if_needed from ray.rllib.utils.filter import RunningStat from ray.rllib.utils.timer import TimerStat @@ -41,11 +43,15 @@ class LocalSyncReplayOptimizer(PolicyOptimizer): # Set up replay buffer if prioritized_replay: - self.replay_buffer = PrioritizedReplayBuffer( - buffer_size, alpha=prioritized_replay_alpha, - clip_rewards=clip_rewards) + def new_buffer(): + return PrioritizedReplayBuffer( + buffer_size, alpha=prioritized_replay_alpha, + clip_rewards=clip_rewards) else: - self.replay_buffer = ReplayBuffer(buffer_size, clip_rewards) + def new_buffer(): + return ReplayBuffer(buffer_size, clip_rewards) + + self.replay_buffers = collections.defaultdict(new_buffer) assert buffer_size >= self.replay_starts @@ -63,47 +69,64 @@ class LocalSyncReplayOptimizer(PolicyOptimizer): [e.sample.remote() for e in self.remote_evaluators])) else: batch = self.local_evaluator.sample() - for row in batch.rows(): - self.replay_buffer.add( - pack_if_needed(row["obs"]), row["actions"], row["rewards"], - pack_if_needed(row["new_obs"]), - row["dones"], row["weights"]) - if len(self.replay_buffer) >= self.replay_starts: + # Handle everything as if multiagent + if isinstance(batch, SampleBatch): + batch = MultiAgentBatch( + {DEFAULT_POLICY_ID: batch}, batch.count) + + for policy_id, s in batch.policy_batches.items(): + for row in s.rows(): + if "weights" not in row: + row["weights"] = np.ones_like(row["rewards"]) + self.replay_buffers[policy_id].add( + pack_if_needed(row["obs"]), row["actions"], + row["rewards"], pack_if_needed(row["new_obs"]), + row["dones"], row["weights"]) + + if self.num_steps_sampled >= self.replay_starts: self._optimize() self.num_steps_sampled += batch.count def _optimize(self): - with self.replay_timer: - if isinstance(self.replay_buffer, PrioritizedReplayBuffer): - (obses_t, actions, rewards, obses_tp1, - dones, weights, batch_indexes) = self.replay_buffer.sample( - self.train_batch_size, - beta=self.prioritized_replay_beta) - else: - (obses_t, actions, rewards, obses_tp1, - dones) = self.replay_buffer.sample( - self.train_batch_size) - weights = np.ones_like(rewards) - batch_indexes = - np.ones_like(rewards) - samples = SampleBatch({ - "obs": obses_t, "actions": actions, "rewards": rewards, - "new_obs": obses_tp1, "dones": dones, "weights": weights, - "batch_indexes": batch_indexes}) + samples = self._replay() with self.grad_timer: - info = self.local_evaluator.compute_apply(samples) - if isinstance(self.replay_buffer, PrioritizedReplayBuffer): - td_error = info["td_error"] - new_priorities = ( - np.abs(td_error) + self.prioritized_replay_eps) - self.replay_buffer.update_priorities( - samples["batch_indexes"], new_priorities) + info_dict = self.local_evaluator.compute_apply(samples) + for policy_id, info in info_dict.items(): + replay_buffer = self.replay_buffers[policy_id] + if isinstance(replay_buffer, PrioritizedReplayBuffer): + td_error = info["td_error"] + new_priorities = ( + np.abs(td_error) + self.prioritized_replay_eps) + replay_buffer.update_priorities( + samples.policy_batches[policy_id]["batch_indexes"], + new_priorities) self.grad_timer.push_units_processed(samples.count) self.num_steps_trained += samples.count + def _replay(self): + samples = {} + with self.replay_timer: + for policy_id, replay_buffer in self.replay_buffers.items(): + if isinstance(replay_buffer, PrioritizedReplayBuffer): + (obses_t, actions, rewards, obses_tp1, + dones, weights, batch_indexes) = replay_buffer.sample( + self.train_batch_size, + beta=self.prioritized_replay_beta) + else: + (obses_t, actions, rewards, obses_tp1, + dones) = replay_buffer.sample(self.train_batch_size) + weights = np.ones_like(rewards) + batch_indexes = - np.ones_like(rewards) + samples[policy_id] = SampleBatch({ + "obs": obses_t, "actions": actions, "rewards": rewards, + "new_obs": obses_tp1, "dones": dones, "weights": weights, + "batch_indexes": batch_indexes}) + return MultiAgentBatch(samples, self.train_batch_size) + def stats(self): return dict(PolicyOptimizer.stats(self), **{ "sample_time_ms": round(1000 * self.sample_timer.mean, 3), diff --git a/python/ray/rllib/optimizers/multi_gpu.py b/python/ray/rllib/optimizers/multi_gpu.py index 2e002534f..fedfb6dbb 100644 --- a/python/ray/rllib/optimizers/multi_gpu.py +++ b/python/ray/rllib/optimizers/multi_gpu.py @@ -9,7 +9,6 @@ import tensorflow as tf import ray from ray.rllib.optimizers.policy_evaluator import TFMultiGPUSupport from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer -from ray.rllib.optimizers.sample_batch import SampleBatch from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer from ray.rllib.utils.timer import TimerStat @@ -90,7 +89,7 @@ class LocalMultiGPUOptimizer(PolicyOptimizer): self.timesteps_per_batch) else: samples = self.local_evaluator.sample() - assert isinstance(samples, SampleBatch) + self._check_not_multiagent(samples) if postprocess_fn: postprocess_fn(samples) diff --git a/python/ray/rllib/optimizers/policy_evaluator.py b/python/ray/rllib/optimizers/policy_evaluator.py index e62dea20b..e3bf9518e 100644 --- a/python/ray/rllib/optimizers/policy_evaluator.py +++ b/python/ray/rllib/optimizers/policy_evaluator.py @@ -20,7 +20,8 @@ class PolicyEvaluator(object): This method must be implemented by subclasses. Returns: - SampleBatch: A columnar batch of experiences (e.g., tensors). + SampleBatch|MultiAgentBatch: A columnar batch of experiences + (e.g., tensors), or a multi-agent batch. Examples: >>> print(ev.sample()) @@ -35,8 +36,10 @@ class PolicyEvaluator(object): This method must be implemented by subclasses. Returns: - object: A gradient that can be applied on a compatible evaluator. - info: dictionary of extra metadata. + (grads, info): A list of gradients that can be applied on a + compatible evaluator. In the multi-agent case, returns a dict + of gradients keyed by policy graph ids. An info dictionary of + extra metadata is also returned. Examples: >>> batch = ev.sample() diff --git a/python/ray/rllib/optimizers/policy_optimizer.py b/python/ray/rllib/optimizers/policy_optimizer.py index 96e40a9d8..f44aa4847 100644 --- a/python/ray/rllib/optimizers/policy_optimizer.py +++ b/python/ray/rllib/optimizers/policy_optimizer.py @@ -3,6 +3,7 @@ from __future__ import division from __future__ import print_function import ray +from ray.rllib.optimizers.sample_batch import MultiAgentBatch class PolicyOptimizer(object): @@ -54,8 +55,8 @@ class PolicyOptimizer(object): else: local_evaluator = evaluator_cls(**evaluator_args) remote_evaluators = [ - remote_cls.remote(**evaluator_args) - for _ in range(num_workers)] + remote_cls.remote(worker_index=i+1, **evaluator_args) + for i in range(num_workers)] return cls(optimizer_config, local_evaluator, remote_evaluators) def __init__(self, config, local_evaluator, remote_evaluators): @@ -130,3 +131,8 @@ class PolicyOptimizer(object): [ev.apply.remote(func, i + 1) for i, ev in enumerate(self.remote_evaluators)]) return local_result + remote_results + + def _check_not_multiagent(self, sample_batch): + if isinstance(sample_batch, MultiAgentBatch): + raise NotImplementedError( + "This optimizer does not support multi-agent yet.") diff --git a/python/ray/rllib/optimizers/sample_batch.py b/python/ray/rllib/optimizers/sample_batch.py index 83df66aa2..620eced0f 100644 --- a/python/ray/rllib/optimizers/sample_batch.py +++ b/python/ray/rllib/optimizers/sample_batch.py @@ -6,6 +6,10 @@ import collections import numpy as np +# Defaults policy id for single agent environments +DEFAULT_POLICY_ID = "default" + + class SampleBatchBuilder(object): """Util to build a SampleBatch incrementally. @@ -107,7 +111,7 @@ class MultiAgentSampleBatchBuilder(object): pre_batch, other_batches) # Append into policy batches and reset - for agent_id, post_batch in post_batches.items(): + for agent_id, post_batch in sorted(post_batches.items()): self.policy_builders[self.agent_to_policy[agent_id]].add_batch( post_batch) self.agent_builders.clear() @@ -122,33 +126,62 @@ class MultiAgentSampleBatchBuilder(object): self.postprocess_batch_so_far() policy_batches = {} - for policy_id, policy_batch_builder in self.policy_builders.items(): - policy_batches[policy_id] = policy_batch_builder.build_and_reset() + for policy_id, builder in self.policy_builders.items(): + if builder.count > 0: + policy_batches[policy_id] = builder.build_and_reset() + old_count = self.count self.count = 0 - return MultiAgentBatch.wrap_as_needed(policy_batches) + return MultiAgentBatch.wrap_as_needed(policy_batches, old_count) class MultiAgentBatch(object): - def __init__(self, policy_batches): + """A batch of experiences from multiple policies in the environment. + + Attributes: + policy_batches (dict): Mapping from policy id to a normal SampleBatch + of experiences. Note that these batches may be of different length. + count (int): The number of timesteps in the environment this batch + contains. This will be less than the number of transitions this + batch contains across all policies in total. + """ + + def __init__(self, policy_batches, count): self.policy_batches = policy_batches + self.count = count @staticmethod - def wrap_as_needed(batches): - if len(batches) == 1 and "default" in batches: - return batches["default"] - return MultiAgentBatch(batches) + def wrap_as_needed(batches, count): + if len(batches) == 1 and DEFAULT_POLICY_ID in batches: + return batches[DEFAULT_POLICY_ID] + return MultiAgentBatch(batches, count) @staticmethod def concat_samples(samples): policy_batches = collections.defaultdict(list) + total_count = 0 for s in samples: assert isinstance(s, MultiAgentBatch) for policy_id, batch in s.policy_batches.items(): policy_batches[policy_id].append(batch) + total_count += s.count out = {} for policy_id, batches in policy_batches.items(): out[policy_id] = SampleBatch.concat_samples(batches) - return MultiAgentBatch(out) + return MultiAgentBatch(out, total_count) + + def total(self): + ct = 0 + for batch in self.policy_batches.values(): + ct += batch.count + return ct + + def __str__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) + + def __repr__(self): + return "MultiAgentBatch({}, count={})".format( + str(self.policy_batches), self.count) class SampleBatch(object): @@ -166,11 +199,15 @@ class SampleBatch(object): for k, v in self.data.copy().items(): assert type(k) == str, self lengths.append(len(v)) + if not lengths: + raise ValueError("Empty sample batch") assert len(set(lengths)) == 1, "data columns must be same length" self.count = lengths[0] @staticmethod def concat_samples(samples): + if isinstance(samples[0], MultiAgentBatch): + return MultiAgentBatch.concat_samples(samples) out = {} samples = [s for s in samples if s.count > 0] for k in samples[0].keys(): diff --git a/python/ray/rllib/pg/pg.py b/python/ray/rllib/pg/pg.py index 1ca4eb493..7d78c3c38 100644 --- a/python/ray/rllib/pg/pg.py +++ b/python/ray/rllib/pg/pg.py @@ -29,6 +29,12 @@ DEFAULT_CONFIG = { "model": {"fcnet_hiddens": [128, 128]}, # Arguments to pass to the env creator "env_config": {}, + + # === Multiagent === + "multiagent": { + "policy_graphs": {}, + "policy_mapping_fn": None, + }, } @@ -52,7 +58,11 @@ class PGAgent(Agent): evaluator_cls=CommonPolicyEvaluator, evaluator_args={ "env_creator": self.env_creator, - "policy_graph": PGPolicyGraph, + "policy_graph": ( + self.config["multiagent"]["policy_graphs"] or + PGPolicyGraph), + "policy_mapping_fn": + self.config["multiagent"]["policy_mapping_fn"], "batch_steps": self.config["batch_size"], "batch_mode": "truncate_episodes", "model_config": self.config["model"], diff --git a/python/ray/rllib/pg/pg_policy_graph.py b/python/ray/rllib/pg/pg_policy_graph.py index af4518bd0..b2c3e1f8f 100644 --- a/python/ray/rllib/pg/pg_policy_graph.py +++ b/python/ray/rllib/pg/pg_policy_graph.py @@ -4,6 +4,7 @@ from __future__ import print_function import tensorflow as tf +import ray from ray.rllib.models.catalog import ModelCatalog from ray.rllib.utils.process_rollout import compute_advantages from ray.rllib.utils.tf_policy_graph import TFPolicyGraph @@ -12,6 +13,7 @@ from ray.rllib.utils.tf_policy_graph import TFPolicyGraph class PGPolicyGraph(TFPolicyGraph): def __init__(self, obs_space, action_space, config): + config = dict(ray.rllib.pg.pg.DEFAULT_CONFIG, **config) self.config = config # setup policy @@ -36,7 +38,7 @@ class PGPolicyGraph(TFPolicyGraph): ] self.is_training = tf.placeholder_with_default(True, ()) TFPolicyGraph.__init__( - self, self.sess, obs_input=self.x, + self, obs_space, action_space, self.sess, obs_input=self.x, action_sampler=self.dist.sample(), loss=self.loss, loss_inputs=self.loss_in, is_training=self.is_training) self.sess.run(tf.global_variables_initializer()) diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 144241c44..a609b2429 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -88,7 +88,6 @@ DEFAULT_CONFIG = { class PPOAgent(Agent): _agent_name = "PPO" - _allow_unknown_subkeys = ["model", "tf_session_args", "env_config"] _default_config = DEFAULT_CONFIG @classmethod diff --git a/python/ray/rllib/test/test_common_policy_evaluator.py b/python/ray/rllib/test/test_common_policy_evaluator.py index 10a31b098..c256f2780 100644 --- a/python/ray/rllib/test/test_common_policy_evaluator.py +++ b/python/ray/rllib/test/test_common_policy_evaluator.py @@ -48,6 +48,22 @@ class MockEnv(gym.Env): return 0, 1, self.i >= self.episode_length, {} +class MockEnv2(gym.Env): + def __init__(self, episode_length): + self.episode_length = episode_length + self.i = 0 + self.observation_space = gym.spaces.Discrete(100) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + self.i = 0 + return self.i + + def step(self, action): + self.i += 1 + return self.i, 100, self.i >= self.episode_length, {} + + class MockVectorEnv(VectorEnv): def __init__(self, episode_length, num_envs): self.envs = [ diff --git a/python/ray/rllib/test/test_multi_agent_env.py b/python/ray/rllib/test/test_multi_agent_env.py index 2e8b8169c..a37810568 100644 --- a/python/ray/rllib/test/test_multi_agent_env.py +++ b/python/ray/rllib/test/test_multi_agent_env.py @@ -2,12 +2,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gym +import random import unittest import ray -from ray.rllib.test.test_common_policy_evaluator import MockEnv +from ray.rllib.pg import PGAgent +from ray.rllib.pg.pg_policy_graph import PGPolicyGraph +from ray.rllib.dqn.dqn_policy_graph import DQNPolicyGraph +from ray.rllib.optimizers import LocalSyncOptimizer, \ + LocalSyncReplayOptimizer, AsyncOptimizer +from ray.rllib.test.test_common_policy_evaluator import MockEnv, MockEnv2, \ + MockPolicyGraph +from ray.rllib.utils.common_policy_evaluator import CommonPolicyEvaluator, \ + collect_metrics from ray.rllib.utils.async_vector_env import _MultiAgentEnvToAsync from ray.rllib.utils.multi_agent_env import MultiAgentEnv +from ray.tune.registry import register_env class BasicMultiAgent(MultiAgentEnv): @@ -16,6 +27,8 @@ class BasicMultiAgent(MultiAgentEnv): def __init__(self, num): self.agents = [MockEnv(25) for _ in range(num)] self.dones = set() + self.observation_space = gym.spaces.Discrete(2) + self.action_space = gym.spaces.Discrete(2) def reset(self): self.dones = set() @@ -36,8 +49,13 @@ class RoundRobinMultiAgent(MultiAgentEnv): On each step() of the env, only one agent takes an action.""" - def __init__(self, num): - self.agents = [MockEnv(5) for _ in range(num)] + def __init__(self, num, increment_obs=False): + if increment_obs: + # Observations are 0, 1, 2, 3... etc. as time advances + self.agents = [MockEnv2(5) for _ in range(num)] + else: + # Observations are all zeros + self.agents = [MockEnv(5) for _ in range(num)] self.dones = set() self.last_obs = {} self.last_rew = {} @@ -45,24 +63,59 @@ class RoundRobinMultiAgent(MultiAgentEnv): self.last_info = {} self.i = 0 self.num = num + self.observation_space = gym.spaces.Discrete(2) + self.action_space = gym.spaces.Discrete(2) def reset(self): self.dones = set() - return {i: a.reset() for i, a in enumerate(self.agents)} + self.last_obs = {} + self.last_rew = {} + self.last_done = {} + self.last_info = {} + self.i = 0 + for i, a in enumerate(self.agents): + self.last_obs[i] = a.reset() + self.last_rew[i] = None + self.last_done[i] = False + self.last_info[i] = {} + obs_dict = {self.i: self.last_obs[self.i]} + self.i = (self.i + 1) % self.num + return obs_dict def step(self, action_dict): assert len(self.dones) != len(self.agents) for i, action in action_dict.items(): (self.last_obs[i], self.last_rew[i], self.last_done[i], self.last_info[i]) = self.agents[i].step(action) - if self.last_done[i]: + obs = {self.i: self.last_obs[self.i]} + rew = {self.i: self.last_rew[self.i]} + done = {self.i: self.last_done[self.i]} + info = {self.i: self.last_info[self.i]} + if done[self.i]: + rew[self.i] = 0 + self.dones.add(self.i) + self.i = (self.i + 1) % self.num + done["__all__"] = len(self.dones) == len(self.agents) + return obs, rew, done, info + + +class MultiCartpole(MultiAgentEnv): + def __init__(self, num): + self.agents = [gym.make("CartPole-v0") for _ in range(num)] + self.dones = set() + self.observation_space = self.agents[0].observation_space + self.action_space = self.agents[0].action_space + + def reset(self): + self.dones = set() + return {i: a.reset() for i, a in enumerate(self.agents)} + + def step(self, action_dict): + obs, rew, done, info = {}, {}, {}, {} + for i, action in action_dict.items(): + obs[i], rew[i], done[i], info[i] = self.agents[i].step(action) + if done[i]: self.dones.add(i) - obs = {self.i: self.last_obs[i]} - rew = {self.i: self.last_rew[i]} - done = {self.i: self.last_done[i]} - info = {self.i: self.last_info[i]} - self.i += 1 - self.i %= self.num done["__all__"] = len(self.dones) == len(self.agents) return obs, rew, done, info @@ -86,15 +139,15 @@ class TestMultiAgentEnv(unittest.TestCase): def testRoundRobinMock(self): env = RoundRobinMultiAgent(2) obs = env.reset() - self.assertEqual(obs, {0: 0, 1: 0}) - obs, rew, done, info = env.step({0: 0, 1: 0}) self.assertEqual(obs, {0: 0}) - for _ in range(4): + for _ in range(5): obs, rew, done, info = env.step({0: 0}) self.assertEqual(obs, {1: 0}) self.assertEqual(done["__all__"], False) obs, rew, done, info = env.step({1: 0}) self.assertEqual(obs, {0: 0}) + self.assertEqual(done["__all__"], False) + obs, rew, done, info = env.step({0: 0}) self.assertEqual(done["__all__"], True) def testVectorizeBasic(self): @@ -140,14 +193,160 @@ class TestMultiAgentEnv(unittest.TestCase): def testVectorizeRoundRobin(self): env = _MultiAgentEnvToAsync(lambda: RoundRobinMultiAgent(2), [], 2) obs, rew, dones, _, _ = env.poll() - self.assertEqual(obs, {0: {0: 0, 1: 0}, 1: {0: 0, 1: 0}}) - self.assertEqual(rew, {0: {0: None, 1: None}, 1: {0: None, 1: None}}) - env.send_actions({0: {0: 0}, 1: {0: 0}}) - obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}}) + self.assertEqual(rew, {0: {0: None}, 1: {0: None}}) env.send_actions({0: {0: 0}, 1: {0: 0}}) obs, rew, dones, _, _ = env.poll() self.assertEqual(obs, {0: {1: 0}, 1: {1: 0}}) + env.send_actions({0: {1: 0}, 1: {1: 0}}) + obs, rew, dones, _, _ = env.poll() + self.assertEqual(obs, {0: {0: 0}, 1: {0: 0}}) + + def testMultiAgentSample(self): + act_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(2) + ev = CommonPolicyEvaluator( + env_creator=lambda _: BasicMultiAgent(5), + policy_graph={ + "p0": (MockPolicyGraph, obs_space, act_space, {}), + "p1": (MockPolicyGraph, obs_space, act_space, {}), + }, + policy_mapping_fn=lambda agent_id: "p{}".format(agent_id % 2), + batch_steps=50) + batch = ev.sample() + self.assertEqual(batch.count, 50) + self.assertEqual(batch.policy_batches["p0"].count, 150) + self.assertEqual(batch.policy_batches["p1"].count, 100) + self.assertEqual( + batch.policy_batches["p0"]["t"].tolist(), + list(range(25)) * 6) + + def testMultiAgentSampleRoundRobin(self): + act_space = gym.spaces.Discrete(2) + obs_space = gym.spaces.Discrete(2) + ev = CommonPolicyEvaluator( + env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), + policy_graph={ + "p0": (MockPolicyGraph, obs_space, act_space, {}), + }, + policy_mapping_fn=lambda agent_id: "p0", + batch_steps=50) + batch = ev.sample() + self.assertEqual(batch.count, 50) + # since we round robin introduce agents into the env, some of the env + # steps don't count as proper transitions + self.assertEqual(batch.policy_batches["p0"].count, 42) + self.assertEqual( + batch.policy_batches["p0"]["obs"].tolist()[:10], + [0, 1, 2, 3, 4] * 2) + self.assertEqual( + batch.policy_batches["p0"]["new_obs"].tolist()[:10], + [1, 2, 3, 4, 5] * 2) + self.assertEqual( + batch.policy_batches["p0"]["rewards"].tolist()[:10], + [100, 100, 100, 100, 0] * 2) + self.assertEqual( + batch.policy_batches["p0"]["dones"].tolist()[:10], + [False, False, False, False, True] * 2) + self.assertEqual( + batch.policy_batches["p0"]["t"].tolist()[:10], + [4, 9, 14, 19, 24, 5, 10, 15, 20, 25]) + + def testTrainMultiCartpoleSinglePolicy(self): + n = 10 + register_env("multi_cartpole", lambda _: MultiCartpole(n)) + pg = PGAgent(env="multi_cartpole", config={"num_workers": 0}) + for i in range(100): + result = pg.train() + print("Iteration {}, reward {}, timesteps {}".format( + i, result.episode_reward_mean, result.timesteps_total)) + if result.episode_reward_mean >= 50 * n: + return + raise Exception("failed to improve reward") + + def _testWithOptimizer(self, optimizer_cls): + n = 3 + env = gym.make("CartPole-v0") + act_space = env.action_space + obs_space = env.observation_space + dqn_config = {"gamma": 0.95, "n_step": 3} + if optimizer_cls == LocalSyncReplayOptimizer: + # TODO: support replay with non-DQN graphs. Currently this can't + # happen since the replay buffer doesn't encode extra fields like + # "advantages" that PG uses. + policies = { + "p1": (DQNPolicyGraph, obs_space, act_space, {}), + "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + } + else: + policies = { + "p1": (PGPolicyGraph, obs_space, act_space, dqn_config), + "p2": (DQNPolicyGraph, obs_space, act_space, dqn_config), + } + ev = CommonPolicyEvaluator( + env_creator=lambda _: MultiCartpole(n), + policy_graph=policies, + policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], + batch_steps=50) + if optimizer_cls == AsyncOptimizer: + remote_evs = [CommonPolicyEvaluator.as_remote().remote( + env_creator=lambda _: MultiCartpole(n), + policy_graph=policies, + policy_mapping_fn=lambda agent_id: ["p1", "p2"][agent_id % 2], + batch_steps=50)] + else: + remote_evs = [] + optimizer = optimizer_cls({}, ev, remote_evs) + ev.foreach_policy( + lambda p, _: p.set_epsilon(0.02) + if isinstance(p, DQNPolicyGraph) else None) + for i in range(200): + optimizer.step() + result = collect_metrics(ev, remote_evs) + if i % 20 == 0: + ev.foreach_policy( + lambda p, _: p.update_target() + if isinstance(p, DQNPolicyGraph) else None) + print("Iter {}, rew {}".format(i, result.policy_reward_mean)) + print("Total reward", result.episode_reward_mean) + if result.episode_reward_mean >= 25 * n: + return + print(result) + raise Exception("failed to improve reward") + + def testMultiAgentSyncOptimizer(self): + self._testWithOptimizer(LocalSyncOptimizer) + + def testMultiAgentAsyncOptimizer(self): + self._testWithOptimizer(AsyncOptimizer) + + def testMultiAgentReplayOptimizer(self): + self._testWithOptimizer(LocalSyncReplayOptimizer) + + def testTrainMultiCartpoleManyPolicies(self): + n = 20 + env = gym.make("CartPole-v0") + act_space = env.action_space + obs_space = env.observation_space + policies = {} + for i in range(20): + policies["pg_{}".format(i)] = ( + PGPolicyGraph, obs_space, act_space, {}) + policy_ids = list(policies.keys()) + ev = CommonPolicyEvaluator( + env_creator=lambda _: MultiCartpole(n), + policy_graph=policies, + policy_mapping_fn=lambda agent_id: random.choice(policy_ids), + batch_steps=100) + optimizer = LocalSyncOptimizer({}, ev, []) + for i in range(100): + optimizer.step() + result = collect_metrics(ev) + print("Iteration {}, rew {}".format(i, result.policy_reward_mean)) + print("Total reward", result.episode_reward_mean) + if result.episode_reward_mean >= 25 * n: + return + raise Exception("failed to improve reward") if __name__ == '__main__': diff --git a/python/ray/rllib/utils/async_vector_env.py b/python/ray/rllib/utils/async_vector_env.py index 266907a3a..268a7896c 100644 --- a/python/ray/rllib/utils/async_vector_env.py +++ b/python/ray/rllib/utils/async_vector_env.py @@ -284,14 +284,12 @@ class _MultiAgentEnvState(object): self.reset() def poll(self): - if self.last_obs is None: - raise ValueError("Need to send action after polling") obs, rew, dones, info = ( self.last_obs, self.last_rewards, self.last_dones, self.last_infos) - self.last_obs = None - self.last_rewards = None - self.last_dones = None - self.last_infos = None + self.last_obs = {} + self.last_rewards = {} + self.last_dones = {"__all__": False} + self.last_infos = {} return obs, rew, dones, info def observe(self, obs, rewards, dones, infos): diff --git a/python/ray/rllib/utils/common_policy_evaluator.py b/python/ray/rllib/utils/common_policy_evaluator.py index c5b0e1e03..bb1891629 100644 --- a/python/ray/rllib/utils/common_policy_evaluator.py +++ b/python/ray/rllib/utils/common_policy_evaluator.py @@ -2,31 +2,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import pickle +import collections +import gym import numpy as np +import pickle import tensorflow as tf import ray from ray.rllib.models import ModelCatalog -from ray.rllib.optimizers import MultiAgentBatch from ray.rllib.optimizers.policy_evaluator import PolicyEvaluator +from ray.rllib.optimizers.sample_batch import MultiAgentBatch, \ + DEFAULT_POLICY_ID from ray.rllib.utils.async_vector_env import AsyncVectorEnv from ray.rllib.utils.atari_wrappers import wrap_deepmind, is_atari from ray.rllib.utils.compression import pack +from ray.rllib.utils.env_context import EnvContext from ray.rllib.utils.filter import get_filter from ray.rllib.utils.multi_agent_env import MultiAgentEnv +from ray.rllib.utils.policy_graph import PolicyGraph from ray.rllib.utils.sampler import AsyncSampler, SyncSampler from ray.rllib.utils.serving_env import ServingEnv from ray.rllib.utils.tf_policy_graph import TFPolicyGraph +from ray.rllib.utils.tf_run_builder import TFRunBuilder from ray.rllib.utils.vector_env import VectorEnv from ray.tune.result import TrainingResult -def collect_metrics(local_evaluator, remote_evaluators): +def collect_metrics(local_evaluator, remote_evaluators=[]): """Gathers episode metrics from CommonPolicyEvaluator instances.""" episode_rewards = [] episode_lengths = [] + policy_rewards = collections.defaultdict(list) metric_lists = ray.get( [a.apply.remote(lambda ev: ev.sampler.get_metrics()) for a in remote_evaluators]) @@ -35,6 +42,8 @@ def collect_metrics(local_evaluator, remote_evaluators): for episode in metrics: episode_lengths.append(episode.episode_length) episode_rewards.append(episode.episode_reward) + for (_, policy_id), reward in episode.agent_rewards.items(): + policy_rewards[policy_id].append(reward) if episode_rewards: min_reward = min(episode_rewards) max_reward = max(episode_rewards) @@ -45,19 +54,22 @@ def collect_metrics(local_evaluator, remote_evaluators): avg_length = np.mean(episode_lengths) timesteps = np.sum(episode_lengths) + for policy_id, rewards in policy_rewards.copy().items(): + policy_rewards[policy_id] = np.mean(rewards) + return TrainingResult( episode_reward_max=max_reward, episode_reward_min=min_reward, episode_reward_mean=avg_reward, episode_len_mean=avg_length, episodes_total=len(episode_lengths), - timesteps_this_iter=timesteps) + timesteps_this_iter=timesteps, + policy_reward_mean=dict(policy_rewards)) class CommonPolicyEvaluator(PolicyEvaluator): """Policy evaluator implementation that operates on a rllib.PolicyGraph. - TODO: multi-agent TODO: multi-gpu Examples: @@ -65,9 +77,10 @@ class CommonPolicyEvaluator(PolicyEvaluator): >>> evaluator = CommonPolicyEvaluator( env_creator=lambda _: gym.make("CartPole-v0"), policy_graph=PGPolicyGraph) - >>> print(evaluator.sample().keys()) - {"obs": [[...]], "actions": [[...]], "rewards": [[...]], - "dones": [[...]], "new_obs": [[...]]} + >>> print(evaluator.sample()) + SampleBatch({ + "obs": [[...]], "actions": [[...]], "rewards": [[...]], + "dones": [[...]], "new_obs": [[...]]}) # Creating policy evaluators using optimizer_cls.make(). >>> optimizer = LocalSyncOptimizer.make( @@ -78,6 +91,28 @@ class CommonPolicyEvaluator(PolicyEvaluator): }, num_workers=10) >>> for _ in range(10): optimizer.step() + + # Creating a multi-agent policy evaluator + >>> evaluator = CommonPolicyEvaluator( + env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), + policy_graph={ + # Use an ensemble of two policies for car agents + "car_policy1": + (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.99}), + "car_policy2": + (PGPolicyGraph, Box(...), Discrete(...), {"gamma": 0.95}), + # Use a single shared policy for all traffic lights + "traffic_light_policy": + (PGPolicyGraph, Box(...), Discrete(...), {}), + }, + policy_mapping_fn=lambda agent_id: + random.choice(["car_policy1", "car_policy2"]) + if agent_id.startswith("car_") else "traffic_light_policy") + >>> print(evaluator.sample().keys()) + MultiAgentBatch({ + "car_policy1": SampleBatch(...), + "car_policy2": SampleBatch(...), + "traffic_light_policy": SampleBatch(...)}) """ @classmethod @@ -88,6 +123,7 @@ class CommonPolicyEvaluator(PolicyEvaluator): self, env_creator, policy_graph, + policy_mapping_fn=None, tf_session_creator=None, batch_steps=100, batch_mode="truncate_episodes", @@ -99,14 +135,22 @@ class CommonPolicyEvaluator(PolicyEvaluator): observation_filter="NoFilter", env_config=None, model_config=None, - policy_config=None): + policy_config=None, + worker_index=0): """Initialize a policy evaluator. Arguments: env_creator (func): Function that returns a gym.Env given an - env config dict. - policy_graph (class): A class implementing rllib.PolicyGraph or - rllib.TFPolicyGraph. + EnvContext wrapped configuration. + policy_graph (class|dict): Either a class implementing + PolicyGraph, or a dictionary of policy id strings to + (PolicyGraph, obs_space, action_space, config) tuples. If a + dict is specified, then we are in multi-agent mode and a + policy_mapping_fn should also be set. + policy_mapping_fn (func): A function that maps agent ids to + policy ids in multi-agent mode. This function will be called + each time a new agent appears in an episode, to bind that agent + to a policy for the duration of the episode. tf_session_creator (func): A function that returns a TF session. This is optional and only useful with TFPolicyGraph. batch_steps (int): The target number of env transitions to include @@ -138,19 +182,26 @@ class CommonPolicyEvaluator(PolicyEvaluator): observation_filter (str): Name of observation filter to use. env_config (dict): Config to pass to the env creator. model_config (dict): Config to use when creating the policy model. - policy_config (dict): Config to pass to the policy. + policy_config (dict): Config to pass to the policy. In the + multi-agent case, this config will be merged with the + per-policy configs specified by `policy_graph`. + worker_index (int): For remote evaluators, this should be set to a + non-zero and unique value. This index is passed to created envs + through EnvContext so that envs can be configured per worker. """ - env_config = env_config or {} + env_context = EnvContext(env_config or {}, worker_index) policy_config = policy_config or {} model_config = model_config or {} + policy_mapping_fn = ( + policy_mapping_fn or (lambda agent_id: DEFAULT_POLICY_ID)) self.env_creator = env_creator self.policy_graph = policy_graph self.batch_steps = batch_steps self.batch_mode = batch_mode self.compress_observations = compress_observations - self.env = env_creator(env_config) + self.env = env_creator(env_context) if isinstance(self.env, VectorEnv) or \ isinstance(self.env, ServingEnv) or \ isinstance(self.env, MultiAgentEnv) or \ @@ -169,32 +220,29 @@ class CommonPolicyEvaluator(PolicyEvaluator): self.env = wrap(self.env) def make_env(): - return wrap(env_creator(env_config)) + return wrap(env_creator(env_context)) - if issubclass(policy_graph, TFPolicyGraph): + self.tf_sess = None + policy_dict = _validate_and_canonicalize(policy_graph, self.env) + if _has_tensorflow_graph(policy_dict): with tf.Graph().as_default(): if tf_session_creator: - self.sess = tf_session_creator() + self.tf_sess = tf_session_creator() else: - self.sess = tf.Session(config=tf.ConfigProto( + self.tf_sess = tf.Session(config=tf.ConfigProto( gpu_options=tf.GPUOptions(allow_growth=True))) - with self.sess.as_default(): - policy = policy_graph( - self.env.observation_space, self.env.action_space, - policy_config) + with self.tf_sess.as_default(): + self.policy_map = self._build_policy_map( + policy_dict, policy_config) else: - policy = policy_graph( - self.env.observation_space, self.env.action_space, - policy_config) + self.policy_map = self._build_policy_map( + policy_dict, policy_config) - self.policy_map = { - "default": policy - } + self.multiagent = self.policy_map.keys() != set(DEFAULT_POLICY_ID) self.filters = { - # TODO(ekl) make the obs space dependent on policy policy_id: get_filter( - observation_filter, self.env.observation_space.shape) + observation_filter, policy.observation_space.shape) for (policy_id, policy) in self.policy_map.items() } @@ -218,15 +266,25 @@ class CommonPolicyEvaluator(PolicyEvaluator): "Unsupported batch mode: {}".format(self.batch_mode)) if sample_async: self.sampler = AsyncSampler( - self.async_env, self.policy_map, lambda agent_id: "default", + self.async_env, self.policy_map, policy_mapping_fn, self.filters, batch_steps, horizon=episode_horizon, - pack=pack_episodes) + pack=pack_episodes, tf_sess=self.tf_sess) self.sampler.start() else: self.sampler = SyncSampler( - self.async_env, self.policy_map, lambda agent_id: "default", + self.async_env, self.policy_map, policy_mapping_fn, self.filters, batch_steps, horizon=episode_horizon, - pack=pack_episodes) + pack=pack_episodes, tf_sess=self.tf_sess) + + def _build_policy_map(self, policy_dict, policy_config): + policy_map = {} + for name, (cls, obs_space, act_space, conf) in sorted( + policy_dict.items()): + merged_conf = policy_config.copy() + merged_conf.update(conf) + with tf.variable_scope(name): + policy_map[name] = cls(obs_space, act_space, merged_conf) + return policy_map def sample(self): """Evaluate the current policies and return a batch of experiences. @@ -254,10 +312,15 @@ class CommonPolicyEvaluator(PolicyEvaluator): return batch - def for_policy(self, func): - """Apply the given function to this evaluator's default policy.""" + def for_policy(self, func, policy_id=DEFAULT_POLICY_ID): + """Apply the given function to the specified policy graph.""" - return func(self.policy_map["default"]) + return func(self.policy_map[policy_id]) + + def foreach_policy(self, func): + """Apply the given function to each (policy, policy_id) tuple.""" + + return [func(policy, pid) for pid, policy in self.policy_map.items()] def sync_filters(self, new_filters): """Changes self's filter to given and rebases any accumulated delta. @@ -286,28 +349,126 @@ class CommonPolicyEvaluator(PolicyEvaluator): return return_filters def get_weights(self): - return self.policy_map["default"].get_weights() + return { + pid: policy.get_weights() + for pid, policy in self.policy_map.items()} def set_weights(self, weights): - return self.policy_map["default"].set_weights(weights) + for pid, w in weights.items(): + self.policy_map[pid].set_weights(w) def compute_gradients(self, samples): - return self.policy_map["default"].compute_gradients(samples) + if isinstance(samples, MultiAgentBatch): + grad_out, info_out = {}, {} + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "compute_gradients") + for pid, batch in samples.policy_batches.items(): + grad_out[pid], info_out[pid] = ( + self.policy_map[pid].build_compute_gradients( + builder, batch)) + grad_out = {k: builder.get(v) for k, v in grad_out.items()} + info_out = {k: builder.get(v) for k, v in info_out.items()} + else: + for pid, batch in samples.policy_batches.items(): + grad_out[pid], info_out[pid] = ( + self.policy_map[pid].compute_gradients(batch)) + return grad_out, info_out + else: + return self.policy_map[DEFAULT_POLICY_ID].compute_gradients( + samples) def apply_gradients(self, grads): - return self.policy_map["default"].apply_gradients(grads) + if isinstance(grads, dict): + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "apply_gradients") + outputs = { + pid: self.policy_map[pid].build_apply_gradients( + builder, grad) + for pid, grad in grads.items() + } + return { + k: builder.get(v) for k, v in outputs.items() + } + else: + return { + pid: self.policy_map[pid].apply_gradients(g) + for pid, g in grads.items() + } + else: + return self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) def compute_apply(self, samples): - grad_fetch, apply_fetch = self.policy_map["default"].compute_apply( - samples) - return grad_fetch + if isinstance(samples, MultiAgentBatch): + info_out = {} + if self.tf_sess is not None: + builder = TFRunBuilder(self.tf_sess, "compute_apply") + for pid, batch in samples.policy_batches.items(): + info_out[pid], _ = ( + self.policy_map[pid].build_compute_apply( + builder, batch)) + info_out = {k: builder.get(v) for k, v in info_out.items()} + else: + for pid, batch in samples.policy_batches.items(): + info_out[pid], _ = ( + self.policy_map[pid].compute_apply(batch)) + return info_out + else: + grad_fetch, apply_fetch = ( + self.policy_map[DEFAULT_POLICY_ID].compute_apply(samples)) + return grad_fetch def save(self): filters = self.get_filters(flush_after=True) - state = self.policy_map["default"].get_state() + state = { + pid: self.policy_map[pid].get_state() + for pid in self.policy_map + } return pickle.dumps({"filters": filters, "state": state}) def restore(self, objs): objs = pickle.loads(objs) self.sync_filters(objs["filters"]) - self.policy_map["default"].set_state(objs["state"]) + for pid, state in objs["state"].items(): + self.policy_map[pid].set_state(state) + + +def _validate_and_canonicalize(policy_graph, env): + if isinstance(policy_graph, dict): + for k, v in policy_graph.items(): + if not isinstance(k, str): + raise ValueError( + "policy_graph keys must be strs, got {}".format(type(k))) + if not isinstance(v, tuple) or len(v) != 4: + raise ValueError( + "policy_graph values must be tuples of " + "(cls, obs_space, action_space, config), got {}".format(v)) + if not issubclass(v[0], PolicyGraph): + raise ValueError( + "policy_graph tuple value 0 must be a rllib.PolicyGraph " + "class, got {}".format(v[0])) + if not isinstance(v[1], gym.Space): + raise ValueError( + "policy_graph tuple value 1 (observation_space) must be a " + "gym.Space, got {}".format(type(v[1]))) + if not isinstance(v[2], gym.Space): + raise ValueError( + "policy_graph tuple value 2 (action_space) must be a " + "gym.Space, got {}".format(type(v[2]))) + if not isinstance(v[3], dict): + raise ValueError( + "policy_graph tuple value 3 (config) must be a dict, " + "got {}".format(type(v[3]))) + return policy_graph + elif not issubclass(policy_graph, PolicyGraph): + raise ValueError("policy_graph must be a rllib.PolicyGraph class") + else: + return { + DEFAULT_POLICY_ID: ( + policy_graph, env.observation_space, env.action_space, {})} + + +def _has_tensorflow_graph(policy_dict): + for policy, _, _, _ in policy_dict.values(): + if issubclass(policy, TFPolicyGraph): + return True + return False diff --git a/python/ray/rllib/utils/env_context.py b/python/ray/rllib/utils/env_context.py new file mode 100644 index 000000000..e3885b0d3 --- /dev/null +++ b/python/ray/rllib/utils/env_context.py @@ -0,0 +1,22 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class EnvContext(dict): + """Wraps env configurations to include extra rllib metadata. + + These attributes can be used to parameterize environments per process. + For example, one might use `worker_index` to control which data file an + environment reads in on initialization. + + RLlib auto-sets these attributes when constructing registered envs. + + Attributes: + worker_index (int): When there are multiple workers created, this + uniquely identifies the worker the env is created in. + """ + + def __init__(self, env_config, worker_index): + dict.__init__(self, env_config) + self.worker_index = worker_index diff --git a/python/ray/rllib/utils/policy_graph.py b/python/ray/rllib/utils/policy_graph.py index 91272a75a..d7c526401 100644 --- a/python/ray/rllib/utils/policy_graph.py +++ b/python/ray/rllib/utils/policy_graph.py @@ -15,6 +15,10 @@ class PolicyGraph(object): find TFPolicyGraph simpler to implement. TFPolicyGraph also enables RLlib to apply TensorFlow-specific optimizations such as fusing multiple policy graphs and multi-GPU support. + + Attributes: + observation_space (gym.Space): Observation space of the policy. + action_space (gym.Space): Action space of the policy. """ def __init__(self, observation_space, action_space, config): diff --git a/python/ray/rllib/utils/sampler.py b/python/ray/rllib/utils/sampler.py index 0a0aa36a2..ca6f4dda0 100644 --- a/python/ray/rllib/utils/sampler.py +++ b/python/ray/rllib/utils/sampler.py @@ -10,6 +10,7 @@ import threading from ray.rllib.optimizers.sample_batch import MultiAgentSampleBatchBuilder, \ MultiAgentBatch from ray.rllib.utils.async_vector_env import AsyncVectorEnv +from ray.rllib.utils.tf_run_builder import TFRunBuilder RolloutMetrics = namedtuple( @@ -30,7 +31,7 @@ class SyncSampler(object): def __init__( self, env, policies, policy_mapping_fn, obs_filters, - num_local_steps, horizon=None, pack=False): + num_local_steps, horizon=None, pack=False, tf_sess=None): self.async_vector_env = AsyncVectorEnv.wrap_async(env) self.num_local_steps = num_local_steps self.horizon = horizon @@ -39,7 +40,8 @@ class SyncSampler(object): self._obs_filters = obs_filters self.rollout_provider = _env_runner( self.async_vector_env, self.policies, self.policy_mapping_fn, - self.num_local_steps, self.horizon, self._obs_filters, pack) + self.num_local_steps, self.horizon, self._obs_filters, pack, + tf_sess) self.metrics_queue = queue.Queue() def get_data(self): @@ -68,7 +70,7 @@ class AsyncSampler(threading.Thread): def __init__( self, env, policies, policy_mapping_fn, obs_filters, - num_local_steps, horizon=None, pack=False): + num_local_steps, horizon=None, pack=False, tf_sess=None): for _, f in obs_filters.items(): assert getattr(f, "is_concurrent", False), \ "Observation Filter must support concurrent updates." @@ -83,6 +85,7 @@ class AsyncSampler(threading.Thread): self._obs_filters = obs_filters self.daemon = True self.pack = pack + self.tf_sess = tf_sess def run(self): try: @@ -94,7 +97,8 @@ class AsyncSampler(threading.Thread): def _run(self): rollout_provider = _env_runner( self.async_vector_env, self.policies, self.policy_mapping_fn, - self.num_local_steps, self.horizon, self._obs_filters, self.pack) + self.num_local_steps, self.horizon, self._obs_filters, self.pack, + self.tf_sess) while True: # The timeout variable exists because apparently, if one worker # dies, the other workers won't die with it, unless the timeout is @@ -140,7 +144,7 @@ class AsyncSampler(threading.Thread): def _env_runner( async_vector_env, policies, policy_mapping_fn, num_local_steps, - horizon, obs_filters, pack): + horizon, obs_filters, pack, tf_sess=None): """This implements the common experience collection logic. Args: @@ -156,6 +160,8 @@ def _env_runner( observations for the policy. pack (bool): Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `num_local_steps` in size. + tf_sess (Session|None): Optional tensorflow session to use for batching + TF policy evaluations. Yields: rollout (SampleBatch): Object containing state, action, reward, @@ -192,6 +198,9 @@ def _env_runner( # Map of policy_id to list of PolicyEvalData to_eval = defaultdict(list) + # Map of env_id -> agent_id -> action replies + actions_to_send = defaultdict(dict) + # For each environment for env_id, agent_obs in unfiltered_obs.items(): new_episode = env_id not in active_episodes @@ -209,11 +218,13 @@ def _env_runner( dict(episode.agent_rewards)) else: all_done = False + # At least send an empty dict if not done + actions_to_send[env_id] # For each agent in the environment for agent_id, raw_obs in agent_obs.items(): policy_id = episode.policy_for(agent_id) - filtered_obs = obs_filters[policy_id](raw_obs) + filtered_obs = _get_or_raise(obs_filters, policy_id)(raw_obs) agent_done = bool(all_done or dones[env_id].get(agent_id)) if not agent_done: to_eval[policy_id].append( @@ -263,24 +274,40 @@ def _env_runner( episode = active_episodes[env_id] for agent_id, raw_obs in resetted_obs.items(): policy_id = episode.policy_for(agent_id) - filtered_obs = obs_filters[policy_id](raw_obs) + filtered_obs = _get_or_raise( + obs_filters, policy_id)(raw_obs) episode.set_last_observation(agent_id, filtered_obs) to_eval[policy_id].append( PolicyEvalData( env_id, agent_id, filtered_obs, episode.rnn_state_for(agent_id))) - # Map of env_id -> agent_id -> action - action_dict = defaultdict(dict) - - # TODO(ekl) fuse all policy evaluation into one TF run + # Batch eval policy actions if possible + if tf_sess: + builder = TFRunBuilder(tf_sess, "policy_eval") + else: + builder = None + eval_results = {} + rnn_in_cols = {} for policy_id, eval_data in to_eval.items(): - rnn_in_cols = _to_column_format([t.rnn_state for t in eval_data]) - actions, rnn_out_cols, pi_info_cols = \ - policies[policy_id].compute_actions( - [t.obs for t in eval_data], rnn_in_cols, is_training=True) + rnn_in = _to_column_format([t.rnn_state for t in eval_data]) + rnn_in_cols[policy_id] = rnn_in + policy = _get_or_raise(policies, policy_id) + if builder: + eval_results[policy_id] = policy.build_compute_actions( + builder, [t.obs for t in eval_data], rnn_in, + is_training=True) + else: + eval_results[policy_id] = policy.compute_actions( + [t.obs for t in eval_data], rnn_in, is_training=True) + if builder: + eval_results = {k: builder.get(v) for k, v in eval_results.items()} + + # Record the policy eval results + for policy_id, eval_data in to_eval.items(): + actions, rnn_out_cols, pi_info_cols = eval_results[policy_id] # Add RNN state info - for f_i, column in enumerate(rnn_in_cols): + for f_i, column in enumerate(rnn_in_cols[policy_id]): pi_info_cols["state_in_{}".format(f_i)] = column for f_i, column in enumerate(rnn_out_cols): pi_info_cols["state_out_{}".format(f_i)] = column @@ -288,7 +315,7 @@ def _env_runner( for i, action in enumerate(actions): env_id = eval_data[i].env_id agent_id = eval_data[i].agent_id - action_dict[env_id][agent_id] = action + actions_to_send[env_id][agent_id] = action episode = active_episodes[env_id] episode.set_rnn_state(agent_id, [c[i] for c in rnn_out_cols]) episode.set_last_pi_info( @@ -302,7 +329,7 @@ def _env_runner( # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. - async_vector_env.send_actions(dict(action_dict)) + async_vector_env.send_actions(dict(actions_to_send)) def _to_column_format(rnn_state_rows): @@ -311,6 +338,14 @@ def _to_column_format(rnn_state_rows): [row[i] for row in rnn_state_rows] for i in range(num_cols)] +def _get_or_raise(mapping, policy_id): + if policy_id not in mapping: + raise ValueError( + "Could not find policy for agent: agent policy id `{}` not " + "in policy map keys {}.".format(policy_id, mapping.keys())) + return mapping[policy_id] + + class _MultiAgentEpisode(object): def __init__(self, policies, policy_mapping_fn, batch_builder_factory): self.batch_builder = batch_builder_factory() @@ -327,8 +362,10 @@ class _MultiAgentEpisode(object): def add_agent_rewards(self, reward_dict): for agent_id, reward in reward_dict.items(): - self.agent_rewards[agent_id] += reward - self.total_reward += reward + if reward is not None: + self.agent_rewards[ + agent_id, self.policy_for(agent_id)] += reward + self.total_reward += reward def policy_for(self, agent_id): if agent_id not in self._agent_to_policy: diff --git a/python/ray/rllib/utils/serving_env.py b/python/ray/rllib/utils/serving_env.py index d3928536b..0c1e3ec0d 100644 --- a/python/ray/rllib/utils/serving_env.py +++ b/python/ray/rllib/utils/serving_env.py @@ -35,6 +35,8 @@ class ServingEnv(threading.Thread): def __init__(self, action_space, observation_space, max_concurrent=100): """Initialize a serving env. + ServingEnv subclasses must call this during their __init__. + Arguments: action_space (gym.Space): Action space of the env. observation_space (gym.Space): Observation space of the env. diff --git a/python/ray/rllib/utils/tf_policy_graph.py b/python/ray/rllib/utils/tf_policy_graph.py index a3dfd174b..74cf1345b 100644 --- a/python/ray/rllib/utils/tf_policy_graph.py +++ b/python/ray/rllib/utils/tf_policy_graph.py @@ -6,6 +6,7 @@ import tensorflow as tf import ray from ray.rllib.utils.policy_graph import PolicyGraph +from ray.rllib.utils.tf_run_builder import TFRunBuilder class TFPolicyGraph(PolicyGraph): @@ -29,11 +30,15 @@ class TFPolicyGraph(PolicyGraph): """ def __init__( - self, sess, obs_input, action_sampler, loss, loss_inputs, + self, observation_space, action_space, sess, obs_input, + action_sampler, loss, loss_inputs, is_training, state_inputs=None, state_outputs=None): """Initialize the policy. Arguments: + observation_space (gym.Space): Observation space of the env. + action_space (gym.Space): Action space of the env. + sess (Session): TensorFlow session to use. obs_input (Tensor): input placeholder for observations. action_sampler (Tensor): Tensor for sampling an action. loss (Tensor): scalar policy loss output tensor. @@ -46,6 +51,8 @@ class TFPolicyGraph(PolicyGraph): state_outputs (list): list of initial state values. """ + self.observation_space = observation_space + self.action_space = action_space self._sess = sess self._obs_input = obs_input self._sampler = action_sampler @@ -55,7 +62,9 @@ class TFPolicyGraph(PolicyGraph): self._state_inputs = state_inputs or [] self._state_outputs = state_outputs or [] self._optimizer = self.optimizer() - self._grads_and_vars = self.gradients(self._optimizer) + self._grads_and_vars = [ + (g, v) for (g, v) in self.gradients(self._optimizer) + if g is not None] self._grads = [g for (g, v) in self._grads_and_vars] self._apply_op = self._optimizer.apply_gradients(self._grads_and_vars) self._variables = ray.experimental.TensorFlowVariables( @@ -64,21 +73,27 @@ class TFPolicyGraph(PolicyGraph): assert len(self._state_inputs) == len(self._state_outputs) == \ len(self.get_initial_state()) - def compute_actions( - self, obs_batch, state_batches=None, is_training=False): + def build_compute_actions( + self, builder, obs_batch, state_batches=None, is_training=False): state_batches = state_batches or [] assert len(self._state_inputs) == len(state_batches), \ (self._state_inputs, state_batches) - feed_dict = self.extra_compute_action_feed_dict() - feed_dict[self._obs_input] = obs_batch - feed_dict[self._is_training] = is_training - for ph, value in zip(self._state_inputs, state_batches): - feed_dict[ph] = value - fetches = self._sess.run( - ([self._sampler] + self._state_outputs + - [self.extra_compute_action_fetches()]), feed_dict=feed_dict) + builder.add_feed_dict(self.extra_compute_action_feed_dict()) + builder.add_feed_dict({self._obs_input: obs_batch}) + builder.add_feed_dict({self._is_training: is_training}) + builder.add_feed_dict(dict(zip(self._state_inputs, state_batches))) + fetches = builder.add_fetches( + [self._sampler] + self._state_outputs + + [self.extra_compute_action_fetches()]) return fetches[0], fetches[1:-1], fetches[-1] + def compute_actions( + self, obs_batch, state_batches=None, is_training=False): + builder = TFRunBuilder(self._sess, "compute_actions") + fetches = self.build_compute_actions( + builder, obs_batch, state_batches, is_training) + return builder.get(fetches) + def _get_loss_inputs_dict(self, postprocessed_batch): feed_dict = {} for key, ph in self._loss_inputs: @@ -90,37 +105,48 @@ class TFPolicyGraph(PolicyGraph): feed_dict[ph] = postprocessed_batch[key] return feed_dict - def compute_gradients(self, postprocessed_batch): - feed_dict = self.extra_compute_grad_feed_dict() - feed_dict[self._is_training] = True - feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch)) - fetches = self._sess.run( - [self._grads, self.extra_compute_grad_fetches()], - feed_dict=feed_dict) + def build_compute_gradients(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + fetches = builder.add_fetches( + [self._grads, self.extra_compute_grad_fetches()]) return fetches[0], fetches[1] - def apply_gradients(self, gradients): + def compute_gradients(self, postprocessed_batch): + builder = TFRunBuilder(self._sess, "compute_gradients") + fetches = self.build_compute_gradients(builder, postprocessed_batch) + return builder.get(fetches) + + def build_apply_gradients(self, builder, gradients): assert len(gradients) == len(self._grads), (gradients, self._grads) - feed_dict = self.extra_apply_grad_feed_dict() - feed_dict[self._is_training] = True - for ph, value in zip(self._grads, gradients): - feed_dict[ph] = value - fetches = self._sess.run( - [self._apply_op, self.extra_apply_grad_fetches()], - feed_dict=feed_dict) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict({self._is_training: True}) + builder.add_feed_dict(dict(zip(self._grads, gradients))) + fetches = builder.add_fetches( + [self._apply_op, self.extra_apply_grad_fetches()]) return fetches[1] - def compute_apply(self, postprocessed_batch): - feed_dict = self.extra_compute_grad_feed_dict() - feed_dict.update(self.extra_apply_grad_feed_dict()) - feed_dict.update(self._get_loss_inputs_dict(postprocessed_batch)) - feed_dict[self._is_training] = True - fetches = self._sess.run( + def apply_gradients(self, gradients): + builder = TFRunBuilder(self._sess, "apply_gradients") + fetches = self.build_apply_gradients(builder, gradients) + return builder.get(fetches) + + def build_compute_apply(self, builder, postprocessed_batch): + builder.add_feed_dict(self.extra_compute_grad_feed_dict()) + builder.add_feed_dict(self.extra_apply_grad_feed_dict()) + builder.add_feed_dict(self._get_loss_inputs_dict(postprocessed_batch)) + builder.add_feed_dict({self._is_training: True}) + fetches = builder.add_fetches( [self._apply_op, self.extra_compute_grad_fetches(), - self.extra_apply_grad_fetches()], - feed_dict=feed_dict) + self.extra_apply_grad_fetches()]) return fetches[1], fetches[2] + def compute_apply(self, postprocessed_batch): + builder = TFRunBuilder(self._sess, "compute_apply") + fetches = self.build_compute_apply(builder, postprocessed_batch) + return builder.get(fetches) + def get_weights(self): return self._variables.get_flat() diff --git a/python/ray/rllib/utils/tf_run_builder.py b/python/ray/rllib/utils/tf_run_builder.py new file mode 100644 index 000000000..6512fc85c --- /dev/null +++ b/python/ray/rllib/utils/tf_run_builder.py @@ -0,0 +1,82 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import time + +import tensorflow as tf +from tensorflow.python.client import timeline + + +class TFRunBuilder(object): + """Used to incrementally build up a TensorFlow run. + + This is particularly useful for batching ops from multiple different + policies in the multi-agent setting. + """ + + def __init__(self, session, debug_name): + self.session = session + self.debug_name = debug_name + self.feed_dict = {} + self.fetches = [] + self._executed = None + + def add_feed_dict(self, feed_dict): + assert not self._executed + for k in feed_dict: + assert k not in self.feed_dict + self.feed_dict.update(feed_dict) + + def add_fetches(self, fetches): + assert not self._executed + base_index = len(self.fetches) + self.fetches.extend(fetches) + return list(range(base_index, len(self.fetches))) + + def get(self, to_fetch): + if self._executed is None: + try: + self._executed = run_timeline( + self.session, self.fetches, self.debug_name, + self.feed_dict, os.environ.get("TF_TIMELINE_DIR")) + except Exception as e: + print("Error fetching: {}, feed_dict={}".format( + self.fetches, self.feed_dict)) + raise e + if isinstance(to_fetch, int): + return self._executed[to_fetch] + elif isinstance(to_fetch, list): + return [self.get(x) for x in to_fetch] + elif isinstance(to_fetch, tuple): + return tuple(self.get(x) for x in to_fetch) + else: + raise ValueError("Unsupported fetch type: {}".format(to_fetch)) + + +_count = 0 + + +def run_timeline(sess, ops, debug_name, feed_dict={}, timeline_dir=None): + if timeline_dir: + run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + start = time.time() + fetches = sess.run( + ops, options=run_options, run_metadata=run_metadata, + feed_dict=feed_dict) + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + global _count + outf = os.path.join( + timeline_dir, + "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count)) + _count += 1 + trace_file = open(outf, "w") + print( + "Wrote tf timeline ({} s) to {}".format( + time.time() - start, os.path.abspath(outf))) + trace_file.write(trace.generate_chrome_trace_format()) + else: + fetches = sess.run(ops, feed_dict=feed_dict) + return fetches diff --git a/python/ray/tune/result.py b/python/ray/tune/result.py index 261ca6e90..1b9eb0a68 100644 --- a/python/ray/tune/result.py +++ b/python/ray/tune/result.py @@ -43,6 +43,9 @@ TrainingResult = namedtuple( # (Optional) The number of episodes total. "episodes_total", + # (Optional) Per-policy reward information in multi-agent RL. + "policy_reward_mean", + # (Optional) The current training accuracy if applicable. "mean_accuracy", diff --git a/test/jenkins_tests/run_multi_node_tests.sh b/test/jenkins_tests/run_multi_node_tests.sh index d29077ac6..0a29af107 100755 --- a/test/jenkins_tests/run_multi_node_tests.sh +++ b/test/jenkins_tests/run_multi_node_tests.sh @@ -257,3 +257,6 @@ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ python /ray/python/ray/rllib/examples/legacy_multiagent/multiagent_pendulum.py + +docker run --rm --shm-size=10G --memory=10G $DOCKER_SHA \ + python /ray/python/ray/rllib/examples/multiagent_cartpole.py