[rllib] Properly flatten 2-d observations as input to FCnet (#5733)

This commit is contained in:
Eric Liang
2019-09-19 12:10:31 -07:00
parent 7131166d44
commit 6da7eff4b2
6 changed files with 29 additions and 6 deletions
+3 -3
View File
@@ -450,7 +450,7 @@ class ModelCatalog(object):
else:
obs_rank = len(obs_space.shape)
if obs_rank > 1:
if obs_rank > 2:
return PyTorchVisionNet(obs_space, action_space, num_outputs,
model_config, name)
@@ -506,7 +506,7 @@ class ModelCatalog(object):
obs_rank = len(input_dict["obs"].shape) - 1
if obs_rank > 1:
if obs_rank > 2:
return VisionNetwork(input_dict, obs_space, action_space,
num_outputs, options)
@@ -521,7 +521,7 @@ class ModelCatalog(object):
if options.get("use_lstm"):
return None # TODO: default LSTM v2 not implemented
if obs_rank > 1:
if obs_rank > 2:
return VisionNetV2
return FCNetV2
+12
View File
@@ -190,6 +190,18 @@ class Model(object):
self._num_outputs, shape))
@DeveloperAPI
def flatten(obs, framework):
"""Flatten the given tensor."""
if framework == "tf":
return tf.layers.flatten(obs)
elif framework == "torch":
import torch
return torch.flatten(obs, start_dim=1)
else:
raise NotImplementedError("flatten", framework)
@DeveloperAPI
def restore_original_dimensions(obs, obs_space, tensorlib=tf):
"""Unpacks Dict and Tuple space observations into their original form.
+5 -2
View File
@@ -3,7 +3,7 @@ from __future__ import division
from __future__ import print_function
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.models.model import restore_original_dimensions
from ray.rllib.models.model import restore_original_dimensions, flatten
from ray.rllib.utils.annotations import PublicAPI
@@ -146,7 +146,10 @@ class ModelV2(object):
restored = input_dict.copy()
restored["obs"] = restore_original_dimensions(
input_dict["obs"], self.obs_space, self.framework)
restored["obs_flat"] = input_dict["obs"]
if len(input_dict["obs"].shape) > 2:
restored["obs_flat"] = flatten(input_dict["obs"], self.framework)
else:
restored["obs_flat"] = input_dict["obs"]
with self.context():
res = self.forward(restored, state or [], seq_lens)
if ((not isinstance(res, list) and not isinstance(res, tuple))
+3
View File
@@ -25,6 +25,9 @@ class FullyConnectedNetwork(Model):
hiddens = options.get("fcnet_hiddens")
activation = get_activation_fn(options.get("fcnet_activation"))
if len(inputs.shape) > 2:
inputs = tf.layers.flatten(inputs)
with tf.name_scope("fc_net"):
i = 1
last_layer = inputs
+4 -1
View File
@@ -2,6 +2,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.misc import normc_initializer, get_activation_fn
from ray.rllib.utils import try_import_tf
@@ -22,8 +24,9 @@ class FullyConnectedNetwork(TFModelV2):
no_final_linear = model_config.get("no_final_linear")
vf_share_layers = model_config.get("vf_share_layers")
# we are using obs_flat, so take the flattened shape as input
inputs = tf.keras.layers.Input(
shape=obs_space.shape, name="observations")
shape=(np.product(obs_space.shape), ), name="observations")
last_layer = inputs
i = 1
+2
View File
@@ -31,6 +31,7 @@ ACTION_SPACES_TO_TEST = {
OBSERVATION_SPACES_TO_TEST = {
"discrete": Discrete(5),
"vector": Box(-1.0, 1.0, (5, ), dtype=np.float32),
"vector2": Box(-1.0, 1.0, (5, 5), dtype=np.float32),
"image": Box(-1.0, 1.0, (84, 84, 1), dtype=np.float32),
"atari": Box(-1.0, 1.0, (210, 160, 3), dtype=np.float32),
"tuple": Tuple([Discrete(10),
@@ -106,6 +107,7 @@ def check_support(alg, config, stats, check_bounds=False, name=None):
def check_support_multiagent(alg, config):
register_env("multi_mountaincar", lambda _: MultiMountainCar(2))
register_env("multi_cartpole", lambda _: MultiCartpole(2))
config["log_level"] = "ERROR"
if "DDPG" in alg:
a = get_agent_class(alg)(config=config, env="multi_mountaincar")
else: