mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:44:07 +08:00
[rllib] Properly flatten 2-d observations as input to FCnet (#5733)
This commit is contained in:
@@ -450,7 +450,7 @@ class ModelCatalog(object):
|
||||
else:
|
||||
obs_rank = len(obs_space.shape)
|
||||
|
||||
if obs_rank > 1:
|
||||
if obs_rank > 2:
|
||||
return PyTorchVisionNet(obs_space, action_space, num_outputs,
|
||||
model_config, name)
|
||||
|
||||
@@ -506,7 +506,7 @@ class ModelCatalog(object):
|
||||
|
||||
obs_rank = len(input_dict["obs"].shape) - 1
|
||||
|
||||
if obs_rank > 1:
|
||||
if obs_rank > 2:
|
||||
return VisionNetwork(input_dict, obs_space, action_space,
|
||||
num_outputs, options)
|
||||
|
||||
@@ -521,7 +521,7 @@ class ModelCatalog(object):
|
||||
if options.get("use_lstm"):
|
||||
return None # TODO: default LSTM v2 not implemented
|
||||
|
||||
if obs_rank > 1:
|
||||
if obs_rank > 2:
|
||||
return VisionNetV2
|
||||
|
||||
return FCNetV2
|
||||
|
||||
@@ -190,6 +190,18 @@ class Model(object):
|
||||
self._num_outputs, shape))
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def flatten(obs, framework):
|
||||
"""Flatten the given tensor."""
|
||||
if framework == "tf":
|
||||
return tf.layers.flatten(obs)
|
||||
elif framework == "torch":
|
||||
import torch
|
||||
return torch.flatten(obs, start_dim=1)
|
||||
else:
|
||||
raise NotImplementedError("flatten", framework)
|
||||
|
||||
|
||||
@DeveloperAPI
|
||||
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
|
||||
"""Unpacks Dict and Tuple space observations into their original form.
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from ray.rllib.policy.sample_batch import SampleBatch
|
||||
from ray.rllib.models.model import restore_original_dimensions
|
||||
from ray.rllib.models.model import restore_original_dimensions, flatten
|
||||
from ray.rllib.utils.annotations import PublicAPI
|
||||
|
||||
|
||||
@@ -146,7 +146,10 @@ class ModelV2(object):
|
||||
restored = input_dict.copy()
|
||||
restored["obs"] = restore_original_dimensions(
|
||||
input_dict["obs"], self.obs_space, self.framework)
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
if len(input_dict["obs"].shape) > 2:
|
||||
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
|
||||
else:
|
||||
restored["obs_flat"] = input_dict["obs"]
|
||||
with self.context():
|
||||
res = self.forward(restored, state or [], seq_lens)
|
||||
if ((not isinstance(res, list) and not isinstance(res, tuple))
|
||||
|
||||
@@ -25,6 +25,9 @@ class FullyConnectedNetwork(Model):
|
||||
hiddens = options.get("fcnet_hiddens")
|
||||
activation = get_activation_fn(options.get("fcnet_activation"))
|
||||
|
||||
if len(inputs.shape) > 2:
|
||||
inputs = tf.layers.flatten(inputs)
|
||||
|
||||
with tf.name_scope("fc_net"):
|
||||
i = 1
|
||||
last_layer = inputs
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
|
||||
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
|
||||
from ray.rllib.utils import try_import_tf
|
||||
@@ -22,8 +24,9 @@ class FullyConnectedNetwork(TFModelV2):
|
||||
no_final_linear = model_config.get("no_final_linear")
|
||||
vf_share_layers = model_config.get("vf_share_layers")
|
||||
|
||||
# we are using obs_flat, so take the flattened shape as input
|
||||
inputs = tf.keras.layers.Input(
|
||||
shape=obs_space.shape, name="observations")
|
||||
shape=(np.product(obs_space.shape), ), name="observations")
|
||||
last_layer = inputs
|
||||
i = 1
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ ACTION_SPACES_TO_TEST = {
|
||||
OBSERVATION_SPACES_TO_TEST = {
|
||||
"discrete": Discrete(5),
|
||||
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
|
||||
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
|
||||
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
|
||||
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
|
||||
"tuple": Tuple([Discrete(10),
|
||||
@@ -106,6 +107,7 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
|
||||
def check_support_multiagent(alg, config):
|
||||
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
|
||||
register_env("multi_cartpole", lambda _: MultiCartpole(2))
|
||||
config["log_level"] = "ERROR"
|
||||
if "DDPG" in alg:
|
||||
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user