mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 12:02:04 +08:00
Add action space to model (#4210)
This commit is contained in:
committed by
Eric Liang
parent
5adb4a6941
commit
36cbde651a
@@ -58,7 +58,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"prev_actions": self.prev_actions,
|
||||
"prev_rewards": self.prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, logit_dim, self.config["model"])
|
||||
}, observation_space, action_space, logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs)
|
||||
self.vf = self.model.value_function()
|
||||
self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
|
||||
|
||||
@@ -78,7 +78,7 @@ class GenericPolicy(object):
|
||||
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, dist_dim, model_config)
|
||||
}, obs_space, action_space, dist_dim, model_config)
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
|
||||
@@ -217,7 +217,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
# Actor: P (policy) network
|
||||
with tf.variable_scope(P_SCOPE) as scope:
|
||||
p_values, self.p_model = self._build_p_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.cur_observations, observation_space, action_space)
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Noise vars for P network except for layer normalization vars
|
||||
@@ -256,14 +256,16 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
# p network evaluation
|
||||
with tf.variable_scope(P_SCOPE, reuse=True) as scope:
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
self.p_t, _ = self._build_p_network(self.obs_t, observation_space)
|
||||
self.p_t, _ = self._build_p_network(self.obs_t, observation_space,
|
||||
action_space)
|
||||
p_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
|
||||
# target p network evaluation
|
||||
with tf.variable_scope(P_TARGET_SCOPE) as scope:
|
||||
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space)
|
||||
p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space,
|
||||
action_space)
|
||||
target_p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# Action outputs
|
||||
@@ -283,7 +285,7 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_t, self.q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.obs_t, observation_space, action_space, self.act_t)
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
self.stats = {
|
||||
"mean_q": tf.reduce_mean(q_t),
|
||||
@@ -292,11 +294,11 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
}
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp0, _ = self._build_q_network(self.obs_t, observation_space,
|
||||
output_actions)
|
||||
action_space, output_actions)
|
||||
if self.config["twin_q"]:
|
||||
with tf.variable_scope(TWIN_Q_SCOPE) as scope:
|
||||
twin_q_t, self.twin_q_model = self._build_q_network(
|
||||
self.obs_t, observation_space, self.act_t)
|
||||
self.obs_t, observation_space, action_space, self.act_t)
|
||||
self.twin_q_func_vars = _scope_vars(scope.name)
|
||||
q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops)
|
||||
@@ -304,12 +306,14 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space,
|
||||
action_space,
|
||||
output_actions_estimated)
|
||||
target_q_func_vars = _scope_vars(scope.name)
|
||||
if self.config["twin_q"]:
|
||||
with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope:
|
||||
twin_q_tp1, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space, output_actions_estimated)
|
||||
self.obs_tp1, observation_space, action_space,
|
||||
output_actions_estimated)
|
||||
twin_target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
if self.config["twin_q"]:
|
||||
@@ -492,23 +496,23 @@ class DDPGPolicyGraph(TFPolicyGraph):
|
||||
TFPolicyGraph.set_state(self, state[0])
|
||||
self.set_epsilon(state[1])
|
||||
|
||||
def _build_q_network(self, obs, obs_space, actions):
|
||||
def _build_q_network(self, obs, obs_space, action_space, actions):
|
||||
q_net = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"]), actions,
|
||||
}, obs_space, action_space, 1, self.config["model"]), actions,
|
||||
self.config["critic_hiddens"],
|
||||
self.config["critic_hidden_activation"])
|
||||
return q_net.value, q_net.model
|
||||
|
||||
def _build_p_network(self, obs, obs_space):
|
||||
def _build_p_network(self, obs, obs_space, action_space):
|
||||
policy_net = PNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, 1, self.config["model"]), self.dim_actions,
|
||||
self.config["actor_hiddens"],
|
||||
}, obs_space, action_space, 1, self.config["model"]),
|
||||
self.dim_actions, self.config["actor_hiddens"],
|
||||
self.config["actor_hidden_activation"],
|
||||
self.config["parameter_noise"])
|
||||
return policy_net.action_scores, policy_net.model
|
||||
|
||||
@@ -312,7 +312,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
# Action Q network
|
||||
with tf.variable_scope(Q_SCOPE) as scope:
|
||||
q_values, q_logits, q_dist, _ = self._build_q_network(
|
||||
self.cur_observations, observation_space)
|
||||
self.cur_observations, observation_space, action_space)
|
||||
self.q_values = q_values
|
||||
self.q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
@@ -342,7 +342,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS))
|
||||
q_t, q_logits_t, q_dist_t, model = self._build_q_network(
|
||||
self.obs_t, observation_space)
|
||||
self.obs_t, observation_space, action_space)
|
||||
q_batchnorm_update_ops = list(
|
||||
set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) -
|
||||
prev_update_ops)
|
||||
@@ -350,7 +350,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
# target q network evalution
|
||||
with tf.variable_scope(Q_TARGET_SCOPE) as scope:
|
||||
q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space)
|
||||
self.obs_tp1, observation_space, action_space)
|
||||
self.target_q_func_vars = _scope_vars(scope.name)
|
||||
|
||||
# q scores for actions which we know were selected in the given state.
|
||||
@@ -364,7 +364,7 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
with tf.variable_scope(Q_SCOPE, reuse=True):
|
||||
q_tp1_using_online_net, q_logits_tp1_using_online_net, \
|
||||
q_dist_tp1_using_online_net, _ = self._build_q_network(
|
||||
self.obs_tp1, observation_space)
|
||||
self.obs_tp1, observation_space, action_space)
|
||||
q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1)
|
||||
q_tp1_best_one_hot_selection = tf.one_hot(
|
||||
q_tp1_best_using_online_net, self.num_actions)
|
||||
@@ -556,13 +556,14 @@ class DQNPolicyGraph(TFPolicyGraph):
|
||||
def set_epsilon(self, epsilon):
|
||||
self.cur_epsilon = epsilon
|
||||
|
||||
def _build_q_network(self, obs, space):
|
||||
def _build_q_network(self, obs, obs_space, action_space):
|
||||
qnet = QNetwork(
|
||||
ModelCatalog.get_model({
|
||||
"obs": obs,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, space, self.num_actions, self.config["model"]),
|
||||
self.num_actions, self.config["dueling"], self.config["hiddens"],
|
||||
}, obs_space, action_space, self.num_actions,
|
||||
self.config["model"]), self.num_actions,
|
||||
self.config["dueling"], self.config["hiddens"],
|
||||
self.config["noisy"], self.config["num_atoms"],
|
||||
self.config["v_min"], self.config["v_max"], self.config["sigma0"],
|
||||
self.config["parameter_noise"])
|
||||
|
||||
@@ -56,7 +56,7 @@ class GenericPolicy(object):
|
||||
self.action_space, model_options, dist_type="deterministic")
|
||||
model = ModelCatalog.get_model({
|
||||
"obs": self.inputs
|
||||
}, obs_space, dist_dim, model_options)
|
||||
}, obs_space, action_space, dist_dim, model_options)
|
||||
dist = dist_class(model.outputs)
|
||||
self.sampler = dist.sample()
|
||||
|
||||
|
||||
@@ -159,6 +159,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
|
||||
@@ -73,7 +73,8 @@ class MARWILPolicyGraph(TFPolicyGraph):
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, logit_dim, self.config["model"])
|
||||
}, observation_space, action_space, logit_dim,
|
||||
self.config["model"])
|
||||
logits = self.model.outputs
|
||||
self.p_func_vars = _scope_vars(scope.name)
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ class PGPolicyGraph(TFPolicyGraph):
|
||||
"prev_actions": prev_actions,
|
||||
"prev_rewards": prev_rewards,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, obs_space, self.logit_dim, self.config["model"])
|
||||
}, obs_space, action_space, self.logit_dim, self.config["model"])
|
||||
action_dist = dist_class(self.model.outputs) # logit for each action
|
||||
|
||||
# Setup policy loss
|
||||
|
||||
@@ -235,6 +235,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
|
||||
@@ -169,6 +169,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
},
|
||||
observation_space,
|
||||
action_space,
|
||||
logit_dim,
|
||||
self.config["model"],
|
||||
state_in=existing_state_in,
|
||||
@@ -208,7 +209,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph):
|
||||
"prev_actions": prev_actions_ph,
|
||||
"prev_rewards": prev_rewards_ph,
|
||||
"is_training": self._get_is_training_placeholder(),
|
||||
}, observation_space, 1, vf_config).outputs
|
||||
}, observation_space, action_space, 1, vf_config).outputs
|
||||
self.value_function = tf.reshape(self.value_function, [-1])
|
||||
else:
|
||||
self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1])
|
||||
|
||||
@@ -41,7 +41,8 @@ class CustomLossModel(Model):
|
||||
self.obs_in = input_dict["obs"]
|
||||
with tf.variable_scope("shared", reuse=tf.AUTO_REUSE):
|
||||
self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space,
|
||||
num_outputs, options)
|
||||
self.action_space, num_outputs,
|
||||
options)
|
||||
return self.fcnet.outputs, self.fcnet.last_layer
|
||||
|
||||
def custom_loss(self, policy_loss, loss_inputs):
|
||||
|
||||
@@ -187,6 +187,7 @@ class ModelCatalog(object):
|
||||
@DeveloperAPI
|
||||
def get_model(input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
options,
|
||||
state_in=None,
|
||||
@@ -197,10 +198,11 @@ class ModelCatalog(object):
|
||||
input_dict (dict): Dict of input tensors to the model, including
|
||||
the observation under the "obs" key.
|
||||
obs_space (Space): Observation space of the target gym env.
|
||||
action_space (Space): Action space of the target gym env.
|
||||
num_outputs (int): The size of the output vector of the model.
|
||||
options (dict): Optional args to pass to the model constructor.
|
||||
state_in (list): Optional RNN state in tensors.
|
||||
seq_in (Tensor): Optional RNN sequence length tensor.
|
||||
seq_lens (Tensor): Optional RNN sequence length tensor.
|
||||
|
||||
Returns:
|
||||
model (models.Model): Neural network model.
|
||||
@@ -208,33 +210,36 @@ class ModelCatalog(object):
|
||||
|
||||
assert isinstance(input_dict, dict)
|
||||
options = options or MODEL_DEFAULTS
|
||||
model = ModelCatalog._get_model(input_dict, obs_space, num_outputs,
|
||||
options, state_in, seq_lens)
|
||||
model = ModelCatalog._get_model(input_dict, obs_space, action_space,
|
||||
num_outputs, options, state_in,
|
||||
seq_lens)
|
||||
|
||||
if options.get("use_lstm"):
|
||||
copy = dict(input_dict)
|
||||
copy["obs"] = model.last_layer
|
||||
feature_space = gym.spaces.Box(
|
||||
-1, 1, shape=(model.last_layer.shape[1], ))
|
||||
model = LSTM(copy, feature_space, num_outputs, options, state_in,
|
||||
seq_lens)
|
||||
model = LSTM(copy, feature_space, action_space, num_outputs,
|
||||
options, state_in, seq_lens)
|
||||
|
||||
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
|
||||
model, input_dict, obs_space, state_in, seq_lens, model.outputs,
|
||||
model.state_out))
|
||||
logger.debug(
|
||||
"Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format(
|
||||
model, input_dict, obs_space, action_space, state_in, seq_lens,
|
||||
model.outputs, model.state_out))
|
||||
|
||||
model._validate_output_shape()
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _get_model(input_dict, obs_space, num_outputs, options, state_in,
|
||||
seq_lens):
|
||||
def _get_model(input_dict, obs_space, action_space, num_outputs, options,
|
||||
state_in, seq_lens):
|
||||
if options.get("custom_model"):
|
||||
model = options["custom_model"]
|
||||
logger.debug("Using custom model {}".format(model))
|
||||
return _global_registry.get(RLLIB_MODEL, model)(
|
||||
input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
options,
|
||||
state_in=state_in,
|
||||
@@ -243,10 +248,11 @@ class ModelCatalog(object):
|
||||
obs_rank = len(input_dict["obs"].shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
return VisionNetwork(input_dict, obs_space, num_outputs, options)
|
||||
return VisionNetwork(input_dict, obs_space, action_space,
|
||||
num_outputs, options)
|
||||
|
||||
return FullyConnectedNetwork(input_dict, obs_space, num_outputs,
|
||||
options)
|
||||
return FullyConnectedNetwork(input_dict, obs_space, action_space,
|
||||
num_outputs, options)
|
||||
|
||||
@staticmethod
|
||||
@DeveloperAPI
|
||||
|
||||
@@ -48,6 +48,7 @@ class Model(object):
|
||||
def __init__(self,
|
||||
input_dict,
|
||||
obs_space,
|
||||
action_space,
|
||||
num_outputs,
|
||||
options,
|
||||
state_in=None,
|
||||
@@ -59,6 +60,7 @@ class Model(object):
|
||||
self.state_in = state_in or []
|
||||
self.state_out = []
|
||||
self.obs_space = obs_space
|
||||
self.action_space = action_space
|
||||
self.num_outputs = num_outputs
|
||||
self.options = options
|
||||
self.scope = tf.get_variable_scope()
|
||||
|
||||
@@ -73,13 +73,14 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
with tf.variable_scope("test1"):
|
||||
p1 = ModelCatalog.get_model({
|
||||
"obs": tf.zeros((10, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {})
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {})
|
||||
self.assertEqual(type(p1), FullyConnectedNetwork)
|
||||
|
||||
with tf.variable_scope("test2"):
|
||||
p2 = ModelCatalog.get_model({
|
||||
"obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32)
|
||||
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {})
|
||||
}, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), Discrete(5), 5,
|
||||
{})
|
||||
self.assertEqual(type(p2), VisionNetwork)
|
||||
|
||||
def testCustomModel(self):
|
||||
@@ -87,7 +88,7 @@ class ModelCatalogTest(unittest.TestCase):
|
||||
ModelCatalog.register_custom_model("foo", CustomModel)
|
||||
p1 = ModelCatalog.get_model({
|
||||
"obs": tf.constant([1, 2, 3])
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), 5,
|
||||
}, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5,
|
||||
{"custom_model": "foo"})
|
||||
self.assertEqual(str(type(p1)), str(CustomModel))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user