[rllib] Add custom value functions, fix up and document multi-agent variable sharing (#3151)

This commit is contained in:
Eric Liang
2018-10-29 19:37:27 -07:00
committed by GitHub
parent e49839c73f
commit a221f55b0d
18 changed files with 199 additions and 46 deletions
@@ -13,7 +13,6 @@ from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.catalog import ModelCatalog
@@ -57,9 +56,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
"prev_rewards": prev_rewards
}, observation_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
self.vf = tf.reshape(
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
[-1])
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
@@ -144,7 +141,10 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
def get_initial_state(self):
return self.model.state_init
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
@@ -62,7 +62,10 @@ class A3CTorchPolicyGraph(TorchPolicyGraph):
def optimizer(self):
return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"])
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
@@ -332,7 +332,10 @@ class DDPGPolicyGraph(TFPolicyGraph):
"td_error": self.loss.td_error,
}
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return _postprocess_dqn(self, sample_batch)
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
@@ -414,7 +414,10 @@ class DQNPolicyGraph(TFPolicyGraph):
"td_error": self.loss.td_error,
}
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
return _postprocess_dqn(self, sample_batch)
def compute_td_error(self, obs_t, act_t, rew_t, obs_tp1, done_mask,
@@ -14,7 +14,6 @@ from ray.rllib.agents.impala import vtrace
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.error import UnsupportedSpaceException
from ray.rllib.utils.explained_variance import explained_variance
@@ -140,9 +139,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
state_in=existing_state_in,
seq_lens=existing_seq_lens)
action_dist = dist_class(self.model.outputs)
values = tf.reshape(
linear(self.model.last_layer, 1, "value", normc_initializer(1.0)),
[-1])
values = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
tf.get_variable_scope().name)
@@ -251,7 +248,10 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
def extra_compute_grad_fetches(self):
return self.stats_fetches
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
del sample_batch.data["new_obs"] # not used, so save some bandwidth
return sample_batch
+18 -6
View File
@@ -11,21 +11,27 @@ from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
class PGLoss(object):
"""Simple policy gradient loss."""
def __init__(self, action_dist, actions, advantages):
self.loss = -tf.reduce_mean(action_dist.logp(actions) * advantages)
class PGPolicyGraph(TFPolicyGraph):
"""Simple policy gradient example of defining a policy graph."""
def __init__(self, obs_space, action_space, config):
config = dict(ray.rllib.agents.pg.pg.DEFAULT_CONFIG, **config)
self.config = config
# Setup policy
# Setup placeholders
obs = tf.placeholder(tf.float32, shape=[None] + list(obs_space.shape))
dist_class, self.logit_dim = ModelCatalog.get_action_dist(
action_space, self.config["model"])
prev_actions = ModelCatalog.get_action_placeholder(action_space)
prev_rewards = tf.placeholder(tf.float32, [None], name="prev_reward")
# Create the model network and action outputs
self.model = ModelCatalog.get_model({
"obs": obs,
"prev_actions": prev_actions,
@@ -38,17 +44,19 @@ class PGPolicyGraph(TFPolicyGraph):
advantages = tf.placeholder(tf.float32, [None], name="adv")
loss = PGLoss(action_dist, actions, advantages).loss
# Initialize TFPolicyGraph
sess = tf.get_default_session()
# Mapping from sample batch keys to placeholders
# Mapping from sample batch keys to placeholders. These keys will be
# read from postprocessed sample batches and fed into the specified
# placeholders during loss computation.
loss_in = [
("obs", obs),
("actions", actions),
("prev_actions", prev_actions),
("prev_rewards", prev_rewards),
("advantages", advantages),
("advantages", advantages), # added during postprocessing
]
# Initialize TFPolicyGraph
sess = tf.get_default_session()
TFPolicyGraph.__init__(
self,
obs_space,
@@ -66,7 +74,11 @@ class PGPolicyGraph(TFPolicyGraph):
max_seq_len=config["model"]["max_seq_len"])
sess.run(tf.global_variables_initializer())
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
# This ads the "advantages" column to the sample batch
return compute_advantages(
sample_batch, 0.0, self.config["gamma"], use_gae=False)
@@ -9,7 +9,6 @@ from ray.rllib.evaluation.postprocessing import compute_advantages
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph, \
LearningRateSchedule
from ray.rllib.models.catalog import ModelCatalog
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.utils.explained_variance import explained_variance
@@ -180,9 +179,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
self.sampler = curr_action_dist.sample()
if self.config["use_gae"]:
if self.config["vf_share_layers"]:
self.value_function = tf.reshape(
linear(self.model.last_layer, 1, "value",
normc_initializer(1.0)), [-1])
self.value_function = self.model.value_function()
else:
vf_config = self.config["model"].copy()
# Do not split the last layer of the value function into
@@ -286,7 +283,10 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
vf = self.sess.run(self.value_function, feed_dict)
return vf[0]
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
completed = sample_batch["dones"][-1]
if completed:
last_r = 0.0
@@ -17,10 +17,11 @@ from ray.rllib.evaluation.interface import EvaluatorInterface
from ray.rllib.evaluation.sample_batch import MultiAgentBatch, \
DEFAULT_POLICY_ID
from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.evaluation.policy_graph import PolicyGraph
from ray.rllib.evaluation.tf_policy_graph import TFPolicyGraph
from ray.rllib.utils import merge_dicts
from ray.rllib.utils.compression import pack
from ray.rllib.utils.filter import get_filter
from ray.rllib.utils.tf_run_builder import TFRunBuilder
@@ -299,8 +300,7 @@ class PolicyEvaluator(EvaluatorInterface):
policy_map = {}
for name, (cls, obs_space, act_space,
conf) in sorted(policy_dict.items()):
merged_conf = policy_config.copy()
merged_conf.update(conf)
merged_conf = merge_dicts(policy_config, conf)
with tf.variable_scope(name):
if isinstance(obs_space, gym.spaces.Dict):
raise ValueError(
+8 -2
View File
@@ -83,7 +83,7 @@ class PolicyGraph(object):
is_training (bool): whether we are training the policy
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multiagent algorithms.
multi-agent algorithms.
Returns:
actions (obj): single action
@@ -96,7 +96,10 @@ class PolicyGraph(object):
return action, [s[0] for s in state_out], \
{k: v[0] for k, v in info.items()}
def postprocess_trajectory(self, sample_batch, other_agent_batches=None):
def postprocess_trajectory(self,
sample_batch,
other_agent_batches=None,
episode=None):
"""Implements algorithm-specific trajectory postprocessing.
This will be called on each trajectory fragment computed during policy
@@ -108,6 +111,9 @@ class PolicyGraph(object):
other_agent_batches (dict): In a multi-agent env, this contains a
mapping of agent ids to (policy_graph, agent_batch) tuples
containing the policy graph and experiences of the other agent.
episode (MultiAgentEpisode): this provides access to all of the
internal episode state, which may be useful for model-based or
multi-agent algorithms.
Returns:
SampleBatch: postprocessed sample batch.
+10 -4
View File
@@ -99,11 +99,14 @@ class MultiAgentSampleBatchBuilder(object):
builder = self.agent_builders[agent_id]
builder.add_values(**values)
def postprocess_batch_so_far(self):
def postprocess_batch_so_far(self, episode):
"""Apply policy postprocessors to any unprocessed rows.
This pushes the postprocessed per-agent batches onto the per-policy
builders, clearing per-agent state.
Arguments:
episode: current MultiAgentEpisode object or None
"""
# Materialize the batches so far
@@ -128,7 +131,7 @@ class MultiAgentSampleBatchBuilder(object):
"Batches sent to postprocessing must only contain steps "
"from a single trajectory.", pre_batch)
post_batches[agent_id] = policy.postprocess_trajectory(
pre_batch, other_batches)
pre_batch, other_batches, episode)
# Append into policy batches and reset
for agent_id, post_batch in sorted(post_batches.items()):
@@ -137,14 +140,17 @@ class MultiAgentSampleBatchBuilder(object):
self.agent_builders.clear()
self.agent_to_policy.clear()
def build_and_reset(self):
def build_and_reset(self, episode):
"""Returns the accumulated sample batches for each policy.
Any unprocessed rows will be first postprocessed with a policy
postprocessor. The internal state of this builder will be reset.
Arguments:
episode: current MultiAgentEpisode object or None
"""
self.postprocess_batch_so_far()
self.postprocess_batch_so_far(episode)
policy_batches = {}
for policy_id, builder in self.policy_builders.items():
if builder.count > 0:
+2 -2
View File
@@ -317,10 +317,10 @@ def _env_runner(async_vector_env,
if episode.batch_builder.has_pending_data():
if (all_done and not pack) or \
episode.batch_builder.count >= unroll_length:
yield episode.batch_builder.build_and_reset()
yield episode.batch_builder.build_and_reset(episode)
elif all_done:
# Make sure postprocessor stays within one episode
episode.batch_builder.postprocess_batch_so_far()
episode.batch_builder.postprocess_batch_so_far(episode)
if all_done:
# Handle episode termination
@@ -64,7 +64,9 @@ class TFPolicyGraph(PolicyGraph):
loss_inputs (list): a (name, placeholder) tuple for each loss
input argument. Each placeholder name must correspond to a
SampleBatch column key returned by postprocess_trajectory(),
and has shape [BATCH_SIZE, data...].
and has shape [BATCH_SIZE, data...]. These keys will be read
from postprocessed sample batches and fed into the specified
placeholders during loss computation.
state_inputs (list): list of RNN state input Tensors.
state_outputs (list): list of RNN state output Tensors.
prev_action_input (Tensor): placeholder for previous actions
@@ -16,9 +16,13 @@ import argparse
import gym
import random
import tensorflow as tf
import tensorflow.contrib.slim as slim
import ray
from ray import tune
from ray.rllib.agents.pg.pg_policy_graph import PGPolicyGraph
from ray.rllib.agents.ppo.ppo_policy_graph import PPOPolicyGraph
from ray.rllib.models import Model, ModelCatalog
from ray.rllib.test.test_multi_agent_env import MultiCartpole
from ray.tune import run_experiments
from ray.tune.registry import register_env
@@ -29,26 +33,65 @@ parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--num-iters", type=int, default=20)
class CustomModel1(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
# Example of (optional) weight sharing between two different policies.
# Here, we share the variables defined in the 'shared' variable scope
# by entering it explicitly with tf.AUTO_REUSE. This creates the
# variables for the 'fc1' layer in a global scope called 'shared'
# outside of the policy's normal variable scope.
with tf.variable_scope(
tf.VariableScope(tf.AUTO_REUSE, "shared"),
reuse=tf.AUTO_REUSE,
auxiliary_name_scope=False):
last_layer = slim.fully_connected(
input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1")
output = slim.fully_connected(
last_layer, num_outputs, activation_fn=None, scope="fc_out")
return output, last_layer
class CustomModel2(Model):
def _build_layers_v2(self, input_dict, num_outputs, options):
# Weights shared with CustomModel1
with tf.variable_scope(
tf.VariableScope(tf.AUTO_REUSE, "shared"),
reuse=tf.AUTO_REUSE,
auxiliary_name_scope=False):
last_layer = slim.fully_connected(
input_dict["obs"], 64, activation_fn=tf.nn.relu, scope="fc1")
output = slim.fully_connected(
last_layer, num_outputs, activation_fn=None, scope="fc_out")
return output, last_layer
if __name__ == "__main__":
args = parser.parse_args()
ray.init()
# Simple environment with `num_agents` independent cartpole entities
register_env("multi_cartpole", lambda _: MultiCartpole(args.num_agents))
ModelCatalog.register_custom_model("model1", CustomModel1)
ModelCatalog.register_custom_model("model2", CustomModel2)
single_env = gym.make("CartPole-v0")
obs_space = single_env.observation_space
act_space = single_env.action_space
def gen_policy():
# Each policy can have a different configuration (including custom model)
def gen_policy(i):
config = {
"model": {
"custom_model": ["model1", "model2"][i % 2],
},
"gamma": random.choice([0.5, 0.8, 0.9, 0.95, 0.99]),
"n_step": random.choice([1, 2, 3, 4, 5]),
}
return (PGPolicyGraph, obs_space, act_space, config)
return (PPOPolicyGraph, obs_space, act_space, config)
# Setup PG with an ensemble of `num_policies` different policy graphs
# Setup PPO with an ensemble of `num_policies` different policy graphs
policy_graphs = {
"policy_{}".format(i): gen_policy()
"policy_{}".format(i): gen_policy(i)
for i in range(args.num_policies)
}
policy_ids = list(policy_graphs.keys())
+13
View File
@@ -7,6 +7,7 @@ from collections import OrderedDict
import gym
import tensorflow as tf
from ray.rllib.models.misc import linear, normc_initializer
from ray.rllib.models.preprocessors import get_preprocessor
@@ -131,6 +132,18 @@ class Model(object):
"""
raise NotImplementedError
def value_function(self):
"""Builds the value function output.
This method can be overridden to customize the implementation of the
value function (e.g., not sharing hidden layers).
Returns:
Tensor of size [BATCH_SIZE] for the value function.
"""
return tf.reshape(
linear(self.last_layer, 1, "value", normc_initializer(1.0)), [-1])
def _restore_original_dimensions(input_dict, obs_space):
if hasattr(obs_space, "original_space"):
@@ -359,7 +359,7 @@ class TestMultiAgentEnv(unittest.TestCase):
dones=t == 4,
infos={},
new_obs=obs_batch[0])
batch = builder.build_and_reset()
batch = builder.build_and_reset(episode=None)
episodes[0].add_extra_batch(batch)
# Just return zeros for actions
+10 -2
View File
@@ -28,7 +28,11 @@ class MockPolicyGraph(PolicyGraph):
episodes=None):
return [0] * len(obs_batch), [], {}
def postprocess_trajectory(self, batch, other_agent_batches=None):
def postprocess_trajectory(self,
batch,
other_agent_batches=None,
episode=None):
assert episode is not None
return compute_advantages(batch, 100.0, 0.9, use_gae=False)
@@ -42,7 +46,11 @@ class BadPolicyGraph(PolicyGraph):
episodes=None):
raise Exception("intentional error")
def postprocess_trajectory(self, batch, other_agent_batches=None):
def postprocess_trajectory(self,
batch,
other_agent_batches=None,
episode=None):
assert episode is not None
return compute_advantages(batch, 100.0, 0.9, use_gae=False)