diff --git a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py index 750f6a7e1..a5a91abb5 100644 --- a/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py +++ b/python/ray/rllib/agents/a3c/a3c_tf_policy_graph.py @@ -58,7 +58,7 @@ class A3CPolicyGraph(LearningRateSchedule, TFPolicyGraph): "prev_actions": self.prev_actions, "prev_rewards": self.prev_rewards, "is_training": self._get_is_training_placeholder(), - }, observation_space, logit_dim, self.config["model"]) + }, observation_space, action_space, logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) self.vf = self.model.value_function() self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, diff --git a/python/ray/rllib/agents/ars/policies.py b/python/ray/rllib/agents/ars/policies.py index 7c4defd69..fe82be5b6 100644 --- a/python/ray/rllib/agents/ars/policies.py +++ b/python/ray/rllib/agents/ars/policies.py @@ -78,7 +78,7 @@ class GenericPolicy(object): model = ModelCatalog.get_model({ "obs": self.inputs - }, obs_space, dist_dim, model_config) + }, obs_space, action_space, dist_dim, model_config) dist = dist_class(model.outputs) self.sampler = dist.sample() diff --git a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py index 439671e93..c329a8b64 100644 --- a/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py +++ b/python/ray/rllib/agents/ddpg/ddpg_policy_graph.py @@ -217,7 +217,7 @@ class DDPGPolicyGraph(TFPolicyGraph): # Actor: P (policy) network with tf.variable_scope(P_SCOPE) as scope: p_values, self.p_model = self._build_p_network( - self.cur_observations, observation_space) + self.cur_observations, observation_space, action_space) self.p_func_vars = _scope_vars(scope.name) # Noise vars for P network except for layer normalization vars @@ -256,14 +256,16 @@ class DDPGPolicyGraph(TFPolicyGraph): # p network evaluation with tf.variable_scope(P_SCOPE, reuse=True) as scope: prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - self.p_t, _ = self._build_p_network(self.obs_t, observation_space) + self.p_t, _ = self._build_p_network(self.obs_t, observation_space, + action_space) p_batchnorm_update_ops = list( set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) # target p network evaluation with tf.variable_scope(P_TARGET_SCOPE) as scope: - p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space) + p_tp1, _ = self._build_p_network(self.obs_tp1, observation_space, + action_space) target_p_func_vars = _scope_vars(scope.name) # Action outputs @@ -283,7 +285,7 @@ class DDPGPolicyGraph(TFPolicyGraph): prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) with tf.variable_scope(Q_SCOPE) as scope: q_t, self.q_model = self._build_q_network( - self.obs_t, observation_space, self.act_t) + self.obs_t, observation_space, action_space, self.act_t) self.q_func_vars = _scope_vars(scope.name) self.stats = { "mean_q": tf.reduce_mean(q_t), @@ -292,11 +294,11 @@ class DDPGPolicyGraph(TFPolicyGraph): } with tf.variable_scope(Q_SCOPE, reuse=True): q_tp0, _ = self._build_q_network(self.obs_t, observation_space, - output_actions) + action_space, output_actions) if self.config["twin_q"]: with tf.variable_scope(TWIN_Q_SCOPE) as scope: twin_q_t, self.twin_q_model = self._build_q_network( - self.obs_t, observation_space, self.act_t) + self.obs_t, observation_space, action_space, self.act_t) self.twin_q_func_vars = _scope_vars(scope.name) q_batchnorm_update_ops = list( set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) @@ -304,12 +306,14 @@ class DDPGPolicyGraph(TFPolicyGraph): # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: q_tp1, _ = self._build_q_network(self.obs_tp1, observation_space, + action_space, output_actions_estimated) target_q_func_vars = _scope_vars(scope.name) if self.config["twin_q"]: with tf.variable_scope(TWIN_Q_TARGET_SCOPE) as scope: twin_q_tp1, _ = self._build_q_network( - self.obs_tp1, observation_space, output_actions_estimated) + self.obs_tp1, observation_space, action_space, + output_actions_estimated) twin_target_q_func_vars = _scope_vars(scope.name) if self.config["twin_q"]: @@ -492,23 +496,23 @@ class DDPGPolicyGraph(TFPolicyGraph): TFPolicyGraph.set_state(self, state[0]) self.set_epsilon(state[1]) - def _build_q_network(self, obs, obs_space, actions): + def _build_q_network(self, obs, obs_space, action_space, actions): q_net = QNetwork( ModelCatalog.get_model({ "obs": obs, "is_training": self._get_is_training_placeholder(), - }, obs_space, 1, self.config["model"]), actions, + }, obs_space, action_space, 1, self.config["model"]), actions, self.config["critic_hiddens"], self.config["critic_hidden_activation"]) return q_net.value, q_net.model - def _build_p_network(self, obs, obs_space): + def _build_p_network(self, obs, obs_space, action_space): policy_net = PNetwork( ModelCatalog.get_model({ "obs": obs, "is_training": self._get_is_training_placeholder(), - }, obs_space, 1, self.config["model"]), self.dim_actions, - self.config["actor_hiddens"], + }, obs_space, action_space, 1, self.config["model"]), + self.dim_actions, self.config["actor_hiddens"], self.config["actor_hidden_activation"], self.config["parameter_noise"]) return policy_net.action_scores, policy_net.model diff --git a/python/ray/rllib/agents/dqn/dqn_policy_graph.py b/python/ray/rllib/agents/dqn/dqn_policy_graph.py index d422b0f3c..686e09312 100644 --- a/python/ray/rllib/agents/dqn/dqn_policy_graph.py +++ b/python/ray/rllib/agents/dqn/dqn_policy_graph.py @@ -312,7 +312,7 @@ class DQNPolicyGraph(TFPolicyGraph): # Action Q network with tf.variable_scope(Q_SCOPE) as scope: q_values, q_logits, q_dist, _ = self._build_q_network( - self.cur_observations, observation_space) + self.cur_observations, observation_space, action_space) self.q_values = q_values self.q_func_vars = _scope_vars(scope.name) @@ -342,7 +342,7 @@ class DQNPolicyGraph(TFPolicyGraph): with tf.variable_scope(Q_SCOPE, reuse=True): prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) q_t, q_logits_t, q_dist_t, model = self._build_q_network( - self.obs_t, observation_space) + self.obs_t, observation_space, action_space) q_batchnorm_update_ops = list( set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) - prev_update_ops) @@ -350,7 +350,7 @@ class DQNPolicyGraph(TFPolicyGraph): # target q network evalution with tf.variable_scope(Q_TARGET_SCOPE) as scope: q_tp1, q_logits_tp1, q_dist_tp1, _ = self._build_q_network( - self.obs_tp1, observation_space) + self.obs_tp1, observation_space, action_space) self.target_q_func_vars = _scope_vars(scope.name) # q scores for actions which we know were selected in the given state. @@ -364,7 +364,7 @@ class DQNPolicyGraph(TFPolicyGraph): with tf.variable_scope(Q_SCOPE, reuse=True): q_tp1_using_online_net, q_logits_tp1_using_online_net, \ q_dist_tp1_using_online_net, _ = self._build_q_network( - self.obs_tp1, observation_space) + self.obs_tp1, observation_space, action_space) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = tf.one_hot( q_tp1_best_using_online_net, self.num_actions) @@ -556,13 +556,14 @@ class DQNPolicyGraph(TFPolicyGraph): def set_epsilon(self, epsilon): self.cur_epsilon = epsilon - def _build_q_network(self, obs, space): + def _build_q_network(self, obs, obs_space, action_space): qnet = QNetwork( ModelCatalog.get_model({ "obs": obs, "is_training": self._get_is_training_placeholder(), - }, space, self.num_actions, self.config["model"]), - self.num_actions, self.config["dueling"], self.config["hiddens"], + }, obs_space, action_space, self.num_actions, + self.config["model"]), self.num_actions, + self.config["dueling"], self.config["hiddens"], self.config["noisy"], self.config["num_atoms"], self.config["v_min"], self.config["v_max"], self.config["sigma0"], self.config["parameter_noise"]) diff --git a/python/ray/rllib/agents/es/policies.py b/python/ray/rllib/agents/es/policies.py index 61f748ce0..78ff29da4 100644 --- a/python/ray/rllib/agents/es/policies.py +++ b/python/ray/rllib/agents/es/policies.py @@ -56,7 +56,7 @@ class GenericPolicy(object): self.action_space, model_options, dist_type="deterministic") model = ModelCatalog.get_model({ "obs": self.inputs - }, obs_space, dist_dim, model_options) + }, obs_space, action_space, dist_dim, model_options) dist = dist_class(model.outputs) self.sampler = dist.sample() diff --git a/python/ray/rllib/agents/impala/vtrace_policy_graph.py b/python/ray/rllib/agents/impala/vtrace_policy_graph.py index 7e888cf85..9d16c337d 100644 --- a/python/ray/rllib/agents/impala/vtrace_policy_graph.py +++ b/python/ray/rllib/agents/impala/vtrace_policy_graph.py @@ -159,6 +159,7 @@ class VTracePolicyGraph(LearningRateSchedule, TFPolicyGraph): "is_training": self._get_is_training_placeholder(), }, observation_space, + action_space, logit_dim, self.config["model"], state_in=existing_state_in, diff --git a/python/ray/rllib/agents/marwil/marwil_policy_graph.py b/python/ray/rllib/agents/marwil/marwil_policy_graph.py index 8d52807ef..7b66350d8 100644 --- a/python/ray/rllib/agents/marwil/marwil_policy_graph.py +++ b/python/ray/rllib/agents/marwil/marwil_policy_graph.py @@ -73,7 +73,8 @@ class MARWILPolicyGraph(TFPolicyGraph): "prev_actions": prev_actions_ph, "prev_rewards": prev_rewards_ph, "is_training": self._get_is_training_placeholder(), - }, observation_space, logit_dim, self.config["model"]) + }, observation_space, action_space, logit_dim, + self.config["model"]) logits = self.model.outputs self.p_func_vars = _scope_vars(scope.name) diff --git a/python/ray/rllib/agents/pg/pg_policy_graph.py b/python/ray/rllib/agents/pg/pg_policy_graph.py index 4907e00ff..8928bb108 100644 --- a/python/ray/rllib/agents/pg/pg_policy_graph.py +++ b/python/ray/rllib/agents/pg/pg_policy_graph.py @@ -39,7 +39,7 @@ class PGPolicyGraph(TFPolicyGraph): "prev_actions": prev_actions, "prev_rewards": prev_rewards, "is_training": self._get_is_training_placeholder(), - }, obs_space, self.logit_dim, self.config["model"]) + }, obs_space, action_space, self.logit_dim, self.config["model"]) action_dist = dist_class(self.model.outputs) # logit for each action # Setup policy loss diff --git a/python/ray/rllib/agents/ppo/appo_policy_graph.py b/python/ray/rllib/agents/ppo/appo_policy_graph.py index d613c64a7..378e089c5 100644 --- a/python/ray/rllib/agents/ppo/appo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/appo_policy_graph.py @@ -235,6 +235,7 @@ class AsyncPPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): "is_training": self._get_is_training_placeholder(), }, observation_space, + action_space, logit_dim, self.config["model"], state_in=existing_state_in, diff --git a/python/ray/rllib/agents/ppo/ppo_policy_graph.py b/python/ray/rllib/agents/ppo/ppo_policy_graph.py index bd098e697..cd0e68ab7 100644 --- a/python/ray/rllib/agents/ppo/ppo_policy_graph.py +++ b/python/ray/rllib/agents/ppo/ppo_policy_graph.py @@ -169,6 +169,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): "is_training": self._get_is_training_placeholder(), }, observation_space, + action_space, logit_dim, self.config["model"], state_in=existing_state_in, @@ -208,7 +209,7 @@ class PPOPolicyGraph(LearningRateSchedule, TFPolicyGraph): "prev_actions": prev_actions_ph, "prev_rewards": prev_rewards_ph, "is_training": self._get_is_training_placeholder(), - }, observation_space, 1, vf_config).outputs + }, observation_space, action_space, 1, vf_config).outputs self.value_function = tf.reshape(self.value_function, [-1]) else: self.value_function = tf.zeros(shape=tf.shape(obs_ph)[:1]) diff --git a/python/ray/rllib/examples/custom_loss.py b/python/ray/rllib/examples/custom_loss.py index 85855992c..005428b00 100644 --- a/python/ray/rllib/examples/custom_loss.py +++ b/python/ray/rllib/examples/custom_loss.py @@ -41,7 +41,8 @@ class CustomLossModel(Model): self.obs_in = input_dict["obs"] with tf.variable_scope("shared", reuse=tf.AUTO_REUSE): self.fcnet = FullyConnectedNetwork(input_dict, self.obs_space, - num_outputs, options) + self.action_space, num_outputs, + options) return self.fcnet.outputs, self.fcnet.last_layer def custom_loss(self, policy_loss, loss_inputs): diff --git a/python/ray/rllib/models/catalog.py b/python/ray/rllib/models/catalog.py index 40bc2a13f..bcd79cbfe 100644 --- a/python/ray/rllib/models/catalog.py +++ b/python/ray/rllib/models/catalog.py @@ -187,6 +187,7 @@ class ModelCatalog(object): @DeveloperAPI def get_model(input_dict, obs_space, + action_space, num_outputs, options, state_in=None, @@ -197,10 +198,11 @@ class ModelCatalog(object): input_dict (dict): Dict of input tensors to the model, including the observation under the "obs" key. obs_space (Space): Observation space of the target gym env. + action_space (Space): Action space of the target gym env. num_outputs (int): The size of the output vector of the model. options (dict): Optional args to pass to the model constructor. state_in (list): Optional RNN state in tensors. - seq_in (Tensor): Optional RNN sequence length tensor. + seq_lens (Tensor): Optional RNN sequence length tensor. Returns: model (models.Model): Neural network model. @@ -208,33 +210,36 @@ class ModelCatalog(object): assert isinstance(input_dict, dict) options = options or MODEL_DEFAULTS - model = ModelCatalog._get_model(input_dict, obs_space, num_outputs, - options, state_in, seq_lens) + model = ModelCatalog._get_model(input_dict, obs_space, action_space, + num_outputs, options, state_in, + seq_lens) if options.get("use_lstm"): copy = dict(input_dict) copy["obs"] = model.last_layer feature_space = gym.spaces.Box( -1, 1, shape=(model.last_layer.shape[1], )) - model = LSTM(copy, feature_space, num_outputs, options, state_in, - seq_lens) + model = LSTM(copy, feature_space, action_space, num_outputs, + options, state_in, seq_lens) - logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format( - model, input_dict, obs_space, state_in, seq_lens, model.outputs, - model.state_out)) + logger.debug( + "Created model {}: ({} of {}, {}, {}, {}) -> {}, {}".format( + model, input_dict, obs_space, action_space, state_in, seq_lens, + model.outputs, model.state_out)) model._validate_output_shape() return model @staticmethod - def _get_model(input_dict, obs_space, num_outputs, options, state_in, - seq_lens): + def _get_model(input_dict, obs_space, action_space, num_outputs, options, + state_in, seq_lens): if options.get("custom_model"): model = options["custom_model"] logger.debug("Using custom model {}".format(model)) return _global_registry.get(RLLIB_MODEL, model)( input_dict, obs_space, + action_space, num_outputs, options, state_in=state_in, @@ -243,10 +248,11 @@ class ModelCatalog(object): obs_rank = len(input_dict["obs"].shape) - 1 if obs_rank > 1: - return VisionNetwork(input_dict, obs_space, num_outputs, options) + return VisionNetwork(input_dict, obs_space, action_space, + num_outputs, options) - return FullyConnectedNetwork(input_dict, obs_space, num_outputs, - options) + return FullyConnectedNetwork(input_dict, obs_space, action_space, + num_outputs, options) @staticmethod @DeveloperAPI diff --git a/python/ray/rllib/models/model.py b/python/ray/rllib/models/model.py index 39324ee81..b5664057d 100644 --- a/python/ray/rllib/models/model.py +++ b/python/ray/rllib/models/model.py @@ -48,6 +48,7 @@ class Model(object): def __init__(self, input_dict, obs_space, + action_space, num_outputs, options, state_in=None, @@ -59,6 +60,7 @@ class Model(object): self.state_in = state_in or [] self.state_out = [] self.obs_space = obs_space + self.action_space = action_space self.num_outputs = num_outputs self.options = options self.scope = tf.get_variable_scope() diff --git a/python/ray/rllib/tests/test_catalog.py b/python/ray/rllib/tests/test_catalog.py index efa1aba0e..9346e1064 100644 --- a/python/ray/rllib/tests/test_catalog.py +++ b/python/ray/rllib/tests/test_catalog.py @@ -73,13 +73,14 @@ class ModelCatalogTest(unittest.TestCase): with tf.variable_scope("test1"): p1 = ModelCatalog.get_model({ "obs": tf.zeros((10, 3), dtype=tf.float32) - }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, {}) + }, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {}) self.assertEqual(type(p1), FullyConnectedNetwork) with tf.variable_scope("test2"): p2 = ModelCatalog.get_model({ "obs": tf.zeros((10, 84, 84, 3), dtype=tf.float32) - }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), 5, {}) + }, Box(0, 1, shape=(84, 84, 3), dtype=np.float32), Discrete(5), 5, + {}) self.assertEqual(type(p2), VisionNetwork) def testCustomModel(self): @@ -87,7 +88,7 @@ class ModelCatalogTest(unittest.TestCase): ModelCatalog.register_custom_model("foo", CustomModel) p1 = ModelCatalog.get_model({ "obs": tf.constant([1, 2, 3]) - }, Box(0, 1, shape=(3, ), dtype=np.float32), 5, + }, Box(0, 1, shape=(3, ), dtype=np.float32), Discrete(5), 5, {"custom_model": "foo"}) self.assertEqual(str(type(p1)), str(CustomModel))