mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 01:58:30 +08:00
[rllib] Fix VectorEnv's check for the info object's type (#10982)
This commit is contained in:
Vendored
+1
-1
@@ -145,7 +145,7 @@ class _VectorizedGymEnv(VectorEnv):
|
||||
raise ValueError(
|
||||
"Reward should be finite scalar, got {} ({}). "
|
||||
"Actions={}.".format(r, type(r), actions[i]))
|
||||
if type(info) is not dict:
|
||||
if not isinstance(info, dict):
|
||||
raise ValueError("Info should be a dict, got {} ({})".format(
|
||||
info, type(info)))
|
||||
obs_batch.append(obs)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import gym
|
||||
import unittest
|
||||
|
||||
from ray.rllib.env.vector_env import VectorEnv
|
||||
|
||||
|
||||
class Info(dict):
|
||||
pass
|
||||
|
||||
|
||||
class MockEnvDictSubclass(gym.Env):
|
||||
def __init__(self):
|
||||
self.observation_space = gym.spaces.Discrete(1)
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
|
||||
def reset(self):
|
||||
return 0
|
||||
|
||||
def step(self, action):
|
||||
return 0, 1, True, Info()
|
||||
|
||||
|
||||
class TestExternalEnv(unittest.TestCase):
|
||||
def test_vector_step(self):
|
||||
env = VectorEnv.wrap(lambda _: MockEnvDictSubclass(), num_envs=3)
|
||||
env.vector_step([0] * 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Reference in New Issue
Block a user