mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 01:06:07 +08:00
[rllib] Multi-GPU support for Multi-Agent PPO (#3479)
* wip * fix * remove check * fix null * revert * lint and kl * also fix rollout
This commit is contained in:
@@ -130,8 +130,10 @@ class PPOAgent(Agent):
|
||||
"Episode truncation is not supported without a value function")
|
||||
if (self.config["multiagent"]["policy_graphs"]
|
||||
and not self.config["simple_optimizer"]):
|
||||
logger.warn("forcing simple_optimizer=True in multi-agent mode")
|
||||
self.config["simple_optimizer"] = True
|
||||
logger.info(
|
||||
"In multi-agent mode, policies will be optimized sequentially "
|
||||
"by the multi-GPU optimizer. Consider setting "
|
||||
"simple_optimizer=True if this doesn't work for you.")
|
||||
if self.config["observation_filter"] != "NoFilter":
|
||||
# TODO(ekl): consider setting the default to be NoFilter
|
||||
logger.warn(
|
||||
|
||||
@@ -247,6 +247,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
self.explained_variance = explained_variance(value_targets_ph,
|
||||
self.value_function)
|
||||
self.stats_fetches = {
|
||||
"cur_kl_coeff": self.kl_coeff,
|
||||
"cur_lr": tf.cast(self.cur_lr, tf.float64),
|
||||
"total_loss": self.loss_obj.loss,
|
||||
"policy_loss": self.loss_obj.mean_policy_loss,
|
||||
|
||||
@@ -107,7 +107,8 @@ if __name__ == "__main__":
|
||||
"training_iteration": args.num_iters
|
||||
},
|
||||
"config": {
|
||||
"simple_optimizer": True,
|
||||
"log_level": "DEBUG",
|
||||
"num_sgd_iter": 10,
|
||||
"multiagent": {
|
||||
"policy_graphs": policy_graphs,
|
||||
"policy_mapping_fn": tune.function(
|
||||
|
||||
@@ -57,7 +57,6 @@ if __name__ == "__main__":
|
||||
"policy_mapping_fn": policy_mapping_fn,
|
||||
"policies_to_train": ["ppo_policy"],
|
||||
},
|
||||
"simple_optimizer": True,
|
||||
# disable filters, otherwise we would need to synchronize those
|
||||
# as well to the DQN agent
|
||||
"observation_filter": "NoFilter",
|
||||
|
||||
@@ -14,6 +14,8 @@ from ray.rllib.optimizers.policy_optimizer import PolicyOptimizer
|
||||
from ray.rllib.optimizers.multi_gpu_impl import LocalSyncParallelOptimizer
|
||||
from ray.rllib.utils.annotations import override
|
||||
from ray.rllib.utils.timer import TimerStat
|
||||
from ray.rllib.evaluation.sample_batch import SampleBatch, DEFAULT_POLICY_ID, \
|
||||
MultiAgentBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,32 +65,33 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
|
||||
logger.info("LocalMultiGPUOptimizer devices {}".format(self.devices))
|
||||
|
||||
if set(self.local_evaluator.policy_map.keys()) != {"default"}:
|
||||
raise ValueError(
|
||||
"Multi-agent is not supported with multi-GPU. Try using the "
|
||||
"simple optimizer instead.")
|
||||
self.policy = self.local_evaluator.policy_map["default"]
|
||||
if not isinstance(self.policy, TFPolicyGraph):
|
||||
raise ValueError(
|
||||
"Only TF policies are supported with multi-GPU. Try using the "
|
||||
"simple optimizer instead.")
|
||||
self.policies = self.local_evaluator.policy_map
|
||||
for policy_id, policy in self.policies.items():
|
||||
if not isinstance(policy, TFPolicyGraph):
|
||||
raise ValueError(
|
||||
"Only TF policies are supported with multi-GPU. Try using "
|
||||
"the simple optimizer instead.")
|
||||
|
||||
# per-GPU graph copies created below must share vars with the policy
|
||||
# reuse is set to AUTO_REUSE because Adam nodes are created after
|
||||
# all of the device copies are created.
|
||||
self.optimizers = {}
|
||||
with self.local_evaluator.tf_sess.graph.as_default():
|
||||
with self.local_evaluator.tf_sess.as_default():
|
||||
with tf.variable_scope("default", reuse=tf.AUTO_REUSE):
|
||||
if self.policy._state_inputs:
|
||||
rnn_inputs = self.policy._state_inputs + [
|
||||
self.policy._seq_lens
|
||||
]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
self.par_opt = LocalSyncParallelOptimizer(
|
||||
self.policy.optimizer(), self.devices,
|
||||
[v for _, v in self.policy._loss_inputs], rnn_inputs,
|
||||
self.per_device_batch_size, self.policy.copy)
|
||||
for policy_id, policy in self.policies.items():
|
||||
with tf.variable_scope(policy_id, reuse=tf.AUTO_REUSE):
|
||||
if policy._state_inputs:
|
||||
rnn_inputs = policy._state_inputs + [
|
||||
policy._seq_lens
|
||||
]
|
||||
else:
|
||||
rnn_inputs = []
|
||||
self.optimizers[policy_id] = (
|
||||
LocalSyncParallelOptimizer(
|
||||
policy._optimizer, self.devices,
|
||||
[v
|
||||
for _, v in policy._loss_inputs], rnn_inputs,
|
||||
self.per_device_batch_size, policy.copy))
|
||||
|
||||
self.sess = self.local_evaluator.tf_sess
|
||||
self.sess.run(tf.global_variables_initializer())
|
||||
@@ -109,47 +112,62 @@ class LocalMultiGPUOptimizer(PolicyOptimizer):
|
||||
self.train_batch_size)
|
||||
else:
|
||||
samples = self.local_evaluator.sample()
|
||||
self._check_not_multiagent(samples)
|
||||
# Handle everything as if multiagent
|
||||
if isinstance(samples, SampleBatch):
|
||||
samples = MultiAgentBatch({
|
||||
DEFAULT_POLICY_ID: samples
|
||||
}, samples.count)
|
||||
|
||||
for field in self.standardize_fields:
|
||||
value = samples[field]
|
||||
standardized = (value - value.mean()) / max(1e-4, value.std())
|
||||
samples[field] = standardized
|
||||
for _, batch in samples.policy_batches.items():
|
||||
for field in self.standardize_fields:
|
||||
value = batch[field]
|
||||
standardized = (value - value.mean()) / max(1e-4, value.std())
|
||||
batch[field] = standardized
|
||||
|
||||
# Important: don't shuffle RNN sequence elements
|
||||
if not self.policy._state_inputs:
|
||||
samples.shuffle()
|
||||
for policy_id, policy in self.policies.items():
|
||||
# Important: don't shuffle RNN sequence elements
|
||||
if (policy_id in samples.policy_batches
|
||||
and not policy._state_inputs):
|
||||
samples.policy_batches[policy_id].shuffle()
|
||||
|
||||
num_loaded_tuples = {}
|
||||
with self.load_timer:
|
||||
tuples = self.policy._get_loss_inputs_dict(samples)
|
||||
data_keys = [ph for _, ph in self.policy._loss_inputs]
|
||||
if self.policy._state_inputs:
|
||||
state_keys = (
|
||||
self.policy._state_inputs + [self.policy._seq_lens])
|
||||
else:
|
||||
state_keys = []
|
||||
tuples_per_device = self.par_opt.load_data(
|
||||
self.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys])
|
||||
for policy_id, batch in samples.policy_batches.items():
|
||||
policy = self.policies[policy_id]
|
||||
tuples = policy._get_loss_inputs_dict(batch)
|
||||
data_keys = [ph for _, ph in policy._loss_inputs]
|
||||
if policy._state_inputs:
|
||||
state_keys = policy._state_inputs + [policy._seq_lens]
|
||||
else:
|
||||
state_keys = []
|
||||
num_loaded_tuples[policy_id] = (
|
||||
self.optimizers[policy_id].load_data(
|
||||
self.sess, [tuples[k] for k in data_keys],
|
||||
[tuples[k] for k in state_keys]))
|
||||
|
||||
fetches = {}
|
||||
with self.grad_timer:
|
||||
num_batches = (
|
||||
int(tuples_per_device) // int(self.per_device_batch_size))
|
||||
logger.debug("== sgd epochs ==")
|
||||
for i in range(self.num_sgd_iter):
|
||||
iter_extra_fetches = defaultdict(list)
|
||||
permutation = np.random.permutation(num_batches)
|
||||
for batch_index in range(num_batches):
|
||||
batch_fetches = self.par_opt.optimize(
|
||||
self.sess,
|
||||
permutation[batch_index] * self.per_device_batch_size)
|
||||
for k, v in batch_fetches.items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
logger.debug("{} {}".format(i, _averaged(iter_extra_fetches)))
|
||||
for policy_id, tuples_per_device in num_loaded_tuples.items():
|
||||
optimizer = self.optimizers[policy_id]
|
||||
num_batches = (
|
||||
int(tuples_per_device) // int(self.per_device_batch_size))
|
||||
logger.debug("== sgd epochs for {} ==".format(policy_id))
|
||||
for i in range(self.num_sgd_iter):
|
||||
iter_extra_fetches = defaultdict(list)
|
||||
permutation = np.random.permutation(num_batches)
|
||||
for batch_index in range(num_batches):
|
||||
batch_fetches = optimizer.optimize(
|
||||
self.sess, permutation[batch_index] *
|
||||
self.per_device_batch_size)
|
||||
for k, v in batch_fetches.items():
|
||||
iter_extra_fetches[k].append(v)
|
||||
logger.debug("{} {}".format(i,
|
||||
_averaged(iter_extra_fetches)))
|
||||
fetches[policy_id] = _averaged(iter_extra_fetches)
|
||||
|
||||
self.num_steps_sampled += samples.count
|
||||
self.num_steps_trained += samples.count
|
||||
return _averaged(iter_extra_fetches)
|
||||
return fetches
|
||||
|
||||
@override(PolicyOptimizer)
|
||||
def stats(self):
|
||||
|
||||
@@ -7,7 +7,6 @@ import logging
|
||||
import ray
|
||||
from ray.rllib.evaluation.policy_evaluator import PolicyEvaluator
|
||||
from ray.rllib.evaluation.metrics import collect_episodes, summarize_episodes
|
||||
from ray.rllib.evaluation.sample_batch import MultiAgentBatch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -154,12 +153,6 @@ class PolicyOptimizer(object):
|
||||
])
|
||||
return local_result + remote_results
|
||||
|
||||
@staticmethod
|
||||
def _check_not_multiagent(sample_batch):
|
||||
if isinstance(sample_batch, MultiAgentBatch):
|
||||
raise NotImplementedError(
|
||||
"This optimizer does not support multi-agent yet.")
|
||||
|
||||
@classmethod
|
||||
def make(cls,
|
||||
env_creator,
|
||||
|
||||
@@ -12,7 +12,6 @@ import pickle
|
||||
import gym
|
||||
import ray
|
||||
from ray.rllib.agents.agent import get_agent_class
|
||||
from ray.rllib.models import ModelCatalog
|
||||
|
||||
EXAMPLE_USAGE = """
|
||||
Example Usage via RLlib CLI:
|
||||
@@ -96,7 +95,7 @@ def run(args, parser):
|
||||
if hasattr(agent, "local_evaluator"):
|
||||
env = agent.local_evaluator.env
|
||||
else:
|
||||
env = ModelCatalog.get_preprocessor_as_wrapper(gym.make(args.env))
|
||||
env = gym.make(args.env)
|
||||
if args.out is not None:
|
||||
rollouts = []
|
||||
steps = 0
|
||||
|
||||
@@ -184,7 +184,6 @@ class ModelSupportedSpaces(unittest.TestCase):
|
||||
"train_batch_size": 10,
|
||||
"sample_batch_size": 10,
|
||||
"sgd_minibatch_size": 1,
|
||||
"simple_optimizer": True,
|
||||
})
|
||||
check_support_multiagent("PG", {"num_workers": 1, "optimizer": {}})
|
||||
check_support_multiagent("DDPG", {"timesteps_per_iteration": 1})
|
||||
|
||||
Reference in New Issue
Block a user