mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:38:18 +08:00
[rllib] Fix use_lstm option when using custom model with dict space (#3368)
## What do these changes do? This passes in the right obs space to the lstm model wrapper, so that it doesn't attempt to un-flatten the already processed dict observation. ## Related issue number Closes https://github.com/ray-project/ray/issues/3367
This commit is contained in:
@@ -331,7 +331,7 @@ class Agent(Trainable):
|
||||
self.env_creator = lambda env_config: None
|
||||
|
||||
# Merge the supplied config with the class default
|
||||
merged_config = self._default_config.copy()
|
||||
merged_config = copy.deepcopy(self._default_config)
|
||||
merged_config = deep_update(merged_config, config,
|
||||
self._allow_unknown_configs,
|
||||
self._allow_unknown_subkeys)
|
||||
|
||||
@@ -200,7 +200,9 @@ class ModelCatalog(object):
|
||||
if options.get("use_lstm"):
|
||||
copy = dict(input_dict)
|
||||
copy["obs"] = model.last_layer
|
||||
model = LSTM(copy, obs_space, num_outputs, options, state_in,
|
||||
feature_space = gym.spaces.Box(
|
||||
-1, 1, shape=(model.last_layer.shape[1], ))
|
||||
model = LSTM(copy, feature_space, num_outputs, options, state_in,
|
||||
seq_lens)
|
||||
|
||||
logger.debug("Created model {}: ({} of {}, {}, {}) -> {}, {}".format(
|
||||
|
||||
@@ -174,7 +174,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
},
|
||||
}))
|
||||
|
||||
def doTestNestedDict(self, make_env):
|
||||
def doTestNestedDict(self, make_env, test_lstm=False):
|
||||
ModelCatalog.register_custom_model("composite", DictSpyModel)
|
||||
register_env("nested", make_env)
|
||||
pg = PGAgent(
|
||||
@@ -184,6 +184,7 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
"sample_batch_size": 5,
|
||||
"model": {
|
||||
"custom_model": "composite",
|
||||
"use_lstm": test_lstm,
|
||||
},
|
||||
})
|
||||
pg.train()
|
||||
@@ -230,6 +231,9 @@ class NestedSpacesTest(unittest.TestCase):
|
||||
def testNestedDictGym(self):
|
||||
self.doTestNestedDict(lambda _: NestedDictEnv())
|
||||
|
||||
def testNestedDictGymLSTM(self):
|
||||
self.doTestNestedDict(lambda _: NestedDictEnv(), test_lstm=True)
|
||||
|
||||
def testNestedDictVector(self):
|
||||
self.doTestNestedDict(
|
||||
lambda _: VectorEnv.wrap(lambda i: NestedDictEnv()))
|
||||
|
||||
Reference in New Issue
Block a user