From ea25482f6a4467e8cc3aa6543d83da47543b44b6 Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 9 Dec 2020 20:49:21 +0100 Subject: [PATCH] WIP. (#12706) --- rllib/agents/trainer_template.py | 2 +- rllib/evaluation/worker_set.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index cd622631a..b896958b6 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -65,7 +65,7 @@ def build_trainer( Optional callable that takes the config to check for correctness. It may mutate the config as needed. default_policy (Optional[Type[Policy]]): The default Policy class to - use. + use if `get_policy_class` returns None. get_policy_class (Optional[Callable[ TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable that takes a config and returns the policy class or None. If None diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 16626fb86..b6a4131ff 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -79,7 +79,10 @@ class WorkerSet: remote_spaces = ray.get(self.remote_workers( )[0].foreach_policy.remote( lambda p, pid: (pid, p.observation_space, p.action_space))) - spaces = {e[0]: (e[1], e[2]) for e in remote_spaces} + spaces = { + e[0]: (getattr(e[1], "original_space", e[1]), e[2]) + for e in remote_spaces + } else: spaces = None