[rllib] Add more warnings when multi-agent envs might not be set up right (#3061)

This commit is contained in:
Eric Liang
2018-10-15 13:42:56 -07:00
committed by GitHub
parent 3c891c6ece
commit 6240ccbc6e
4 changed files with 24 additions and 2 deletions
+5
View File
@@ -268,12 +268,17 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
raise ValueError("Env {} is already done".format(env_id))
env = self.envs[env_id]
obs, rewards, dones, infos = env.step(agent_dict)
assert isinstance(obs, dict), "Not a multi-agent obs"
assert isinstance(rewards, dict), "Not a multi-agent reward"
assert isinstance(dones, dict), "Not a multi-agent return"
assert isinstance(infos, dict), "Not a multi-agent info"
if dones["__all__"]:
self.dones.add(env_id)
self.env_states[env_id].observe(obs, rewards, dones, infos)
def try_reset(self, env_id):
obs = self.env_states[env_id].reset()
assert isinstance(obs, dict), "Not a multi-agent obs"
if obs is not None and env_id in self.dones:
self.dones.remove(env_id)
return obs
@@ -168,6 +168,11 @@ class PolicyEvaluator(EvaluatorInterface):
model_config = model_config or {}
policy_mapping_fn = (policy_mapping_fn
or (lambda agent_id: DEFAULT_POLICY_ID))
if not callable(policy_mapping_fn):
raise ValueError(
"Policy mapping function not callable. If you're using Tune, "
"make sure to escape the function with tune.function() "
"to prevent it from being evaluated as an expression.")
self.env_creator = env_creator
self.sample_batch_size = batch_steps * num_envs
self.batch_mode = batch_mode
@@ -230,7 +235,14 @@ class PolicyEvaluator(EvaluatorInterface):
self.policy_map = self._build_policy_map(policy_dict,
policy_config)
self.multiagent = self.policy_map.keys() != {DEFAULT_POLICY_ID}
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
if self.multiagent:
if not (isinstance(self.env, MultiAgentEnv)
or isinstance(self.env, AsyncVectorEnv)):
raise ValueError(
"Have multiple policy graphs {}, but the env ".format(
self.policy_map) +
"{} is not a subclass of MultiAgentEnv?".format(self.env))
self.filters = {
policy_id: get_filter(observation_filter,
+1 -1
View File
@@ -218,7 +218,7 @@ def _env_runner(async_vector_env,
horizon = (
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
except Exception:
print("Warning, no horizon specified, assuming infinite")
print("*** WARNING ***: no episode horizon specified, assuming inf")
if not horizon:
horizon = float("inf")
@@ -155,6 +155,11 @@ def _resolve_lambda_vars(spec, lambda_vars):
value = fn(_UnresolvedAccessGuard(spec))
except RecursiveDependencyError as e:
error = e
except Exception:
raise ValueError(
"Failed to evaluate expression: {}: {}".format(path, fn) +
". If you meant to pass this as a function literal, use "
"tune.function() to escape it.")
else:
_assign_value(spec, path, value)
resolved[path] = value