diff --git a/rllib/env/vector_env.py b/rllib/env/vector_env.py index 10b14bfbb..49d4bdf6d 100644 --- a/rllib/env/vector_env.py +++ b/rllib/env/vector_env.py @@ -145,7 +145,7 @@ class _VectorizedGymEnv(VectorEnv): raise ValueError( "Reward should be finite scalar, got {} ({}). " "Actions={}.".format(r, type(r), actions[i])) - if type(info) is not dict: + if not isinstance(info, dict): raise ValueError("Info should be a dict, got {} ({})".format( info, type(info))) obs_batch.append(obs) diff --git a/rllib/tests/test_vector_env.py b/rllib/tests/test_vector_env.py new file mode 100644 index 000000000..ef03dd1af --- /dev/null +++ b/rllib/tests/test_vector_env.py @@ -0,0 +1,32 @@ +import gym +import unittest + +from ray.rllib.env.vector_env import VectorEnv + + +class Info(dict): + pass + + +class MockEnvDictSubclass(gym.Env): + def __init__(self): + self.observation_space = gym.spaces.Discrete(1) + self.action_space = gym.spaces.Discrete(2) + + def reset(self): + return 0 + + def step(self, action): + return 0, 1, True, Info() + + +class TestExternalEnv(unittest.TestCase): + def test_vector_step(self): + env = VectorEnv.wrap(lambda _: MockEnvDictSubclass(), num_envs=3) + env.vector_step([0] * 3) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__]))