mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 03:21:00 +08:00
[RLlib] Make PyTorch Model forward pass faster in vf-case. (#8422)
This commit is contained in:
+1
-1
@@ -1528,7 +1528,7 @@ py_test(
|
||||
name = "examples/cartpole_lstm_impala_tf",
|
||||
main = "examples/cartpole_lstm.py",
|
||||
tags = ["examples", "examples_C"],
|
||||
size = "small",
|
||||
size = "medium",
|
||||
srcs = ["examples/cartpole_lstm.py"],
|
||||
args = ["--as-test", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
|
||||
)
|
||||
|
||||
@@ -148,10 +148,11 @@ class TorchAutoregressiveActionModel(TorchModelV2, nn.Module):
|
||||
# [ctx_input, a1_input])
|
||||
self.action_module = _ActionModel()
|
||||
|
||||
self._context = None
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
context = self.context_layer(input_dict["obs"])
|
||||
self._value_out = self.value_branch(context)
|
||||
return context, state
|
||||
self._context = self.context_layer(input_dict["obs"])
|
||||
return self._context, state
|
||||
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
return torch.reshape(self.value_branch(self._context), [-1])
|
||||
|
||||
@@ -174,17 +174,18 @@ class TorchBatchNormModel(TorchModelV2, nn.Module):
|
||||
activation_fn=None)
|
||||
|
||||
self._hidden_layers = nn.Sequential(*layers)
|
||||
self._hidden_out = None
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
# Set the correct train-mode for our hidden module (only important
|
||||
# b/c we have some batch-norm layers).
|
||||
self._hidden_layers.train(mode=input_dict["is_training"])
|
||||
hidden_out = self._hidden_layers(input_dict["obs"])
|
||||
logits = self._logits(hidden_out)
|
||||
self._value_out = self._value_branch(hidden_out)
|
||||
self._hidden_out = self._hidden_layers(input_dict["obs"])
|
||||
logits = self._logits(self._hidden_out)
|
||||
return logits, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
assert self._hidden_out is not None, "must call forward first!"
|
||||
return torch.reshape(self._value_branch(self._hidden_out), [-1])
|
||||
|
||||
@@ -161,14 +161,17 @@ class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
|
||||
|
||||
self.value_model = TorchFC(obs_space, action_space, 1, model_config,
|
||||
name + "_vf")
|
||||
self._model_in = None
|
||||
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
self._value_out, _ = self.value_model({
|
||||
"obs": input_dict["obs_flat"]
|
||||
}, state, seq_lens)
|
||||
# Store model-input for possible `value_function()` call.
|
||||
self._model_in = [input_dict["obs_flat"], state, seq_lens]
|
||||
return self.action_model({
|
||||
"obs": input_dict["obs"]["own_obs"]
|
||||
}, state, seq_lens)
|
||||
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
value_out, _ = self.value_model({
|
||||
"obs": self._model_in[0]
|
||||
}, self._model_in[1], self._model_in[2])
|
||||
return torch.reshape(value_out, [-1])
|
||||
|
||||
@@ -63,14 +63,15 @@ class TorchFastModel(TorchModelV2, nn.Module):
|
||||
# Only needed to give some params to the optimizer (even though,
|
||||
# they are never used anywhere).
|
||||
self.dummy_layer = SlimFC(1, 1)
|
||||
self._output = None
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
output = self.bias + \
|
||||
self._output = self.bias + \
|
||||
torch.zeros(size=(input_dict["obs"].shape[0], self.num_outputs))
|
||||
self._value_out = torch.mean(output, -1) # fake value
|
||||
return output, []
|
||||
return self._output, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
assert self._output is not None, "must call forward first!"
|
||||
return torch.reshape(torch.mean(self._output, -1), [-1])
|
||||
|
||||
@@ -114,6 +114,8 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
|
||||
# Postprocess LSTM output with another hidden layer and compute values.
|
||||
self.logits = SlimFC(self.lstm_state_size, self.num_outputs)
|
||||
self.value_branch = SlimFC(self.lstm_state_size, 1)
|
||||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
@override(RecurrentTFModelV2)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
@@ -128,10 +130,9 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
|
||||
state[0] = state[0].unsqueeze(0)
|
||||
state[1] = state[1].unsqueeze(0)
|
||||
# Forward through LSTM.
|
||||
lstm_out, [h, c] = self.lstm(vision_out_time_ranked, state)
|
||||
self._features, [h, c] = self.lstm(vision_out_time_ranked, state)
|
||||
# Forward LSTM out through logits layer and value layer.
|
||||
logits = self.logits(lstm_out)
|
||||
self._value_out = self.value_branch(lstm_out)
|
||||
logits = self.logits(self._features)
|
||||
return logits, [h.squeeze(0), c.squeeze(0)]
|
||||
|
||||
@override(ModelV2)
|
||||
@@ -147,4 +148,5 @@ class TorchMobileV2PlusRNNModel(RecurrentTorchModel):
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
assert self._features is not None, "must call forward() first"
|
||||
return torch.reshape(self.value_branch(self._features), [-1])
|
||||
|
||||
@@ -97,8 +97,8 @@ class TorchRNNModel(RecurrentTorchModel):
|
||||
self.fc_size, self.lstm_state_size, batch_first=True)
|
||||
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
|
||||
self.value_branch = nn.Linear(self.lstm_state_size, 1)
|
||||
# Store the value output to save an extra forward pass.
|
||||
self._cur_value = None
|
||||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
@override(ModelV2)
|
||||
def get_initial_state(self):
|
||||
@@ -111,8 +111,8 @@ class TorchRNNModel(RecurrentTorchModel):
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
assert self._features is not None, "must call forward() first"
|
||||
return torch.reshape(self.value_branch(self._features), [-1])
|
||||
|
||||
@override(RecurrentTorchModel)
|
||||
def forward_rnn(self, inputs, state, seq_lens):
|
||||
@@ -127,12 +127,8 @@ class TorchRNNModel(RecurrentTorchModel):
|
||||
The state batches as a List of two items (c- and h-states).
|
||||
"""
|
||||
x = nn.functional.relu(self.fc1(inputs))
|
||||
lstm_out = self.lstm(
|
||||
self._features, [h, c] = self.lstm(
|
||||
x, [torch.unsqueeze(state[0], 0),
|
||||
torch.unsqueeze(state[1], 0)])
|
||||
action_out = self.action_branch(lstm_out[0])
|
||||
self._cur_value = torch.reshape(self.value_branch(lstm_out[0]), [-1])
|
||||
return action_out, [
|
||||
torch.squeeze(lstm_out[1][0], 0),
|
||||
torch.squeeze(lstm_out[1][1], 0)
|
||||
]
|
||||
action_out = self.action_branch(self._features)
|
||||
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
|
||||
|
||||
@@ -110,15 +110,16 @@ class TorchSharedWeightsModel(TorchModelV2, nn.Module):
|
||||
# Non-shared final layer.
|
||||
self.last_layer = SlimFC(32, self.num_outputs, activation_fn=nn.ReLU)
|
||||
self.vf = SlimFC(32, 1, activation_fn=None)
|
||||
self._output = None
|
||||
|
||||
@override(ModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
out = self.first_layer(input_dict["obs"])
|
||||
out = TORCH_GLOBAL_SHARED_LAYER(out)
|
||||
model_out = self.last_layer(out)
|
||||
self._value_out = self.vf(out)
|
||||
self._output = TORCH_GLOBAL_SHARED_LAYER(out)
|
||||
model_out = self.last_layer(self._output)
|
||||
return model_out, []
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
return torch.reshape(self._value_out, [-1])
|
||||
assert self._output is not None, "must call forward first!"
|
||||
return torch.reshape(self.vf(self._output), [-1])
|
||||
|
||||
@@ -110,7 +110,7 @@ def make_v1_wrapper(legacy_model_cls):
|
||||
|
||||
@override(ModelV2)
|
||||
def value_function(self):
|
||||
assert self.cur_instance, "must call forward first"
|
||||
assert self.cur_instance is not None, "must call forward first"
|
||||
|
||||
with tf.variable_scope(self.variable_scope):
|
||||
with tf.variable_scope("value_function", reuse=tf.AUTO_REUSE):
|
||||
|
||||
@@ -96,20 +96,20 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
||||
out_size=1,
|
||||
initializer=normc_initializer(1.0),
|
||||
activation_fn=None)
|
||||
# Holds the current value output.
|
||||
self._cur_value = None
|
||||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
obs = input_dict["obs_flat"].float()
|
||||
features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
|
||||
logits = self._logits(features) if self._logits else features
|
||||
self._features = self._hidden_layers(obs.reshape(obs.shape[0], -1))
|
||||
logits = self._logits(self._features) if self._logits else \
|
||||
self._features
|
||||
if self.free_log_std:
|
||||
logits = self._append_free_log_std(logits)
|
||||
self._cur_value = self._value_branch(features).squeeze(1)
|
||||
return logits, state
|
||||
|
||||
@override(TorchModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
assert self._features is not None, "must call forward() first"
|
||||
return self._value_branch(self._features).squeeze(1)
|
||||
|
||||
@@ -58,19 +58,19 @@ class VisionNetwork(TorchModelV2, nn.Module):
|
||||
out_channels, num_outputs, initializer=nn.init.xavier_uniform_)
|
||||
self._value_branch = SlimFC(
|
||||
out_channels, 1, initializer=normc_initializer())
|
||||
self._cur_value = None
|
||||
# Holds the current "base" output (before logits layer).
|
||||
self._features = None
|
||||
|
||||
@override(TorchModelV2)
|
||||
def forward(self, input_dict, state, seq_lens):
|
||||
features = self._hidden_layers(input_dict["obs"].float())
|
||||
logits = self._logits(features)
|
||||
self._cur_value = self._value_branch(features).squeeze(1)
|
||||
self._features = self._hidden_layers(input_dict["obs"].float())
|
||||
logits = self._logits(self._features)
|
||||
return logits, state
|
||||
|
||||
@override(TorchModelV2)
|
||||
def value_function(self):
|
||||
assert self._cur_value is not None, "must call forward() first"
|
||||
return self._cur_value
|
||||
assert self._features is not None, "must call forward() first"
|
||||
return self._value_branch(self._features).squeeze(1)
|
||||
|
||||
def _hidden_layers(self, obs):
|
||||
res = self._convs(obs.permute(0, 3, 1, 2)) # switch to channel-major
|
||||
|
||||
Reference in New Issue
Block a user