mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 05:28:20 +08:00
Add missing int-casts for all shape calculating code (using np.product([some shape])). (#10092)
This commit is contained in:
@@ -33,7 +33,7 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
|
||||
# We are using obs_flat, so take the flattened shape as input.
|
||||
inputs = tf.keras.layers.Input(
|
||||
shape=(np.product(obs_space.shape), ), name="observations")
|
||||
shape=(int(np.product(obs_space.shape)), ), name="observations")
|
||||
# Last hidden layer output (before logits outputs).
|
||||
last_layer = inputs
|
||||
# The action distribution outputs.
|
||||
@@ -75,7 +75,7 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
# Adjust num_outputs to be the number of nodes in the last layer.
|
||||
else:
|
||||
self.num_outputs = (
|
||||
[np.product(obs_space.shape)] + hiddens[-1:])[-1]
|
||||
[int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
|
||||
|
||||
# Concat the log std vars to the end of the state-dependent means.
|
||||
if free_log_std and logits_out is not None:
|
||||
|
||||
@@ -78,7 +78,7 @@ class FullyConnectedNetwork(TorchModelV2, nn.Module):
|
||||
activation_fn=None)
|
||||
else:
|
||||
self.num_outputs = (
|
||||
[np.product(obs_space.shape)] + hiddens[-1:])[-1]
|
||||
[int(np.product(obs_space.shape))] + hiddens[-1:])[-1]
|
||||
|
||||
# Layer to add the log std vars to the state-dependent means.
|
||||
if self.free_log_std and self._logits:
|
||||
|
||||
Reference in New Issue
Block a user