Add action space to model (#4210)

This commit is contained in:
Stefan Pantic
2019-03-10 04:23:12 +01:00
committed by Eric Liang
parent 5adb4a6941
commit 36cbde651a
14 changed files with 61 additions and 42 deletions
@@ -58,7 +58,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
"prev_actions": self.prev_actions,
"prev_rewards": self.prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, observation_space, logit_dim, self.config["model"])
}, observation_space, action_space, logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs)
self.vf = self.model.value_function()
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
+1 -1
View File
@@ -78,7 +78,7 @@ class GenericPolicy(object):
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, dist_dim, model_config)
}, obs_space, action_space, dist_dim, model_config)
dist = dist_class(model.outputs)
self.sampler = dist.sample()
@@ -217,7 +217,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
# Actor: P (policy) network
with tf.variable_scope(P_SCOPE) as scope:
p_values, self.p_model = self._build_p_network(
self.cur_observations, observation_space)
self.cur_observations, observation_space, action_space)
self.p_func_vars = _scope_vars(scope.name)
# Noise vars for P network except for layer normalization vars
@@ -256,14 +256,16 @@ class DDPGPolicyGraph(TFPolicyGraph):
# p network evaluation
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
self.p_t, _ = self._build_p_network(self.obs_t, observation_space)
self.p_t, _ = self._build_p_network(self.obs_t, observation_space,
action_space)
p_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
# target p network evaluation
with tf.variable_scope(P_TARGET_SCOPE) as scope:
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space)
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space,
action_space)
target_p_func_vars = _scope_vars(scope.name)
# Action outputs
@@ -283,7 +285,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
with tf.variable_scope(Q_SCOPE) as scope:
q_t, self.q_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.obs_t, observation_space, action_space, self.act_t)
self.q_func_vars = _scope_vars(scope.name)
self.stats = {
"mean_q": tf.reduce_mean(q_t),
@@ -292,11 +294,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
}
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
output_actions)
action_space, output_actions)
if self.config["twin_q"]:
with tf.variable_scope(TWIN_Q_SCOPE) as scope:
twin_q_t, self.twin_q_model = self._build_q_network(
self.obs_t, observation_space, self.act_t)
self.obs_t, observation_space, action_space, self.act_t)
self.twin_q_func_vars = _scope_vars(scope.name)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
@@ -304,12 +306,14 @@ class DDPGPolicyGraph(TFPolicyGraph):
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
action_space,
output_actions_estimated)
target_q_func_vars = _scope_vars(scope.name)
if self.config["twin_q"]:
with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope:
twin_q_tp1, _ = self._build_q_network(
self.obs_tp1, observation_space, output_actions_estimated)
self.obs_tp1, observation_space, action_space,
output_actions_estimated)
twin_target_q_func_vars = _scope_vars(scope.name)
if self.config["twin_q"]:
@@ -492,23 +496,23 @@ class DDPGPolicyGraph(TFPolicyGraph):
TFPolicyGraph.set_state(self, state[0])
self.set_epsilon(state[1])
def _build_q_network(self, obs, obs_space, actions):
def _build_q_network(self, obs, obs_space, action_space, actions):
q_net = QNetwork(
ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), actions,
}, obs_space, action_space, 1, self.config["model"]), actions,
self.config["critic_hiddens"],
self.config["critic_hidden_activation"])
return q_net.value, q_net.model
def _build_p_network(self, obs, obs_space):
def _build_p_network(self, obs, obs_space, action_space):
policy_net = PNetwork(
ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, obs_space, 1, self.config["model"]), self.dim_actions,
self.config["actor_hiddens"],
}, obs_space, action_space, 1, self.config["model"]),
self.dim_actions, self.config["actor_hiddens"],
self.config["actor_hidden_activation"],
self.config["parameter_noise"])
return policy_net.action_scores, policy_net.model
@@ -312,7 +312,7 @@ class DQNPolicyGraph(TFPolicyGraph):
# Action Q network
with tf.variable_scope(Q_SCOPE) as scope:
q_values, q_logits, q_dist, _ = self._build_q_network(
self.cur_observations, observation_space)
self.cur_observations, observation_space, action_space)
self.q_values = q_values
self.q_func_vars = _scope_vars(scope.name)
@@ -342,7 +342,7 @@ class DQNPolicyGraph(TFPolicyGraph):
with tf.variable_scope(Q_SCOPE, reuse=True):
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
self.obs_t, observation_space)
self.obs_t, observation_space, action_space)
q_batchnorm_update_ops = list(
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
prev_update_ops)
@@ -350,7 +350,7 @@ class DQNPolicyGraph(TFPolicyGraph):
# target q network evalution
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
self.obs_tp1, observation_space)
self.obs_tp1, observation_space, action_space)
self.target_q_func_vars = _scope_vars(scope.name)
# q scores for actions which we know were selected in the given state.
@@ -364,7 +364,7 @@ class DQNPolicyGraph(TFPolicyGraph):
with tf.variable_scope(Q_SCOPE, reuse=True):
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
q_dist_tp1_using_online_net, _ = self._build_q_network(
self.obs_tp1, observation_space)
self.obs_tp1, observation_space, action_space)
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
q_tp1_best_one_hot_selection = tf.one_hot(
q_tp1_best_using_online_net, self.num_actions)
@@ -556,13 +556,14 @@ class DQNPolicyGraph(TFPolicyGraph):
def set_epsilon(self, epsilon):
self.cur_epsilon = epsilon
def _build_q_network(self, obs, space):
def _build_q_network(self, obs, obs_space, action_space):
qnet = QNetwork(
ModelCatalog.get_model({
"obs": obs,
"is_training": self._get_is_training_placeholder(),
}, space, self.num_actions, self.config["model"]),
self.num_actions, self.config["dueling"], self.config["hiddens"],
}, obs_space, action_space, self.num_actions,
self.config["model"]), self.num_actions,
self.config["dueling"], self.config["hiddens"],
self.config["noisy"], self.config["num_atoms"],
self.config["v_min"], self.config["v_max"], self.config["sigma0"],
self.config["parameter_noise"])
+1 -1
View File
@@ -56,7 +56,7 @@ class GenericPolicy(object):
self.action_space, model_options, dist_type="deterministic")
model = ModelCatalog.get_model({
"obs": self.inputs
}, obs_space, dist_dim, model_options)
}, obs_space, action_space, dist_dim, model_options)
dist = dist_class(model.outputs)
self.sampler = dist.sample()
@@ -159,6 +159,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
"is_training": self._get_is_training_placeholder(),
},
observation_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
@@ -73,7 +73,8 @@ class MARWILPolicyGraph(TFPolicyGraph):
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
}, observation_space, logit_dim, self.config["model"])
}, observation_space, action_space, logit_dim,
self.config["model"])
logits = self.model.outputs
self.p_func_vars = _scope_vars(scope.name)
@@ -39,7 +39,7 @@ class PGPolicyGraph(TFPolicyGraph):
"prev_actions": prev_actions,
"prev_rewards": prev_rewards,
"is_training": self._get_is_training_placeholder(),
}, obs_space, self.logit_dim, self.config["model"])
}, obs_space, action_space, self.logit_dim, self.config["model"])
action_dist = dist_class(self.model.outputs) # logit for each action
# Setup policy loss
@@ -235,6 +235,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
"is_training": self._get_is_training_placeholder(),
},
observation_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
@@ -169,6 +169,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
"is_training": self._get_is_training_placeholder(),
},
observation_space,
action_space,
logit_dim,
self.config["model"],
state_in=existing_state_in,
@@ -208,7 +209,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
"prev_actions": prev_actions_ph,
"prev_rewards": prev_rewards_ph,
"is_training": self._get_is_training_placeholder(),
}, observation_space, 1, vf_config).outputs
}, observation_space, action_space, 1, vf_config).outputs
self.value_function = tf.reshape(self.value_function, [-1])
else:
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
+2 -1
View File
@@ -41,7 +41,8 @@ class CustomLossModel(Model):
self.obs_in = input_dict["obs"]
with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
num_outputs, options)
self.action_space, num_outputs,
options)
return self.fcnet.outputs, self.fcnet.last_layer
def custom_loss(self, policy_loss, loss_inputs):
+19 -13
View File
@@ -187,6 +187,7 @@ class ModelCatalog(object):
@DeveloperAPI
def get_model(input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=None,
@@ -197,10 +198,11 @@ class ModelCatalog(object):
input_dict (dict): Dict of input tensors to the model, including
the observation under the "obs" key.
obs_space (Space): Observation space of the target gym env.
action_space (Space): Action space of the target gym env.
num_outputs (int): The size of the output vector of the model.
options (dict): Optional args to pass to the model constructor.
state_in (list): Optional RNN state in tensors.
seq_in (Tensor): Optional RNN sequence length tensor.
seq_lens (Tensor): Optional RNN sequence length tensor.
Returns:
model (models.Model): Neural network model.
@@ -208,33 +210,36 @@ class ModelCatalog(object):
assert isinstance(input_dict, dict)
options = options or MODEL_DEFAULTS
model = ModelCatalog._get_model(input_dict, obs_space, num_outputs,
options, state_in, seq_lens)
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
num_outputs, options, state_in,
seq_lens)
if options.get("use_lstm"):
copy = dict(input_dict)
copy["obs"] = model.last_layer
feature_space = gym.spaces.Box(
-1, 1, shape=(model.last_layer.shape[1], ))
model = LSTM(copy, feature_space, num_outputs, options, state_in,
seq_lens)
model = LSTM(copy, feature_space, action_space, num_outputs,
options, state_in, seq_lens)
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
model, input_dict, obs_space, state_in, seq_lens, model.outputs,
model.state_out))
logger.debug(
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
model, input_dict, obs_space, action_space, state_in, seq_lens,
model.outputs, model.state_out))
model._validate_output_shape()
return model
@staticmethod
def _get_model(input_dict, obs_space, num_outputs, options, state_in,
seq_lens):
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
state_in, seq_lens):
if options.get("custom_model"):
model = options["custom_model"]
logger.debug("Using custom model {}".format(model))
return _global_registry.get(RLLIB_MODEL, model)(
input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=state_in,
@@ -243,10 +248,11 @@ class ModelCatalog(object):
obs_rank = len(input_dict["obs"].shape) - 1
if obs_rank > 1:
return VisionNetwork(input_dict, obs_space, num_outputs, options)
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)
return FullyConnectedNetwork(input_dict, obs_space, num_outputs,
options)
return FullyConnectedNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@staticmethod
@DeveloperAPI
+2
View File
@@ -48,6 +48,7 @@ class Model(object):
def __init__(self,
input_dict,
obs_space,
action_space,
num_outputs,
options,
state_in=None,
@@ -59,6 +60,7 @@ class Model(object):
self.state_in = state_in or []
self.state_out = []
self.obs_space = obs_space
self.action_space = action_space
self.num_outputs = num_outputs
self.options = options
self.scope = tf.get_variable_scope()
+4 -3
View File
@@ -73,13 +73,14 @@ class ModelCatalogTest(unittest.TestCase):
with tf.variable_scope("test1"):
p1 = ModelCatalog.get_model({
"obs": tf.zeros((10, 3), dtype=tf.float32)
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {})
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {})
self.assertEqual(type(p1), FullyConnectedNetwork)
with tf.variable_scope("test2"):
p2 = ModelCatalog.get_model({
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {})
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), Discrete(5), 5,
{})
self.assertEqual(type(p2), VisionNetwork)
def testCustomModel(self):
@@ -87,7 +88,7 @@ class ModelCatalogTest(unittest.TestCase):
ModelCatalog.register_custom_model("foo", CustomModel)
p1 = ModelCatalog.get_model({
"obs": tf.constant([1, 2, 3])
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5,
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5,
{"custom_model": "foo"})
self.assertEqual(str(type(p1)), str(CustomModel))