diff --git a/python/requirements.txt b/python/requirements.txt index 3fefd37c8..2be4b13b8 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -56,7 +56,7 @@ mypy networkx numba openpyxl -pettingzoo +pettingzoo>=1.3.2 Pillow; platform_system != "Windows" pygments pytest==5.4.3 diff --git a/rllib/env/pettingzoo_env.py b/rllib/env/pettingzoo_env.py index 1704cfbc9..816752a28 100644 --- a/rllib/env/pettingzoo_env.py +++ b/rllib/env/pettingzoo_env.py @@ -156,13 +156,21 @@ class PettingZooEnv(MultiAgentEnv): infos (dict): Optional info values for each agent id. """ stepped_agents = set() - while self.aec_env.agent_selection not in stepped_agents: + while (self.aec_env.agent_selection not in stepped_agents + and self.aec_env.dones[self.aec_env.agent_selection]): agent = self.aec_env.agent_selection - assert agent in action_dict, \ + self.aec_env.step(None) + stepped_agents.add(agent) + stepped_agents = set() + # print(action_dict) + while (self.aec_env.agent_selection not in stepped_agents): + agent = self.aec_env.agent_selection + assert agent in action_dict or self.aec_env.dones[agent], \ "Live environment agent is not in actions dictionary" self.aec_env.step(action_dict[agent]) stepped_agents.add(agent) - + # print(self.aec_env.dones) + # print(stepped_agents) assert all(agent in stepped_agents or self.aec_env.dones[agent] for agent in action_dict), \ "environment has a nontrivial ordering, and cannot be used with"\ @@ -234,11 +242,18 @@ class ParallelPettingZooEnv(MultiAgentEnv): return self.par_env.reset() def step(self, action_dict): - for agent in self.agents: - action_dict[agent] = self.action_space.sample() - obs, rew, dones, info = self.par_env.step(action_dict) - dones["__all__"] = all(dones.values()) - return obs, rew, dones, info + aobs, arew, adones, ainfo = self.par_env.step(action_dict) + obss = {} + rews = {} + dones = {} + infos = {} + for agent in action_dict: + obss[agent] = aobs[agent] + rews[agent] = arew[agent] + dones[agent] = adones[agent] + infos[agent] = ainfo[agent] + dones["__all__"] = all(adones.values()) + return obss, rews, dones, infos def close(self): self.par_env.close()