[rllib] Multi-GPU support for Multi-Agent PPO (#3479)

* wip

* fix

* remove check

* fix null

* revert

* lint and kl

* also fix rollout
This commit is contained in:
Eric Liang
2018-12-08 18:02:33 -08:00
committed by GitHub
parent 8b5827b9da
commit 7aec357501
8 changed files with 78 additions and 66 deletions
+4 -2
View File
@@ -130,8 +130,10 @@ class PPOAgent(Agent):
"Episode truncation is not supported without a value function")
if (self.config["multiagent"]["policy_graphs"]
and not self.config["simple_optimizer"]):
logger.warn("forcing simple_optimizer=True in multi-agent mode")
self.config["simple_optimizer"] = True
logger.info(
"In multi-agent mode, policies will be optimized sequentially "
"by the multi-GPU optimizer. Consider setting "
"simple_optimizer=True if this doesn't work for you.")
if self.config["observation_filter"] != "NoFilter":
# TODO(ekl): consider setting the default to be NoFilter
logger.warn(
@@ -247,6 +247,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.explained_variance = explained_variance(value_targets_ph,
self.value_function)
self.stats_fetches = {
"cur_kl_coeff": self.kl_coeff,
"cur_lr": tf.cast(self.cur_lr, tf.float64),
"total_loss": self.loss_obj.loss,
"policy_loss": self.loss_obj.mean_policy_loss,
@@ -107,7 +107,8 @@ if __name__ == "__main__":
"training_iteration": args.num_iters
},
"config": {
"simple_optimizer": True,
"log_level": "DEBUG",
"num_sgd_iter": 10,
"multiagent": {
"policy_graphs": policy_graphs,
"policy_mapping_fn": tune.function(
@@ -57,7 +57,6 @@ if __name__ == "__main__":
"policy_mapping_fn": policy_mapping_fn,
"policies_to_train": ["ppo_policy"],
},
"simple_optimizer": True,
# disable filters, otherwise we would need to synchronize those
# as well to the DQN agent
"observation_filter": "NoFilter",
@@ -14,6 +14,8 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.timer import TimerStat
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
MultiAgentBatch
logger = logging.getLogger(__name__)
@@ -63,32 +65,33 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
raise ValueError(
"Multi-agent is not supported with multi-GPU. Try using the "
"simple optimizer instead.")
self.policy = self.local_evaluator.policy_map["default"]
if not isinstance(self.policy, TFPolicyGraph):
raise ValueError(
"Only TF policies are supported with multi-GPU. Try using the "
"simple optimizer instead.")
self.policies = self.local_evaluator.policy_map
for policy_id, policy in self.policies.items():
if not isinstance(policy, TFPolicyGraph):
raise ValueError(
"Only TF policies are supported with multi-GPU. Try using "
"the simple optimizer instead.")
# per-GPU graph copies created below must share vars with the policy
# reuse is set to AUTO_REUSE because Adam nodes are created after
# all of the device copies are created.
self.optimizers = {}
with self.local_evaluator.tf_sess.graph.as_default():
with self.local_evaluator.tf_sess.as_default():
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
if self.policy._state_inputs:
rnn_inputs = self.policy._state_inputs + [
self.policy._seq_lens
]
else:
rnn_inputs = []
self.par_opt = LocalSyncParallelOptimizer(
self.policy.optimizer(), self.devices,
[v for _, v in self.policy._loss_inputs], rnn_inputs,
self.per_device_batch_size, self.policy.copy)
for policy_id, policy in self.policies.items():
with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
if policy._state_inputs:
rnn_inputs = policy._state_inputs + [
policy._seq_lens
]
else:
rnn_inputs = []
self.optimizers[policy_id] = (
LocalSyncParallelOptimizer(
policy._optimizer, self.devices,
[v
for _, v in policy._loss_inputs], rnn_inputs,
self.per_device_batch_size, policy.copy))
self.sess = self.local_evaluator.tf_sess
self.sess.run(tf.global_variables_initializer())
@@ -109,47 +112,62 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
self.train_batch_size)
else:
samples = self.local_evaluator.sample()
self._check_not_multiagent(samples)
# Handle everything as if multiagent
if isinstance(samples, SampleBatch):
samples = MultiAgentBatch({
DEFAULT_POLICY_ID: samples
}, samples.count)
for field in self.standardize_fields:
value = samples[field]
standardized = (value - value.mean()) / max(1e-4, value.std())
samples[field] = standardized
for _, batch in samples.policy_batches.items():
for field in self.standardize_fields:
value = batch[field]
standardized = (value - value.mean()) / max(1e-4, value.std())
batch[field] = standardized
# Important: don't shuffle RNN sequence elements
if not self.policy._state_inputs:
samples.shuffle()
for policy_id, policy in self.policies.items():
# Important: don't shuffle RNN sequence elements
if (policy_id in samples.policy_batches
and not policy._state_inputs):
samples.policy_batches[policy_id].shuffle()
num_loaded_tuples = {}
with self.load_timer:
tuples = self.policy._get_loss_inputs_dict(samples)
data_keys = [ph for _, ph in self.policy._loss_inputs]
if self.policy._state_inputs:
state_keys = (
self.policy._state_inputs + [self.policy._seq_lens])
else:
state_keys = []
tuples_per_device = self.par_opt.load_data(
self.sess, [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys])
for policy_id, batch in samples.policy_batches.items():
policy = self.policies[policy_id]
tuples = policy._get_loss_inputs_dict(batch)
data_keys = [ph for _, ph in policy._loss_inputs]
if policy._state_inputs:
state_keys = policy._state_inputs + [policy._seq_lens]
else:
state_keys = []
num_loaded_tuples[policy_id] = (
self.optimizers[policy_id].load_data(
self.sess, [tuples[k] for k in data_keys],
[tuples[k] for k in state_keys]))
fetches = {}
with self.grad_timer:
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
logger.debug("== sgd epochs ==")
for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
batch_fetches = self.par_opt.optimize(
self.sess,
permutation[batch_index] * self.per_device_batch_size)
for k, v in batch_fetches.items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i, _averaged(iter_extra_fetches)))
for policy_id, tuples_per_device in num_loaded_tuples.items():
optimizer = self.optimizers[policy_id]
num_batches = (
int(tuples_per_device) // int(self.per_device_batch_size))
logger.debug("== sgd epochs for {} ==".format(policy_id))
for i in range(self.num_sgd_iter):
iter_extra_fetches = defaultdict(list)
permutation = np.random.permutation(num_batches)
for batch_index in range(num_batches):
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
for k, v in batch_fetches.items():
iter_extra_fetches[k].append(v)
logger.debug("{} {}".format(i,
_averaged(iter_extra_fetches)))
fetches[policy_id] = _averaged(iter_extra_fetches)
self.num_steps_sampled += samples.count
self.num_steps_trained += samples.count
return _averaged(iter_extra_fetches)
return fetches
@override(PolicyOptimizer)
def stats(self):
@@ -7,7 +7,6 @@ import logging
import ray
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
logger = logging.getLogger(__name__)
@@ -154,12 +153,6 @@ class PolicyOptimizer(object):
])
return local_result + remote_results
@staticmethod
def _check_not_multiagent(sample_batch):
if isinstance(sample_batch, MultiAgentBatch):
raise NotImplementedError(
"This optimizer does not support multi-agent yet.")
@classmethod
def make(cls,
env_creator,
+1 -2
View File
@@ -12,7 +12,6 @@ import pickle
import gym
import ray
from ray.rllib.agents.agent import get_agent_class
from ray.rllib.models import ModelCatalog
EXAMPLE_USAGE = """
Example Usage via RLlib CLI:
@@ -96,7 +95,7 @@ def run(args, parser):
if hasattr(agent, "local_evaluator"):
env = agent.local_evaluator.env
else:
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
env = gym.make(args.env)
if args.out is not None:
rollouts = []
steps = 0
@@ -184,7 +184,6 @@ class ModelSupportedSpaces(unittest.TestCase):
"train_batch_size": 10,
"sample_batch_size": 10,
"sgd_minibatch_size": 1,
"simple_optimizer": True,
})
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
check_support_multiagent("DDPG", {"timesteps_per_iteration": 1})