From b01b0f80aa33fc10569f3ab36676ef71fc624d08 Mon Sep 17 00:00:00 2001 From: Yuri Rocha Date: Thu, 28 Jan 2021 21:28:08 +0900 Subject: [PATCH] [RLlib] Fix multiple Unity3DEnvs trying to connect to the same custom port (#13519) --- python/requirements_rllib.txt | 4 ++ rllib/BUILD | 7 +++ rllib/env/wrappers/tests/test_unity3d_env.py | 55 ++++++++++++++++++++ rllib/env/wrappers/unity3d_env.py | 21 ++++++-- 4 files changed, 82 insertions(+), 5 deletions(-) create mode 100644 rllib/env/wrappers/tests/test_unity3d_env.py diff --git a/python/requirements_rllib.txt b/python/requirements_rllib.txt index 0cefb0296..5f5a0f991 100644 --- a/python/requirements_rllib.txt +++ b/python/requirements_rllib.txt @@ -16,3 +16,7 @@ kaggle_environments # For MAML on PyTorch. higher + +# Unity3D testing +mlagents +mlagents_envs diff --git a/rllib/BUILD b/rllib/BUILD index f8f1cbd3c..dd1d4c163 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1069,6 +1069,13 @@ sh_test( data = glob(["examples/serving/*.py"]), ) +py_test( + name = "env/wrappers/tests/test_unity3d_env", + tags = ["env"], + size = "small", + srcs = ["env/wrappers/tests/test_unity3d_env.py"] +) + py_test( name = "env/wrappers/tests/test_recsim_wrapper", tags = ["env"], diff --git a/rllib/env/wrappers/tests/test_unity3d_env.py b/rllib/env/wrappers/tests/test_unity3d_env.py new file mode 100644 index 000000000..5e347ed0e --- /dev/null +++ b/rllib/env/wrappers/tests/test_unity3d_env.py @@ -0,0 +1,55 @@ +import unittest +from unittest.mock import patch + +from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv + + +@patch("mlagents_envs.environment.UnityEnvironment") +class TestUnity3DEnv(unittest.TestCase): + def test_port_editor(self, mock_unity3d): + """Test if the environment uses the editor port + when no environment file is provided""" + + _ = Unity3DEnv(port=None) + args, kwargs = mock_unity3d.call_args + mock_unity3d.assert_called_once() + self.assertEqual(5004, kwargs.get("base_port")) + + def test_port_app(self, mock_unity3d): + """Test if the environment uses the correct port + when the environment file is provided""" + + _ = Unity3DEnv(file_name="app", port=None) + args, kwargs = mock_unity3d.call_args + mock_unity3d.assert_called_once() + self.assertEqual(5005, kwargs.get("base_port")) + + def test_ports_multi_app(self, mock_unity3d): + """Test if the base_port + worker_id + is different for each environment""" + + _ = Unity3DEnv(file_name="app", port=None) + args, kwargs_first = mock_unity3d.call_args + _ = Unity3DEnv(file_name="app", port=None) + args, kwargs_second = mock_unity3d.call_args + self.assertNotEqual( + kwargs_first.get("base_port") + kwargs_first.get("worker_id"), + kwargs_second.get("base_port") + kwargs_second.get("worker_id")) + + def test_custom_port_app(self, mock_unity3d): + """Test if the base_port + worker_id is different + for each environment when using custom ports""" + + _ = Unity3DEnv(file_name="app", port=5010) + args, kwargs_first = mock_unity3d.call_args + _ = Unity3DEnv(file_name="app", port=5010) + args, kwargs_second = mock_unity3d.call_args + self.assertNotEqual( + kwargs_first.get("base_port") + kwargs_first.get("worker_id"), + kwargs_second.get("base_port") + kwargs_second.get("worker_id")) + + +if __name__ == "__main__": + import pytest + import sys + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/env/wrappers/unity3d_env.py b/rllib/env/wrappers/unity3d_env.py index 753c23443..876c06e96 100644 --- a/rllib/env/wrappers/unity3d_env.py +++ b/rllib/env/wrappers/unity3d_env.py @@ -27,7 +27,12 @@ class Unity3DEnv(MultiAgentEnv): inside an RLlib PolicyClient for cloud/distributed training of Unity games. """ - _BASE_PORT = 5004 + # Default base port when connecting directly to the Editor + _BASE_PORT_EDITOR = 5004 + # Default base port when connecting to a compiled environment + _BASE_PORT_ENVIRONMENT = 5005 + # The worker_id for each environment instance + _WORKER_ID = 0 def __init__(self, file_name: str = None, @@ -73,18 +78,24 @@ class Unity3DEnv(MultiAgentEnv): # environments (num_workers >> 1). Otherwise, would lead to port # conflicts sometimes. time.sleep(random.randint(1, 10)) - port_ = port or self._BASE_PORT - self._BASE_PORT += 1 + port_ = port or (self._BASE_PORT_ENVIRONMENT + if file_name else self._BASE_PORT_EDITOR) + # cache the worker_id and + # increase it for the next environment + worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0 + Unity3DEnv._WORKER_ID += 1 try: self.unity_env = UnityEnvironment( file_name=file_name, - worker_id=0, + worker_id=worker_id_, base_port=port_, seed=seed, no_graphics=no_graphics, timeout_wait=timeout_wait, ) - print("Created UnityEnvironment for port {}".format(port_)) + print( + "Created UnityEnvironment for port {}".format(port_ + + worker_id_)) except mlagents_envs.exception.UnityWorkerInUseException: pass else: