diff --git a/rllib/agents/trainer_template.py b/rllib/agents/trainer_template.py index cd622631a..b896958b6 100644 --- a/rllib/agents/trainer_template.py +++ b/rllib/agents/trainer_template.py @@ -65,7 +65,7 @@ def build_trainer( Optional callable that takes the config to check for correctness. It may mutate the config as needed. default_policy (Optional[Type[Policy]]): The default Policy class to - use. + use if `get_policy_class` returns None. get_policy_class (Optional[Callable[ TrainerConfigDict, Optional[Type[Policy]]]]): Optional callable that takes a config and returns the policy class or None. If None diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index 16626fb86..b6a4131ff 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -79,7 +79,10 @@ class WorkerSet: remote_spaces = ray.get(self.remote_workers( )[0].foreach_policy.remote( lambda p, pid: (pid, p.observation_space, p.action_space))) - spaces = {e[0]: (e[1], e[2]) for e in remote_spaces} + spaces = { + e[0]: (getattr(e[1], "original_space", e[1]), e[2]) + for e in remote_spaces + } else: spaces = None