[rllib] Fix use_lstm option when using custom model with dict space (#3368)

## What do these changes do?

This passes in the right obs space to the lstm model wrapper, so that it doesn't attempt to un-flatten the already processed dict observation.

## Related issue number

Closes https://github.com/ray-project/ray/issues/3367
This commit is contained in:
Eric Liang
2018-11-23 22:51:08 -08:00
committed by Richard Liaw
parent 8b76bab25c
commit 55fca828ce
5 changed files with 9 additions and 29 deletions
+1 -1
View File
@@ -331,7 +331,7 @@ class Agent(Trainable):
self.env_creator = lambda env_config: None
# Merge the supplied config with the class default
merged_config = self._default_config.copy()
merged_config = copy.deepcopy(self._default_config)
merged_config = deep_update(merged_config, config,
self._allow_unknown_configs,
self._allow_unknown_subkeys)
+3 -1
View File
@@ -200,7 +200,9 @@ class ModelCatalog(object):
if options.get("use_lstm"):
copy = dict(input_dict)
copy["obs"] = model.last_layer
model = LSTM(copy, obs_space, num_outputs, options, state_in,
feature_space = gym.spaces.Box(
-1, 1, shape=(model.last_layer.shape[1], ))
model = LSTM(copy, feature_space, num_outputs, options, state_in,
seq_lens)
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
+5 -1
View File
@@ -174,7 +174,7 @@ class NestedSpacesTest(unittest.TestCase):
},
}))
def doTestNestedDict(self, make_env):
def doTestNestedDict(self, make_env, test_lstm=False):
ModelCatalog.register_custom_model("composite", DictSpyModel)
register_env("nested", make_env)
pg = PGAgent(
@@ -184,6 +184,7 @@ class NestedSpacesTest(unittest.TestCase):
"sample_batch_size": 5,
"model": {
"custom_model": "composite",
"use_lstm": test_lstm,
},
})
pg.train()
@@ -230,6 +231,9 @@ class NestedSpacesTest(unittest.TestCase):
def testNestedDictGym(self):
self.doTestNestedDict(lambda _: NestedDictEnv())
def testNestedDictGymLSTM(self):
self.doTestNestedDict(lambda _: NestedDictEnv(), test_lstm=True)
def testNestedDictVector(self):
self.doTestNestedDict(
lambda _: VectorEnv.wrap(lambda i: NestedDictEnv()))