mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 16:31:25 +08:00
[RLlib] Fix multiple Unity3DEnvs trying to connect to the same custom port (#13519)
This commit is contained in:
@@ -16,3 +16,7 @@ kaggle_environments
|
||||
|
||||
# For MAML on PyTorch.
|
||||
higher
|
||||
|
||||
# Unity3D testing
|
||||
mlagents
|
||||
mlagents_envs
|
||||
|
||||
@@ -1069,6 +1069,13 @@ sh_test(
|
||||
data = glob(["examples/serving/*.py"]),
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "env/wrappers/tests/test_unity3d_env",
|
||||
tags = ["env"],
|
||||
size = "small",
|
||||
srcs = ["env/wrappers/tests/test_unity3d_env.py"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "env/wrappers/tests/test_recsim_wrapper",
|
||||
tags = ["env"],
|
||||
|
||||
+55
@@ -0,0 +1,55 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from ray.rllib.env.wrappers.unity3d_env import Unity3DEnv
|
||||
|
||||
|
||||
@patch("mlagents_envs.environment.UnityEnvironment")
|
||||
class TestUnity3DEnv(unittest.TestCase):
|
||||
def test_port_editor(self, mock_unity3d):
|
||||
"""Test if the environment uses the editor port
|
||||
when no environment file is provided"""
|
||||
|
||||
_ = Unity3DEnv(port=None)
|
||||
args, kwargs = mock_unity3d.call_args
|
||||
mock_unity3d.assert_called_once()
|
||||
self.assertEqual(5004, kwargs.get("base_port"))
|
||||
|
||||
def test_port_app(self, mock_unity3d):
|
||||
"""Test if the environment uses the correct port
|
||||
when the environment file is provided"""
|
||||
|
||||
_ = Unity3DEnv(file_name="app", port=None)
|
||||
args, kwargs = mock_unity3d.call_args
|
||||
mock_unity3d.assert_called_once()
|
||||
self.assertEqual(5005, kwargs.get("base_port"))
|
||||
|
||||
def test_ports_multi_app(self, mock_unity3d):
|
||||
"""Test if the base_port + worker_id
|
||||
is different for each environment"""
|
||||
|
||||
_ = Unity3DEnv(file_name="app", port=None)
|
||||
args, kwargs_first = mock_unity3d.call_args
|
||||
_ = Unity3DEnv(file_name="app", port=None)
|
||||
args, kwargs_second = mock_unity3d.call_args
|
||||
self.assertNotEqual(
|
||||
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
|
||||
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
|
||||
|
||||
def test_custom_port_app(self, mock_unity3d):
|
||||
"""Test if the base_port + worker_id is different
|
||||
for each environment when using custom ports"""
|
||||
|
||||
_ = Unity3DEnv(file_name="app", port=5010)
|
||||
args, kwargs_first = mock_unity3d.call_args
|
||||
_ = Unity3DEnv(file_name="app", port=5010)
|
||||
args, kwargs_second = mock_unity3d.call_args
|
||||
self.assertNotEqual(
|
||||
kwargs_first.get("base_port") + kwargs_first.get("worker_id"),
|
||||
kwargs_second.get("base_port") + kwargs_second.get("worker_id"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
Vendored
+16
-5
@@ -27,7 +27,12 @@ class Unity3DEnv(MultiAgentEnv):
|
||||
inside an RLlib PolicyClient for cloud/distributed training of Unity games.
|
||||
"""
|
||||
|
||||
_BASE_PORT = 5004
|
||||
# Default base port when connecting directly to the Editor
|
||||
_BASE_PORT_EDITOR = 5004
|
||||
# Default base port when connecting to a compiled environment
|
||||
_BASE_PORT_ENVIRONMENT = 5005
|
||||
# The worker_id for each environment instance
|
||||
_WORKER_ID = 0
|
||||
|
||||
def __init__(self,
|
||||
file_name: str = None,
|
||||
@@ -73,18 +78,24 @@ class Unity3DEnv(MultiAgentEnv):
|
||||
# environments (num_workers >> 1). Otherwise, would lead to port
|
||||
# conflicts sometimes.
|
||||
time.sleep(random.randint(1, 10))
|
||||
port_ = port or self._BASE_PORT
|
||||
self._BASE_PORT += 1
|
||||
port_ = port or (self._BASE_PORT_ENVIRONMENT
|
||||
if file_name else self._BASE_PORT_EDITOR)
|
||||
# cache the worker_id and
|
||||
# increase it for the next environment
|
||||
worker_id_ = Unity3DEnv._WORKER_ID if file_name else 0
|
||||
Unity3DEnv._WORKER_ID += 1
|
||||
try:
|
||||
self.unity_env = UnityEnvironment(
|
||||
file_name=file_name,
|
||||
worker_id=0,
|
||||
worker_id=worker_id_,
|
||||
base_port=port_,
|
||||
seed=seed,
|
||||
no_graphics=no_graphics,
|
||||
timeout_wait=timeout_wait,
|
||||
)
|
||||
print("Created UnityEnvironment for port {}".format(port_))
|
||||
print(
|
||||
"Created UnityEnvironment for port {}".format(port_ +
|
||||
worker_id_))
|
||||
except mlagents_envs.exception.UnityWorkerInUseException:
|
||||
pass
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user