mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:06:31 +08:00
[rllib] Add more warnings when multi-agent envs might not be set up right (#3061)
This commit is contained in:
+5
@@ -268,12 +268,17 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
raise ValueError("Env {} is already done".format(env_id))
|
||||
env = self.envs[env_id]
|
||||
obs, rewards, dones, infos = env.step(agent_dict)
|
||||
assert isinstance(obs, dict), "Not a multi-agent obs"
|
||||
assert isinstance(rewards, dict), "Not a multi-agent reward"
|
||||
assert isinstance(dones, dict), "Not a multi-agent return"
|
||||
assert isinstance(infos, dict), "Not a multi-agent info"
|
||||
if dones["__all__"]:
|
||||
self.dones.add(env_id)
|
||||
self.env_states[env_id].observe(obs, rewards, dones, infos)
|
||||
|
||||
def try_reset(self, env_id):
|
||||
obs = self.env_states[env_id].reset()
|
||||
assert isinstance(obs, dict), "Not a multi-agent obs"
|
||||
if obs is not None and env_id in self.dones:
|
||||
self.dones.remove(env_id)
|
||||
return obs
|
||||
|
||||
@@ -168,6 +168,11 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
model_config = model_config or {}
|
||||
policy_mapping_fn = (policy_mapping_fn
|
||||
or (lambda agent_id: DEFAULT_POLICY_ID))
|
||||
if not callable(policy_mapping_fn):
|
||||
raise ValueError(
|
||||
"Policy mapping function not callable. If you're using Tune, "
|
||||
"make sure to escape the function with tune.function() "
|
||||
"to prevent it from being evaluated as an expression.")
|
||||
self.env_creator = env_creator
|
||||
self.sample_batch_size = batch_steps * num_envs
|
||||
self.batch_mode = batch_mode
|
||||
@@ -230,7 +235,14 @@ class PolicyEvaluator(EvaluatorInterface):
|
||||
self.policy_map = self._build_policy_map(policy_dict,
|
||||
policy_config)
|
||||
|
||||
self.multiagent = self.policy_map.keys() != {DEFAULT_POLICY_ID}
|
||||
self.multiagent = set(self.policy_map.keys()) != {DEFAULT_POLICY_ID}
|
||||
if self.multiagent:
|
||||
if not (isinstance(self.env, MultiAgentEnv)
|
||||
or isinstance(self.env, AsyncVectorEnv)):
|
||||
raise ValueError(
|
||||
"Have multiple policy graphs {}, but the env ".format(
|
||||
self.policy_map) +
|
||||
"{} is not a subclass of MultiAgentEnv?".format(self.env))
|
||||
|
||||
self.filters = {
|
||||
policy_id: get_filter(observation_filter,
|
||||
|
||||
@@ -218,7 +218,7 @@ def _env_runner(async_vector_env,
|
||||
horizon = (
|
||||
async_vector_env.get_unwrapped()[0].spec.max_episode_steps)
|
||||
except Exception:
|
||||
print("Warning, no horizon specified, assuming infinite")
|
||||
print("*** WARNING ***: no episode horizon specified, assuming inf")
|
||||
if not horizon:
|
||||
horizon = float("inf")
|
||||
|
||||
|
||||
@@ -155,6 +155,11 @@ def _resolve_lambda_vars(spec, lambda_vars):
|
||||
value = fn(_UnresolvedAccessGuard(spec))
|
||||
except RecursiveDependencyError as e:
|
||||
error = e
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"Failed to evaluate expression: {}: {}".format(path, fn) +
|
||||
". If you meant to pass this as a function literal, use "
|
||||
"tune.function() to escape it.")
|
||||
else:
|
||||
_assign_value(spec, path, value)
|
||||
resolved[path] = value
|
||||
|
||||
Reference in New Issue
Block a user