mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
32dc5676b4
* Record per node and raylet cpu / mem usage * Add comments. * Addressed code review.
475 lines
14 KiB
Python
475 lines
14 KiB
Python
import asyncio
|
|
import errno
|
|
import io
|
|
import fnmatch
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import socket
|
|
import math
|
|
|
|
from contextlib import redirect_stdout, redirect_stderr
|
|
|
|
import ray
|
|
import ray._private.services
|
|
import ray.utils
|
|
import requests
|
|
from prometheus_client.parser import text_string_to_metric_families
|
|
from ray.scripts.scripts import main as ray_main
|
|
|
|
import psutil # We must import psutil after ray because we bundle it with ray.
|
|
|
|
if sys.platform == "win32":
|
|
import _winapi
|
|
|
|
|
|
class RayTestTimeoutException(Exception):
|
|
"""Exception used to identify timeouts from test utilities."""
|
|
pass
|
|
|
|
|
|
def _pid_alive(pid):
|
|
"""Check if the process with this PID is alive or not.
|
|
|
|
Args:
|
|
pid: The pid to check.
|
|
|
|
Returns:
|
|
This returns false if the process is dead. Otherwise, it returns true.
|
|
"""
|
|
no_such_process = errno.EINVAL if sys.platform == "win32" else errno.ESRCH
|
|
alive = True
|
|
try:
|
|
if sys.platform == "win32":
|
|
SYNCHRONIZE = 0x00100000 # access mask defined in <winnt.h>
|
|
handle = _winapi.OpenProcess(SYNCHRONIZE, False, pid)
|
|
try:
|
|
alive = (_winapi.WaitForSingleObject(handle, 0) !=
|
|
_winapi.WAIT_OBJECT_0)
|
|
finally:
|
|
_winapi.CloseHandle(handle)
|
|
else:
|
|
os.kill(pid, 0)
|
|
except OSError as ex:
|
|
if ex.errno != no_such_process:
|
|
raise
|
|
alive = False
|
|
return alive
|
|
|
|
|
|
def check_call_module(main, argv, capture_stdout=False, capture_stderr=False):
|
|
# We use this function instead of calling the "ray" command to work around
|
|
# some deadlocks that occur when piping ray's output on Windows
|
|
stream = io.TextIOWrapper(io.BytesIO(), encoding=sys.stdout.encoding)
|
|
old_argv = sys.argv[:]
|
|
try:
|
|
sys.argv = argv[:]
|
|
try:
|
|
with redirect_stderr(stream if capture_stderr else sys.stderr):
|
|
with redirect_stdout(stream if capture_stdout else sys.stdout):
|
|
main()
|
|
finally:
|
|
stream.flush()
|
|
except SystemExit as ex:
|
|
if ex.code:
|
|
output = stream.buffer.getvalue()
|
|
raise subprocess.CalledProcessError(ex.code, argv, output)
|
|
except Exception as ex:
|
|
output = stream.buffer.getvalue()
|
|
raise subprocess.CalledProcessError(1, argv, output, ex.args[0])
|
|
finally:
|
|
sys.argv = old_argv
|
|
if capture_stdout:
|
|
sys.stdout.buffer.write(stream.buffer.getvalue())
|
|
elif capture_stderr:
|
|
sys.stderr.buffer.write(stream.buffer.getvalue())
|
|
return stream.buffer.getvalue()
|
|
|
|
|
|
def check_call_ray(args, capture_stdout=False, capture_stderr=False):
|
|
# We use this function instead of calling the "ray" command to work around
|
|
# some deadlocks that occur when piping ray's output on Windows
|
|
argv = ["ray"] + args
|
|
if sys.platform == "win32":
|
|
result = check_call_module(
|
|
ray_main,
|
|
argv,
|
|
capture_stdout=capture_stdout,
|
|
capture_stderr=capture_stderr)
|
|
else:
|
|
stdout_redir = None
|
|
stderr_redir = None
|
|
if capture_stdout:
|
|
stdout_redir = subprocess.PIPE
|
|
if capture_stderr and capture_stdout:
|
|
stderr_redir = subprocess.STDOUT
|
|
elif capture_stderr:
|
|
stderr_redir = subprocess.PIPE
|
|
proc = subprocess.Popen(argv, stdout=stdout_redir, stderr=stderr_redir)
|
|
(stdout, stderr) = proc.communicate()
|
|
if proc.returncode:
|
|
raise subprocess.CalledProcessError(proc.returncode, argv, stdout,
|
|
stderr)
|
|
result = b"".join([s for s in [stdout, stderr] if s is not None])
|
|
return result
|
|
|
|
|
|
def wait_for_pid_to_exit(pid, timeout=20):
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout:
|
|
if not _pid_alive(pid):
|
|
return
|
|
time.sleep(0.1)
|
|
raise RayTestTimeoutException(
|
|
f"Timed out while waiting for process {pid} to exit.")
|
|
|
|
|
|
def wait_for_children_of_pid(pid, num_children=1, timeout=20):
|
|
p = psutil.Process(pid)
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout:
|
|
num_alive = len(p.children(recursive=False))
|
|
if num_alive >= num_children:
|
|
return
|
|
time.sleep(0.1)
|
|
raise RayTestTimeoutException(
|
|
"Timed out while waiting for process {} children to start "
|
|
"({}/{} started).".format(pid, num_alive, num_children))
|
|
|
|
|
|
def wait_for_children_of_pid_to_exit(pid, timeout=20):
|
|
children = psutil.Process(pid).children()
|
|
if len(children) == 0:
|
|
return
|
|
|
|
_, alive = psutil.wait_procs(children, timeout=timeout)
|
|
if len(alive) > 0:
|
|
raise RayTestTimeoutException(
|
|
"Timed out while waiting for process children to exit."
|
|
" Children still alive: {}.".format([p.name() for p in alive]))
|
|
|
|
|
|
def kill_process_by_name(name, SIGKILL=False):
|
|
for p in psutil.process_iter(attrs=["name"]):
|
|
if p.info["name"] == name + ray._private.services.EXE_SUFFIX:
|
|
if SIGKILL:
|
|
p.kill()
|
|
else:
|
|
p.terminate()
|
|
|
|
|
|
def run_string_as_driver(driver_script):
|
|
"""Run a driver as a separate process.
|
|
|
|
Args:
|
|
driver_script: A string to run as a Python script.
|
|
|
|
Returns:
|
|
The script's output.
|
|
"""
|
|
proc = subprocess.Popen(
|
|
[sys.executable, "-"],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.STDOUT)
|
|
with proc:
|
|
output = proc.communicate(driver_script.encode("ascii"))[0]
|
|
if proc.returncode:
|
|
print(ray.utils.decode(output))
|
|
raise subprocess.CalledProcessError(proc.returncode, proc.args,
|
|
output, proc.stderr)
|
|
out = ray.utils.decode(output)
|
|
return out
|
|
|
|
|
|
def run_string_as_driver_nonblocking(driver_script):
|
|
"""Start a driver as a separate process and return immediately.
|
|
|
|
Args:
|
|
driver_script: A string to run as a Python script.
|
|
|
|
Returns:
|
|
A handle to the driver process.
|
|
"""
|
|
script = "; ".join([
|
|
"import sys",
|
|
"script = sys.stdin.read()",
|
|
"sys.stdin.close()",
|
|
"del sys",
|
|
"exec(\"del script\\n\" + script)",
|
|
])
|
|
proc = subprocess.Popen(
|
|
[sys.executable, "-c", script],
|
|
stdin=subprocess.PIPE,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE)
|
|
proc.stdin.write(driver_script.encode("ascii"))
|
|
proc.stdin.close()
|
|
return proc
|
|
|
|
|
|
def wait_for_num_actors(num_actors, state=None, timeout=10):
|
|
start_time = time.time()
|
|
while time.time() - start_time < timeout:
|
|
if len([
|
|
_ for _ in ray.actors().values()
|
|
if state is None or _["State"] == state
|
|
]) >= num_actors:
|
|
return
|
|
time.sleep(0.1)
|
|
raise RayTestTimeoutException("Timed out while waiting for global state.")
|
|
|
|
|
|
def wait_for_condition(condition_predictor, timeout=10, retry_interval_ms=100):
|
|
"""Wait until a condition is met or time out with an exception.
|
|
|
|
Args:
|
|
condition_predictor: A function that predicts the condition.
|
|
timeout: Maximum timeout in seconds.
|
|
retry_interval_ms: Retry interval in milliseconds.
|
|
|
|
Raises:
|
|
RuntimeError: If the condition is not met before the timeout expires.
|
|
"""
|
|
start = time.time()
|
|
while time.time() - start <= timeout:
|
|
if condition_predictor():
|
|
return
|
|
time.sleep(retry_interval_ms / 1000.0)
|
|
raise RuntimeError("The condition wasn't met before the timeout expired.")
|
|
|
|
|
|
def wait_until_succeeded_without_exception(func,
|
|
exceptions,
|
|
*args,
|
|
timeout_ms=1000,
|
|
retry_interval_ms=100):
|
|
"""A helper function that waits until a given function
|
|
completes without exceptions.
|
|
|
|
Args:
|
|
func: A function to run.
|
|
exceptions(tuple): Exceptions that are supposed to occur.
|
|
args: arguments to pass for a given func
|
|
timeout_ms: Maximum timeout in milliseconds.
|
|
retry_interval_ms: Retry interval in milliseconds.
|
|
|
|
Return:
|
|
Whether exception occurs within a timeout.
|
|
"""
|
|
if type(exceptions) != tuple:
|
|
print("exceptions arguments should be given as a tuple")
|
|
return False
|
|
|
|
time_elapsed = 0
|
|
start = time.time()
|
|
while time_elapsed <= timeout_ms:
|
|
try:
|
|
func(*args)
|
|
return True
|
|
except exceptions:
|
|
time_elapsed = (time.time() - start) * 1000
|
|
time.sleep(retry_interval_ms / 1000.0)
|
|
return False
|
|
|
|
|
|
def recursive_fnmatch(dirpath, pattern):
|
|
"""Looks at a file directory subtree for a filename pattern.
|
|
|
|
Similar to glob.glob(..., recursive=True) but also supports 2.7
|
|
"""
|
|
matches = []
|
|
for root, dirnames, filenames in os.walk(dirpath):
|
|
for filename in fnmatch.filter(filenames, pattern):
|
|
matches.append(os.path.join(root, filename))
|
|
return matches
|
|
|
|
|
|
def generate_system_config_map(**kwargs):
|
|
ray_kwargs = {
|
|
"_system_config": kwargs,
|
|
}
|
|
return ray_kwargs
|
|
|
|
|
|
@ray.remote(num_cpus=0)
|
|
class SignalActor:
|
|
def __init__(self):
|
|
self.ready_event = asyncio.Event()
|
|
|
|
def send(self, clear=False):
|
|
self.ready_event.set()
|
|
if clear:
|
|
self.ready_event.clear()
|
|
|
|
async def wait(self, should_wait=True):
|
|
if should_wait:
|
|
await self.ready_event.wait()
|
|
|
|
|
|
@ray.remote(num_cpus=0)
|
|
class Semaphore:
|
|
def __init__(self, value=1):
|
|
self._sema = asyncio.Semaphore(value=value)
|
|
|
|
async def acquire(self):
|
|
await self._sema.acquire()
|
|
|
|
async def release(self):
|
|
self._sema.release()
|
|
|
|
async def locked(self):
|
|
return self._sema.locked()
|
|
|
|
|
|
def dicts_equal(dict1, dict2, abs_tol=1e-4):
|
|
"""Compares to dicts whose values may be floating point numbers."""
|
|
|
|
if dict1.keys() != dict2.keys():
|
|
return False
|
|
|
|
for k, v in dict1.items():
|
|
if isinstance(v, float) and \
|
|
isinstance(dict2[k], float) and \
|
|
math.isclose(v, dict2[k], abs_tol=abs_tol):
|
|
continue
|
|
if v != dict2[k]:
|
|
return False
|
|
return True
|
|
|
|
|
|
def same_elements(elems_a, elems_b):
|
|
"""Checks if two iterables (such as lists) contain the same elements. Elements
|
|
do not have to be hashable (this allows us to compare sets of dicts for
|
|
example). This comparison is not necessarily efficient.
|
|
"""
|
|
a = list(elems_a)
|
|
b = list(elems_b)
|
|
|
|
for x in a:
|
|
if x not in b:
|
|
return False
|
|
|
|
for x in b:
|
|
if x not in a:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
@ray.remote
|
|
def _put(obj):
|
|
return obj
|
|
|
|
|
|
def put_object(obj, use_ray_put):
|
|
if use_ray_put:
|
|
return ray.put(obj)
|
|
else:
|
|
return _put.remote(obj)
|
|
|
|
|
|
def wait_until_server_available(address,
|
|
timeout_ms=5000,
|
|
retry_interval_ms=100):
|
|
ip_port = address.split(":")
|
|
ip = ip_port[0]
|
|
port = int(ip_port[1])
|
|
time_elapsed = 0
|
|
start = time.time()
|
|
while time_elapsed <= timeout_ms:
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
s.settimeout(1)
|
|
try:
|
|
s.connect((ip, port))
|
|
except Exception:
|
|
time_elapsed = (time.time() - start) * 1000
|
|
time.sleep(retry_interval_ms / 1000.0)
|
|
s.close()
|
|
continue
|
|
s.close()
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_other_nodes(cluster, exclude_head=False):
|
|
"""Get all nodes except the one that we're connected to."""
|
|
return [
|
|
node for node in cluster.list_all_nodes() if
|
|
node._raylet_socket_name != ray.worker._global_node._raylet_socket_name
|
|
and (exclude_head is False or node.head is False)
|
|
]
|
|
|
|
|
|
def get_non_head_nodes(cluster):
|
|
"""Get all non-head nodes."""
|
|
return list(filter(lambda x: x.head is False, cluster.list_all_nodes()))
|
|
|
|
|
|
def init_error_pubsub():
|
|
"""Initialize redis error info pub/sub"""
|
|
p = ray.worker.global_worker.redis_client.pubsub(
|
|
ignore_subscribe_messages=True)
|
|
error_pubsub_channel = ray.gcs_utils.RAY_ERROR_PUBSUB_PATTERN
|
|
p.psubscribe(error_pubsub_channel)
|
|
return p
|
|
|
|
|
|
def get_error_message(pub_sub, num, error_type=None, timeout=20):
|
|
"""Get errors through pub/sub."""
|
|
start_time = time.time()
|
|
msgs = []
|
|
while time.time() - start_time < timeout and len(msgs) < num:
|
|
msg = pub_sub.get_message()
|
|
if msg is None:
|
|
time.sleep(0.01)
|
|
continue
|
|
pubsub_msg = ray.gcs_utils.PubSubMessage.FromString(msg["data"])
|
|
error_data = ray.gcs_utils.ErrorTableData.FromString(pubsub_msg.data)
|
|
if error_type is None or error_type == error_data.type:
|
|
msgs.append(error_data)
|
|
else:
|
|
time.sleep(0.01)
|
|
|
|
return msgs
|
|
|
|
|
|
def format_web_url(url):
|
|
"""Format web url."""
|
|
url = url.replace("localhost", "http://127.0.0.1")
|
|
if not url.startswith("http://"):
|
|
return "http://" + url
|
|
return url
|
|
|
|
|
|
def new_scheduler_enabled():
|
|
return os.environ.get("RAY_ENABLE_NEW_SCHEDULER", "1") == "1"
|
|
|
|
|
|
def client_test_enabled() -> bool:
|
|
return os.environ.get("RAY_CLIENT_MODE") == "1"
|
|
|
|
|
|
def fetch_prometheus(prom_addresses):
|
|
components_dict = {}
|
|
metric_names = set()
|
|
metric_samples = []
|
|
for address in prom_addresses:
|
|
if address not in components_dict:
|
|
components_dict[address] = set()
|
|
try:
|
|
response = requests.get(f"http://{address}/metrics")
|
|
except requests.exceptions.ConnectionError:
|
|
continue
|
|
|
|
for line in response.text.split("\n"):
|
|
for family in text_string_to_metric_families(line):
|
|
for sample in family.samples:
|
|
metric_names.add(sample.name)
|
|
metric_samples.append(sample)
|
|
if "Component" in sample.labels:
|
|
components_dict[address].add(
|
|
sample.labels["Component"])
|
|
return components_dict, metric_names, metric_samples
|