mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 17:34:51 +08:00
[rllib] Don't call reset() unless necessary for multi-agent envs
This commit is contained in:
+8
-1
@@ -319,14 +319,21 @@ class _MultiAgentEnvToAsync(AsyncVectorEnv):
|
||||
self.dones.remove(env_id)
|
||||
return obs
|
||||
|
||||
@override(AsyncVectorEnv)
|
||||
def get_unwrapped(self):
|
||||
return [state.env for state in self.env_states]
|
||||
|
||||
|
||||
class _MultiAgentEnvState(object):
|
||||
def __init__(self, env):
|
||||
assert isinstance(env, MultiAgentEnv)
|
||||
self.env = env
|
||||
self.reset()
|
||||
self.initialized = False
|
||||
|
||||
def poll(self):
|
||||
if not self.initialized:
|
||||
self.reset()
|
||||
self.initialized = True
|
||||
obs, rew, dones, info = (self.last_obs, self.last_rewards,
|
||||
self.last_dones, self.last_infos)
|
||||
self.last_obs = {}
|
||||
|
||||
@@ -36,8 +36,10 @@ class BasicMultiAgent(MultiAgentEnv):
|
||||
self.dones = set()
|
||||
self.observation_space = gym.spaces.Discrete(2)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.resetted = False
|
||||
|
||||
def reset(self):
|
||||
self.resetted = True
|
||||
self.dones = set()
|
||||
return {i: a.reset() for i, a in enumerate(self.agents)}
|
||||
|
||||
@@ -173,6 +175,12 @@ class TestMultiAgentEnv(unittest.TestCase):
|
||||
obs, rew, done, info = env.step({0: 0})
|
||||
self.assertEqual(done["__all__"], True)
|
||||
|
||||
def testNoResetUntilPoll(self):
|
||||
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 1)
|
||||
self.assertFalse(env.get_unwrapped()[0].resetted)
|
||||
env.poll()
|
||||
self.assertTrue(env.get_unwrapped()[0].resetted)
|
||||
|
||||
def testVectorizeBasic(self):
|
||||
env = _MultiAgentEnvToAsync(lambda v: BasicMultiAgent(2), [], 2)
|
||||
obs, rew, dones, _, _ = env.poll()
|
||||
|
||||
Reference in New Issue
Block a user