[RLlib] Make PyTorch Model forward pass faster in vf-case. (#8422)

This commit is contained in:
Sven Mika
2020-05-14 10:15:50 +02:00
committed by GitHub
parent 212f78f735
commit 5f4c196fed
11 changed files with 55 additions and 50 deletions
+1 -1
View File
@@ -1528,7 +1528,7 @@ py_test(
name = "examples/cartpole_lstm_impala_tf",
main = "examples/cartpole_lstm.py",
tags = ["examples", "examples_C"],
size = "small",
size = "medium",
srcs = ["examples/cartpole_lstm.py"],
args = ["--as-test", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
)
@@ -148,10 +148,11 @@ class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
# [ctx_input, a1_input])
self.action_module = _ActionModel()
self._context = None
def forward(self, input_dict, state, seq_lens):
context = self.context_layer(input_dict["obs"])
self._value_out = self.value_branch(context)
return context, state
self._context = self.context_layer(input_dict["obs"])
return self._context, state
def value_function(self):
return torch.reshape(self._value_out, [-1])
return torch.reshape(self.value_branch(self._context), [-1])
+5 -4
View File
@@ -174,17 +174,18 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
activation_fn=None)
self._hidden_layers = nn.Sequential(*layers)
self._hidden_out = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
# Set the correct train-mode for our hidden module (only important
# b/c we have some batch-norm layers).
self._hidden_layers.train(mode=input_dict["is_training"])
hidden_out = self._hidden_layers(input_dict["obs"])
logits = self._logits(hidden_out)
self._value_out = self._value_branch(hidden_out)
self._hidden_out = self._hidden_layers(input_dict["obs"])
logits = self._logits(self._hidden_out)
return logits, []
@override(ModelV2)
def value_function(self):
return torch.reshape(self._value_out, [-1])
assert self._hidden_out is not None, "must call forward first!"
return torch.reshape(self._value_branch(self._hidden_out), [-1])
@@ -161,14 +161,17 @@ class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
self.value_model = TorchFC(obs_space, action_space, 1, model_config,
name + "_vf")
self._model_in = None
def forward(self, input_dict, state, seq_lens):
self._value_out, _ = self.value_model({
"obs": input_dict["obs_flat"]
}, state, seq_lens)
# Store model-input for possible `value_function()` call.
self._model_in = [input_dict["obs_flat"], state, seq_lens]
return self.action_model({
"obs": input_dict["obs"]["own_obs"]
}, state, seq_lens)
def value_function(self):
return torch.reshape(self._value_out, [-1])
value_out, _ = self.value_model({
"obs": self._model_in[0]
}, self._model_in[1], self._model_in[2])
return torch.reshape(value_out, [-1])
+5 -4
View File
@@ -63,14 +63,15 @@ class TorchFastModel(TorchModelV2, nn.Module):
# Only needed to give some params to the optimizer (even though,
# they are never used anywhere).
self.dummy_layer = SlimFC(1, 1)
self._output = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
output = self.bias + \
self._output = self.bias + \
torch.zeros(size=(input_dict["obs"].shape[0], self.num_outputs))
self._value_out = torch.mean(output, -1) # fake value
return output, []
return self._output, []
@override(ModelV2)
def value_function(self):
return torch.reshape(self._value_out, [-1])
assert self._output is not None, "must call forward first!"
return torch.reshape(torch.mean(self._output, -1), [-1])
@@ -114,6 +114,8 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
# Postprocess LSTM output with another hidden layer and compute values.
self.logits = SlimFC(self.lstm_state_size, self.num_outputs)
self.value_branch = SlimFC(self.lstm_state_size, 1)
# Holds the current "base" output (before logits layer).
self._features = None
@override(RecurrentTFModelV2)
def forward_rnn(self, inputs, state, seq_lens):
@@ -128,10 +130,9 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
state[0] = state[0].unsqueeze(0)
state[1] = state[1].unsqueeze(0)
# Forward through LSTM.
lstm_out, [h, c] = self.lstm(vision_out_time_ranked, state)
self._features, [h, c] = self.lstm(vision_out_time_ranked, state)
# Forward LSTM out through logits layer and value layer.
logits = self.logits(lstm_out)
self._value_out = self.value_branch(lstm_out)
logits = self.logits(self._features)
return logits, [h.squeeze(0), c.squeeze(0)]
@override(ModelV2)
@@ -147,4 +148,5 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
@override(ModelV2)
def value_function(self):
return torch.reshape(self._value_out, [-1])
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
+7 -11
View File
@@ -97,8 +97,8 @@ class TorchRNNModel(RecurrentTorchModel):
self.fc_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Store the value output to save an extra forward pass.
self._cur_value = None
# Holds the current "base" output (before logits layer).
self._features = None
@override(ModelV2)
def get_initial_state(self):
@@ -111,8 +111,8 @@ class TorchRNNModel(RecurrentTorchModel):
@override(ModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
@override(RecurrentTorchModel)
def forward_rnn(self, inputs, state, seq_lens):
@@ -127,12 +127,8 @@ class TorchRNNModel(RecurrentTorchModel):
The state batches as a List of two items (c- and h-states).
"""
x = nn.functional.relu(self.fc1(inputs))
lstm_out = self.lstm(
self._features, [h, c] = self.lstm(
x, [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(lstm_out[0])
self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1])
return action_out, [
torch.squeeze(lstm_out[1][0], 0),
torch.squeeze(lstm_out[1][1], 0)
]
action_out = self.action_branch(self._features)
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
@@ -110,15 +110,16 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
# Non-shared final layer.
self.last_layer = SlimFC(32, self.num_outputs, activation_fn=nn.ReLU)
self.vf = SlimFC(32, 1, activation_fn=None)
self._output = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
out = self.first_layer(input_dict["obs"])
out = TORCH_GLOBAL_SHARED_LAYER(out)
model_out = self.last_layer(out)
self._value_out = self.vf(out)
self._output = TORCH_GLOBAL_SHARED_LAYER(out)
model_out = self.last_layer(self._output)
return model_out, []
@override(ModelV2)
def value_function(self):
return torch.reshape(self._value_out, [-1])
assert self._output is not None, "must call forward first!"
return torch.reshape(self.vf(self._output), [-1])
+1 -1
View File
@@ -110,7 +110,7 @@ def make_v1_wrapper(legacy_model_cls):
@override(ModelV2)
def value_function(self):
assert self.cur_instance, "must call forward first"
assert self.cur_instance is not None, "must call forward first"
with tf.variable_scope(self.variable_scope):
with tf.variable_scope("value_function", reuse=tf.AUTO_REUSE):
+7 -7
View File
@@ -96,20 +96,20 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
out_size=1,
initializer=normc_initializer(1.0),
activation_fn=None)
# Holds the current value output.
self._cur_value = None
# Holds the current "base" output (before logits layer).
self._features = None
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs_flat"].float()
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
logits = self._logits(features) if self._logits else features
self._features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
logits = self._logits(self._features) if self._logits else \
self._features
if self.free_log_std:
logits = self._append_free_log_std(logits)
self._cur_value = self._value_branch(features).squeeze(1)
return logits, state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
assert self._features is not None, "must call forward() first"
return self._value_branch(self._features).squeeze(1)
+6 -6
View File
@@ -58,19 +58,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
self._value_branch = SlimFC(
out_channels, 1, initializer=normc_initializer())
self._cur_value = None
# Holds the current "base" output (before logits layer).
self._features = None
@override(TorchModelV2)
def forward(self, input_dict, state, seq_lens):
features = self._hidden_layers(input_dict["obs"].float())
logits = self._logits(features)
self._cur_value = self._value_branch(features).squeeze(1)
self._features = self._hidden_layers(input_dict["obs"].float())
logits = self._logits(self._features)
return logits, state
@override(TorchModelV2)
def value_function(self):
assert self._cur_value is not None, "must call forward() first"
return self._cur_value
assert self._features is not None, "must call forward() first"
return self._value_branch(self._features).squeeze(1)
def _hidden_layers(self, obs):
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major