From 690b374581ea5cd0e9aaaf911d70dbf19752d99e Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Wed, 24 Jul 2019 13:09:41 -0700 Subject: [PATCH] [rllib] Add Keras LSTM example with ModelV2 (#5258) --- python/ray/rllib/agents/dqn/simple_q_model.py | 3 + python/ray/rllib/agents/ppo/ppo.py | 12 -- python/ray/rllib/examples/cartpole_lstm.py | 2 +- .../ray/rllib/examples/custom_keras_model.py | 5 +- .../rllib/examples/custom_keras_rnn_model.py | 108 ++++++++++++++++++ .../rllib/models/tf/recurrent_tf_modelv2.py | 51 +++++++++ python/ray/rllib/models/tf/tf_modelv2.py | 4 +- 7 files changed, 166 insertions(+), 19 deletions(-) create mode 100644 python/ray/rllib/examples/custom_keras_rnn_model.py create mode 100644 python/ray/rllib/models/tf/recurrent_tf_modelv2.py diff --git a/python/ray/rllib/agents/dqn/simple_q_model.py b/python/ray/rllib/agents/dqn/simple_q_model.py index a365af658..3efec7921 100644 --- a/python/ray/rllib/agents/dqn/simple_q_model.py +++ b/python/ray/rllib/agents/dqn/simple_q_model.py @@ -2,7 +2,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils.annotations import override from ray.rllib.utils import try_import_tf tf = try_import_tf() @@ -54,6 +56,7 @@ class SimpleQModel(TFModelV2): self.q_value_head = tf.keras.Model(self.model_out, q_out) self.register_variables(self.q_value_head.variables) + @override(ModelV2) def forward(self, input_dict, state, seq_lens): """This generates the model_out tensor input. diff --git a/python/ray/rllib/agents/ppo/ppo.py b/python/ray/rllib/agents/ppo/ppo.py index e31d74862..760467eee 100644 --- a/python/ray/rllib/agents/ppo/ppo.py +++ b/python/ray/rllib/agents/ppo/ppo.py @@ -104,17 +104,6 @@ def update_kl(trainer, fetches): trainer.workers.local_worker().foreach_trainable_policy(update) -def warn_about_obs_filter(trainer): - if "observation_filter" not in trainer.raw_user_config: - # TODO(ekl) remove this message after a few releases - logger.info( - "Important! Since 0.7.0, observation normalization is no " - "longer enabled by default. To enable running-mean " - "normalization, set 'observation_filter': 'MeanStdFilter'. " - "You can ignore this message if your environment doesn't " - "require observation normalization.") - - def warn_about_bad_reward_scales(trainer, result): # Warn about bad clipping configs if trainer.config["vf_clip_param"] <= 0: @@ -164,5 +153,4 @@ PPOTrainer = build_trainer( make_policy_optimizer=choose_policy_optimizer, validate_config=validate_config, after_optimizer_step=update_kl, - before_train_step=warn_about_obs_filter, after_train_result=warn_about_bad_reward_scales) diff --git a/python/ray/rllib/examples/cartpole_lstm.py b/python/ray/rllib/examples/cartpole_lstm.py index 681647872..98d7a2b2f 100644 --- a/python/ray/rllib/examples/cartpole_lstm.py +++ b/python/ray/rllib/examples/cartpole_lstm.py @@ -24,7 +24,7 @@ class CartPoleStatelessEnv(gym.Env): "video.frames_per_second": 60 } - def __init__(self): + def __init__(self, config=None): self.gravity = 9.8 self.masscart = 1.0 self.masspole = 0.1 diff --git a/python/ray/rllib/examples/custom_keras_model.py b/python/ray/rllib/examples/custom_keras_model.py index e82dee966..407510653 100644 --- a/python/ray/rllib/examples/custom_keras_model.py +++ b/python/ray/rllib/examples/custom_keras_model.py @@ -1,7 +1,4 @@ -"""Example of using a custom ModelV2 Keras-style model. - -TODO(ekl): add this to docs once ModelV2 is fully implemented. -""" +"""Example of using a custom ModelV2 Keras-style model.""" from __future__ import absolute_import from __future__ import division diff --git a/python/ray/rllib/examples/custom_keras_rnn_model.py b/python/ray/rllib/examples/custom_keras_rnn_model.py new file mode 100644 index 000000000..7bd793710 --- /dev/null +++ b/python/ray/rllib/examples/custom_keras_rnn_model.py @@ -0,0 +1,108 @@ +"""Example of using a custom RNN keras model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import argparse + +import ray +from ray import tune +from ray.rllib.examples.cartpole_lstm import CartPoleStatelessEnv +from ray.rllib.models import ModelCatalog +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.recurrent_tf_modelv2 import RecurrentTFModelV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + +parser = argparse.ArgumentParser() +parser.add_argument("--run", type=str, default="PPO") +parser.add_argument("--stop", type=int, default=200) + + +class MyKerasRNN(RecurrentTFModelV2): + """Example of using the Keras functional API to define a RNN model.""" + + def __init__(self, + obs_space, + action_space, + num_outputs, + model_config, + name, + hiddens_size=256, + cell_size=64): + super(MyKerasRNN, self).__init__(obs_space, action_space, num_outputs, + model_config, name) + self.cell_size = cell_size + + # Define input layers + input_layer = tf.keras.layers.Input( + shape=(None, obs_space.shape[0]), name="inputs") + state_in_h = tf.keras.layers.Input(shape=(cell_size, ), name="h") + state_in_c = tf.keras.layers.Input(shape=(cell_size, ), name="c") + seq_in = tf.keras.layers.Input(shape=(), name="seq_in") + + # Preprocess observation with a hidden layer and send to LSTM cell + dense1 = tf.keras.layers.Dense( + hiddens_size, activation=tf.nn.relu, name="dense1")(input_layer) + lstm_out, state_h, state_c = tf.keras.layers.LSTM( + cell_size, return_sequences=True, return_state=True, name="lstm")( + inputs=dense1, + mask=tf.sequence_mask(seq_in), + initial_state=[state_in_h, state_in_c]) + + # Postprocess LSTM output with another hidden layer and compute values + dense2 = tf.keras.layers.Dense( + hiddens_size, activation=tf.nn.relu, name="dense2")(lstm_out) + logits = tf.keras.layers.Dense( + self.num_outputs, + activation=tf.keras.activations.linear, + name="logits")(dense2) + values = tf.keras.layers.Dense( + 1, activation=None, name="values")(dense2) + + # Create the RNN model + self.rnn_model = tf.keras.Model( + inputs=[input_layer, seq_in, state_in_h, state_in_c], + outputs=[logits, values, state_h, state_c]) + self.register_variables(self.rnn_model.variables) + self.rnn_model.summary() + + @override(RecurrentTFModelV2) + def forward_rnn(self, inputs, state, seq_lens): + model_out, self._value_out, h, c = self.rnn_model([inputs, seq_lens] + + state) + return model_out, [h, c] + + @override(ModelV2) + def get_initial_state(self): + return [ + np.zeros(self.cell_size, np.float32), + np.zeros(self.cell_size, np.float32), + ] + + @override(ModelV2) + def value_function(self): + return tf.reshape(self._value_out, [-1]) + + +if __name__ == "__main__": + ray.init() + args = parser.parse_args() + ModelCatalog.register_custom_model("rnn", MyKerasRNN) + tune.run( + args.run, + stop={"episode_reward_mean": args.stop}, + config={ + "env": CartPoleStatelessEnv, + "num_envs_per_worker": 4, + "num_sgd_iter": 3, + "vf_loss_coeff": 1e-4, + "model": { + "custom_model": "rnn", + "max_seq_len": 7, + }, + }) diff --git a/python/ray/rllib/models/tf/recurrent_tf_modelv2.py b/python/ray/rllib/models/tf/recurrent_tf_modelv2.py new file mode 100644 index 000000000..a496ad554 --- /dev/null +++ b/python/ray/rllib/models/tf/recurrent_tf_modelv2.py @@ -0,0 +1,51 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray.rllib.models.lstm import add_time_dimension +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_modelv2 import TFModelV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils import try_import_tf + +tf = try_import_tf() + + +class RecurrentTFModelV2(TFModelV2): + """Helper class to simplify implementing RNN models with TFModelV2. + + Instead of implementing forward(), you can implement forward_rnn() which + takes batches with the time dimension added already.""" + + def __init__(self, obs_space, action_space, num_outputs, model_config, + name): + TFModelV2.__init__(self, obs_space, action_space, num_outputs, + model_config, name) + + @override(ModelV2) + def forward(self, input_dict, state, seq_lens): + """Adds time dimension to batch before sending inputs to forward_rnn(). + + You should implement forward_rnn() in your subclass.""" + output, new_state = self.forward_rnn( + add_time_dimension(input_dict["obs_flat"], seq_lens), state, + seq_lens) + return tf.reshape(output, [-1, self.num_outputs]), new_state + + def forward_rnn(self, inputs, state, seq_lens): + """Call the model with the given input tensors and state. + + Arguments: + inputs (dict): observation tensor with shape [B, T, obs_size]. + state (list): list of state tensors, each with shape [B, T, size]. + seq_lens (Tensor): 1d tensor holding input sequence lengths. + + Returns: + (outputs, new_state): The model output tensor of shape + [B, T, num_outputs] and the list of new state tensors each with + shape [B, size]. + """ + raise NotImplementedError("You must implement this for a RNN model") + + def get_initial_state(self): + raise NotImplementedError("You must implement this for a RNN model") diff --git a/python/ray/rllib/models/tf/tf_modelv2.py b/python/ray/rllib/models/tf/tf_modelv2.py index b2769b5ec..34cbbe40d 100644 --- a/python/ray/rllib/models/tf/tf_modelv2.py +++ b/python/ray/rllib/models/tf/tf_modelv2.py @@ -11,13 +11,13 @@ tf = try_import_tf() class TFModelV2(ModelV2): """TF version of ModelV2.""" - def __init__(self, obs_space, action_space, output_spec, model_config, + def __init__(self, obs_space, action_space, num_outputs, model_config, name): ModelV2.__init__( self, obs_space, action_space, - output_spec, + num_outputs, model_config, name, framework="tf")