From 5f4c196feda42d022c12c07e5cd4182fbbf46cba Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Thu, 14 May 2020 10:15:50 +0200 Subject: [PATCH] [RLlib] Make PyTorch Model forward pass faster in vf-case. (#8422) --- rllib/BUILD | 2 +- .../models/autoregressive_action_model.py | 9 +++++---- rllib/examples/models/batch_norm_model.py | 9 +++++---- .../models/centralized_critic_models.py | 11 +++++++---- rllib/examples/models/fast_model.py | 9 +++++---- .../models/mobilenet_v2_with_lstm_models.py | 10 ++++++---- rllib/examples/models/rnn_model.py | 18 +++++++----------- rllib/examples/models/shared_weights_model.py | 9 +++++---- rllib/models/tf/modelv1_compat.py | 2 +- rllib/models/torch/fcnet.py | 14 +++++++------- rllib/models/torch/visionnet.py | 12 ++++++------ 11 files changed, 55 insertions(+), 50 deletions(-) diff --git a/rllib/BUILD b/rllib/BUILD index 20c7d0460..900eca88b 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1528,7 +1528,7 @@ py_test( name = "examples/cartpole_lstm_impala_tf", main = "examples/cartpole_lstm.py", tags = ["examples", "examples_C"], - size = "small", + size = "medium", srcs = ["examples/cartpole_lstm.py"], args = ["--as-test", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"] ) diff --git a/rllib/examples/models/autoregressive_action_model.py b/rllib/examples/models/autoregressive_action_model.py index ef570a7a5..06f32f78a 100644 --- a/rllib/examples/models/autoregressive_action_model.py +++ b/rllib/examples/models/autoregressive_action_model.py @@ -148,10 +148,11 @@ class TorchAutoregressiveActionModel(TorchModelV2, nn.Module): # [ctx_input, a1_input]) self.action_module = _ActionModel() + self._context = None + def forward(self, input_dict, state, seq_lens): - context = self.context_layer(input_dict["obs"]) - self._value_out = self.value_branch(context) - return context, state + self._context = self.context_layer(input_dict["obs"]) + return self._context, state def value_function(self): - return torch.reshape(self._value_out, [-1]) + return torch.reshape(self.value_branch(self._context), [-1]) diff --git a/rllib/examples/models/batch_norm_model.py b/rllib/examples/models/batch_norm_model.py index 693df3910..3330abd90 100644 --- a/rllib/examples/models/batch_norm_model.py +++ b/rllib/examples/models/batch_norm_model.py @@ -174,17 +174,18 @@ class TorchBatchNormModel(TorchModelV2, nn.Module): activation_fn=None) self._hidden_layers = nn.Sequential(*layers) + self._hidden_out = None @override(ModelV2) def forward(self, input_dict, state, seq_lens): # Set the correct train-mode for our hidden module (only important # b/c we have some batch-norm layers). self._hidden_layers.train(mode=input_dict["is_training"]) - hidden_out = self._hidden_layers(input_dict["obs"]) - logits = self._logits(hidden_out) - self._value_out = self._value_branch(hidden_out) + self._hidden_out = self._hidden_layers(input_dict["obs"]) + logits = self._logits(self._hidden_out) return logits, [] @override(ModelV2) def value_function(self): - return torch.reshape(self._value_out, [-1]) + assert self._hidden_out is not None, "must call forward first!" + return torch.reshape(self._value_branch(self._hidden_out), [-1]) diff --git a/rllib/examples/models/centralized_critic_models.py b/rllib/examples/models/centralized_critic_models.py index 77c40cd54..325b6b493 100644 --- a/rllib/examples/models/centralized_critic_models.py +++ b/rllib/examples/models/centralized_critic_models.py @@ -161,14 +161,17 @@ class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module): self.value_model = TorchFC(obs_space, action_space, 1, model_config, name + "_vf") + self._model_in = None def forward(self, input_dict, state, seq_lens): - self._value_out, _ = self.value_model({ - "obs": input_dict["obs_flat"] - }, state, seq_lens) + # Store model-input for possible `value_function()` call. + self._model_in = [input_dict["obs_flat"], state, seq_lens] return self.action_model({ "obs": input_dict["obs"]["own_obs"] }, state, seq_lens) def value_function(self): - return torch.reshape(self._value_out, [-1]) + value_out, _ = self.value_model({ + "obs": self._model_in[0] + }, self._model_in[1], self._model_in[2]) + return torch.reshape(value_out, [-1]) diff --git a/rllib/examples/models/fast_model.py b/rllib/examples/models/fast_model.py index 15c35a752..7e6528db7 100644 --- a/rllib/examples/models/fast_model.py +++ b/rllib/examples/models/fast_model.py @@ -63,14 +63,15 @@ class TorchFastModel(TorchModelV2, nn.Module): # Only needed to give some params to the optimizer (even though, # they are never used anywhere). self.dummy_layer = SlimFC(1, 1) + self._output = None @override(ModelV2) def forward(self, input_dict, state, seq_lens): - output = self.bias + \ + self._output = self.bias + \ torch.zeros(size=(input_dict["obs"].shape[0], self.num_outputs)) - self._value_out = torch.mean(output, -1) # fake value - return output, [] + return self._output, [] @override(ModelV2) def value_function(self): - return torch.reshape(self._value_out, [-1]) + assert self._output is not None, "must call forward first!" + return torch.reshape(torch.mean(self._output, -1), [-1]) diff --git a/rllib/examples/models/mobilenet_v2_with_lstm_models.py b/rllib/examples/models/mobilenet_v2_with_lstm_models.py index ade63d73f..4ce78f4dc 100644 --- a/rllib/examples/models/mobilenet_v2_with_lstm_models.py +++ b/rllib/examples/models/mobilenet_v2_with_lstm_models.py @@ -114,6 +114,8 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel): # Postprocess LSTM output with another hidden layer and compute values. self.logits = SlimFC(self.lstm_state_size, self.num_outputs) self.value_branch = SlimFC(self.lstm_state_size, 1) + # Holds the current "base" output (before logits layer). + self._features = None @override(RecurrentTFModelV2) def forward_rnn(self, inputs, state, seq_lens): @@ -128,10 +130,9 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel): state[0] = state[0].unsqueeze(0) state[1] = state[1].unsqueeze(0) # Forward through LSTM. - lstm_out, [h, c] = self.lstm(vision_out_time_ranked, state) + self._features, [h, c] = self.lstm(vision_out_time_ranked, state) # Forward LSTM out through logits layer and value layer. - logits = self.logits(lstm_out) - self._value_out = self.value_branch(lstm_out) + logits = self.logits(self._features) return logits, [h.squeeze(0), c.squeeze(0)] @override(ModelV2) @@ -147,4 +148,5 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel): @override(ModelV2) def value_function(self): - return torch.reshape(self._value_out, [-1]) + assert self._features is not None, "must call forward() first" + return torch.reshape(self.value_branch(self._features), [-1]) diff --git a/rllib/examples/models/rnn_model.py b/rllib/examples/models/rnn_model.py index 963623ae2..e13eee24c 100644 --- a/rllib/examples/models/rnn_model.py +++ b/rllib/examples/models/rnn_model.py @@ -97,8 +97,8 @@ class TorchRNNModel(RecurrentTorchModel): self.fc_size, self.lstm_state_size, batch_first=True) self.action_branch = nn.Linear(self.lstm_state_size, num_outputs) self.value_branch = nn.Linear(self.lstm_state_size, 1) - # Store the value output to save an extra forward pass. - self._cur_value = None + # Holds the current "base" output (before logits layer). + self._features = None @override(ModelV2) def get_initial_state(self): @@ -111,8 +111,8 @@ class TorchRNNModel(RecurrentTorchModel): @override(ModelV2) def value_function(self): - assert self._cur_value is not None, "must call forward() first" - return self._cur_value + assert self._features is not None, "must call forward() first" + return torch.reshape(self.value_branch(self._features), [-1]) @override(RecurrentTorchModel) def forward_rnn(self, inputs, state, seq_lens): @@ -127,12 +127,8 @@ class TorchRNNModel(RecurrentTorchModel): The state batches as a List of two items (c- and h-states). """ x = nn.functional.relu(self.fc1(inputs)) - lstm_out = self.lstm( + self._features, [h, c] = self.lstm( x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]) - action_out = self.action_branch(lstm_out[0]) - self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1]) - return action_out, [ - torch.squeeze(lstm_out[1][0], 0), - torch.squeeze(lstm_out[1][1], 0) - ] + action_out = self.action_branch(self._features) + return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)] diff --git a/rllib/examples/models/shared_weights_model.py b/rllib/examples/models/shared_weights_model.py index 1e68902f9..3348fa0a6 100644 --- a/rllib/examples/models/shared_weights_model.py +++ b/rllib/examples/models/shared_weights_model.py @@ -110,15 +110,16 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module): # Non-shared final layer. self.last_layer = SlimFC(32, self.num_outputs, activation_fn=nn.ReLU) self.vf = SlimFC(32, 1, activation_fn=None) + self._output = None @override(ModelV2) def forward(self, input_dict, state, seq_lens): out = self.first_layer(input_dict["obs"]) - out = TORCH_GLOBAL_SHARED_LAYER(out) - model_out = self.last_layer(out) - self._value_out = self.vf(out) + self._output = TORCH_GLOBAL_SHARED_LAYER(out) + model_out = self.last_layer(self._output) return model_out, [] @override(ModelV2) def value_function(self): - return torch.reshape(self._value_out, [-1]) + assert self._output is not None, "must call forward first!" + return torch.reshape(self.vf(self._output), [-1]) diff --git a/rllib/models/tf/modelv1_compat.py b/rllib/models/tf/modelv1_compat.py index 6a2917309..d5ff99150 100644 --- a/rllib/models/tf/modelv1_compat.py +++ b/rllib/models/tf/modelv1_compat.py @@ -110,7 +110,7 @@ def make_v1_wrapper(legacy_model_cls): @override(ModelV2) def value_function(self): - assert self.cur_instance, "must call forward first" + assert self.cur_instance is not None, "must call forward first" with tf.variable_scope(self.variable_scope): with tf.variable_scope("value_function", reuse=tf.AUTO_REUSE): diff --git a/rllib/models/torch/fcnet.py b/rllib/models/torch/fcnet.py index 4065309bb..5dfa6ba86 100644 --- a/rllib/models/torch/fcnet.py +++ b/rllib/models/torch/fcnet.py @@ -96,20 +96,20 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module): out_size=1, initializer=normc_initializer(1.0), activation_fn=None) - # Holds the current value output. - self._cur_value = None + # Holds the current "base" output (before logits layer). + self._features = None @override(TorchModelV2) def forward(self, input_dict, state, seq_lens): obs = input_dict["obs_flat"].float() - features = self._hidden_layers(obs.reshape(obs.shape[0], -1)) - logits = self._logits(features) if self._logits else features + self._features = self._hidden_layers(obs.reshape(obs.shape[0], -1)) + logits = self._logits(self._features) if self._logits else \ + self._features if self.free_log_std: logits = self._append_free_log_std(logits) - self._cur_value = self._value_branch(features).squeeze(1) return logits, state @override(TorchModelV2) def value_function(self): - assert self._cur_value is not None, "must call forward() first" - return self._cur_value + assert self._features is not None, "must call forward() first" + return self._value_branch(self._features).squeeze(1) diff --git a/rllib/models/torch/visionnet.py b/rllib/models/torch/visionnet.py index 7dd55b4b5..66bb04f78 100644 --- a/rllib/models/torch/visionnet.py +++ b/rllib/models/torch/visionnet.py @@ -58,19 +58,19 @@ class VisionNetwork(TorchModelV2, nn.Module): out_channels, num_outputs, initializer=nn.init.xavier_uniform_) self._value_branch = SlimFC( out_channels, 1, initializer=normc_initializer()) - self._cur_value = None + # Holds the current "base" output (before logits layer). + self._features = None @override(TorchModelV2) def forward(self, input_dict, state, seq_lens): - features = self._hidden_layers(input_dict["obs"].float()) - logits = self._logits(features) - self._cur_value = self._value_branch(features).squeeze(1) + self._features = self._hidden_layers(input_dict["obs"].float()) + logits = self._logits(self._features) return logits, state @override(TorchModelV2) def value_function(self): - assert self._cur_value is not None, "must call forward() first" - return self._cur_value + assert self._features is not None, "must call forward() first" + return self._value_branch(self._features).squeeze(1) def _hidden_layers(self, obs): res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major