[RLlib] Fix multiple Unity3DEnvs trying to connect to the same custom port (#13519)

This commit is contained in:
Yuri Rocha
2021-01-28 21:28:08 +09:00
committed by GitHub
parent d4ef5c5993
commit b01b0f80aa
4 changed files with 82 additions and 5 deletions
+4
View File
@@ -16,3 +16,7 @@ kaggle_environments
# For MAML on PyTorch.
higher
# Unity3D testing
mlagents
mlagents_envs
+7
View File
@@ -1069,6 +1069,13 @@ sh_test(
data = glob(["examples/serving/*.py"]),
)
py_test(
name = "env/wrappers/tests/test_unity3d_env",
tags = ["env"],
size = "small",
srcs = ["env/wrappers/tests/test_unity3d_env.py"]
)
py_test(
name = "env/wrappers/tests/test_recsim_wrapper",
tags = ["env"],
+55
View File
@@ -0,0 +1,55 @@
import unittest
from unittest.mock import patch
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
@patch("mlagents_envs.environment.UnityEnvironment")
class TestUnity3DEnv(unittest.TestCase):
def test_port_editor(self, mock_unity3d):
"""Test if the environment uses the editor port
when no environment file is provided"""
_ = Unity3DEnv(port=None)
args, kwargs = mock_unity3d.call_args
mock_unity3d.assert_called_once()
self.assertEqual(5004, kwargs.get("base_port"))
def test_port_app(self, mock_unity3d):
"""Test if the environment uses the correct port
when the environment file is provided"""
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs = mock_unity3d.call_args
mock_unity3d.assert_called_once()
self.assertEqual(5005, kwargs.get("base_port"))
def test_ports_multi_app(self, mock_unity3d):
"""Test if the base_port + worker_id
is different for each environment"""
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs_first = mock_unity3d.call_args
_ = Unity3DEnv(file_name="app", port=None)
args, kwargs_second = mock_unity3d.call_args
self.assertNotEqual(
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
def test_custom_port_app(self, mock_unity3d):
"""Test if the base_port + worker_id is different
for each environment when using custom ports"""
_ = Unity3DEnv(file_name="app", port=5010)
args, kwargs_first = mock_unity3d.call_args
_ = Unity3DEnv(file_name="app", port=5010)
args, kwargs_second = mock_unity3d.call_args
self.assertNotEqual(
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
+16 -5
View File
@@ -27,7 +27,12 @@ class Unity3DEnv(MultiAgentEnv):
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
"""
_BASE_PORT = 5004
# Default base port when connecting directly to the Editor
_BASE_PORT_EDITOR = 5004
# Default base port when connecting to a compiled environment
_BASE_PORT_ENVIRONMENT = 5005
# The worker_id for each environment instance
_WORKER_ID = 0
def __init__(self,
file_name: str = None,
@@ -73,18 +78,24 @@ class Unity3DEnv(MultiAgentEnv):
# environments (num_workers >> 1). Otherwise, would lead to port
# conflicts sometimes.
time.sleep(random.randint(1, 10))
port_ = port or self._BASE_PORT
self._BASE_PORT += 1
port_ = port or (self._BASE_PORT_ENVIRONMENT
if file_name else self._BASE_PORT_EDITOR)
# cache the worker_id and
# increase it for the next environment
worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
Unity3DEnv._WORKER_ID += 1
try:
self.unity_env = UnityEnvironment(
file_name=file_name,
worker_id=0,
worker_id=worker_id_,
base_port=port_,
seed=seed,
no_graphics=no_graphics,
timeout_wait=timeout_wait,
)
print("Created UnityEnvironment for port {}".format(port_))
print(
"Created UnityEnvironment for port {}".format(port_ +
worker_id_))
except mlagents_envs.exception.UnityWorkerInUseException:
pass
else: