Files
ray/python/ray/test_utils.py
T
2020-07-13 20:44:27 -05:00

345 lines
9.5 KiB
Python

import asyncio
import errno
import json
import fnmatch
import os
import subprocess
import sys
import time
import socket
import math
import ray
import ray.services
import psutil # We must import psutil after ray because we bundle it with ray.
if sys.platform == "win32":
import _winapi
class RayTestTimeoutException(Exception):
"""Exception used to identify timeouts from test utilities."""
pass
def _pid_alive(pid):
"""Check if the process with this PID is alive or not.
Args:
pid: The pid to check.
Returns:
This returns false if the process is dead. Otherwise, it returns true.
"""
no_such_process = errno.EINVAL if sys.platform == "win32" else errno.ESRCH
alive = True
try:
if sys.platform == "win32":
SYNCHRONIZE = 0x00100000 # access mask defined in <winnt.h>
handle = _winapi.OpenProcess(SYNCHRONIZE, False, pid)
try:
alive = (_winapi.WaitForSingleObject(handle, 0) !=
_winapi.WAIT_OBJECT_0)
finally:
_winapi.CloseHandle(handle)
else:
os.kill(pid, 0)
except OSError as ex:
if ex.errno != no_such_process:
raise
alive = False
return alive
def wait_for_pid_to_exit(pid, timeout=20):
start_time = time.time()
while time.time() - start_time < timeout:
if not _pid_alive(pid):
return
time.sleep(0.1)
raise RayTestTimeoutException(
"Timed out while waiting for process {} to exit.".format(pid))
def wait_for_children_of_pid(pid, num_children=1, timeout=20):
p = psutil.Process(pid)
start_time = time.time()
while time.time() - start_time < timeout:
num_alive = len(p.children(recursive=False))
if num_alive >= num_children:
return
time.sleep(0.1)
raise RayTestTimeoutException(
"Timed out while waiting for process {} children to start "
"({}/{} started).".format(pid, num_alive, num_children))
def wait_for_children_of_pid_to_exit(pid, timeout=20):
children = psutil.Process(pid).children()
if len(children) == 0:
return
_, alive = psutil.wait_procs(children, timeout=timeout)
if len(alive) > 0:
raise RayTestTimeoutException(
"Timed out while waiting for process children to exit."
" Children still alive: {}.".format([p.name() for p in alive]))
def kill_process_by_name(name, SIGKILL=False):
for p in psutil.process_iter(attrs=["name"]):
if p.info["name"] == name + ray.services.EXE_SUFFIX:
if SIGKILL:
p.kill()
else:
p.terminate()
def run_string_as_driver(driver_script):
"""Run a driver as a separate process.
Args:
driver_script: A string to run as a Python script.
Returns:
The script's output.
"""
proc = subprocess.Popen(
[sys.executable, "-"],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
with proc:
output = proc.communicate(driver_script.encode("ascii"))[0]
if proc.returncode:
print(ray.utils.decode(output))
raise subprocess.CalledProcessError(proc.returncode, proc.args,
output, proc.stderr)
out = ray.utils.decode(output)
return out
def run_string_as_driver_nonblocking(driver_script):
"""Start a driver as a separate process and return immediately.
Args:
driver_script: A string to run as a Python script.
Returns:
A handle to the driver process.
"""
script = "; ".join([
"import sys",
"script = sys.stdin.read()",
"sys.stdin.close()",
"del sys",
"exec(\"del script\\n\" + script)",
])
proc = subprocess.Popen(
[sys.executable, "-c", script],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
proc.stdin.write(driver_script.encode("ascii"))
proc.stdin.close()
return proc
def flat_errors():
errors = []
for job_errors in ray.errors(all_jobs=True).values():
errors.extend(job_errors)
return errors
def relevant_errors(error_type):
return [error for error in flat_errors() if error["type"] == error_type]
def wait_for_num_actors(num_actors, timeout=10):
start_time = time.time()
while time.time() - start_time < timeout:
if len(ray.actors()) >= num_actors:
return
time.sleep(0.1)
raise RayTestTimeoutException("Timed out while waiting for global state.")
def wait_for_errors(error_type, num_errors, timeout=20):
start_time = time.time()
while time.time() - start_time < timeout:
if len(relevant_errors(error_type)) >= num_errors:
return
time.sleep(0.1)
raise RayTestTimeoutException("Timed out waiting for {} {} errors.".format(
num_errors, error_type))
def wait_for_condition(condition_predictor, timeout=30, retry_interval_ms=100):
"""A helper function that waits until a condition is met.
Args:
condition_predictor: A function that predicts the condition.
timeout: Maximum timeout in seconds.
retry_interval_ms: Retry interval in milliseconds.
Return:
Whether the condition is met within the timeout.
"""
start = time.time()
while time.time() - start <= timeout:
if condition_predictor():
return True
time.sleep(retry_interval_ms / 1000.0)
return False
def wait_until_succeeded_without_exception(func,
exceptions,
*args,
timeout_ms=1000,
retry_interval_ms=100):
"""A helper function that waits until a given function
completes without exceptions.
Args:
func: A function to run.
exceptions(tuple): Exceptions that are supposed to occur.
args: arguments to pass for a given func
timeout_ms: Maximum timeout in milliseconds.
retry_interval_ms: Retry interval in milliseconds.
Return:
Whether exception occurs within a timeout.
"""
if type(exceptions) != tuple:
print("exceptions arguments should be given as a tuple")
return False
time_elapsed = 0
start = time.time()
while time_elapsed <= timeout_ms:
try:
func(*args)
return True
except exceptions:
time_elapsed = (time.time() - start) * 1000
time.sleep(retry_interval_ms / 1000.0)
return False
def recursive_fnmatch(dirpath, pattern):
"""Looks at a file directory subtree for a filename pattern.
Similar to glob.glob(..., recursive=True) but also supports 2.7
"""
matches = []
for root, dirnames, filenames in os.walk(dirpath):
for filename in fnmatch.filter(filenames, pattern):
matches.append(os.path.join(root, filename))
return matches
def generate_internal_config_map(**kwargs):
internal_config = json.dumps(kwargs)
ray_kwargs = {
"_internal_config": internal_config,
}
return ray_kwargs
@ray.remote(num_cpus=0)
class SignalActor:
def __init__(self):
self.ready_event = asyncio.Event()
def send(self, clear=False):
self.ready_event.set()
if clear:
self.ready_event.clear()
async def wait(self, should_wait=True):
if should_wait:
await self.ready_event.wait()
@ray.remote(num_cpus=0)
class Semaphore:
def __init__(self, value=1):
self._sema = asyncio.Semaphore(value=value)
async def acquire(self):
await self._sema.acquire()
async def release(self):
self._sema.release()
async def locked(self):
return self._sema.locked()
def dicts_equal(dict1, dict2, abs_tol=1e-4):
"""Compares to dicts whose values may be floating point numbers."""
if dict1.keys() != dict2.keys():
return False
for k, v in dict1.items():
if isinstance(v, float) and \
isinstance(dict2[k], float) and \
math.isclose(v, dict2[k], abs_tol=abs_tol):
continue
if v != dict2[k]:
return False
return True
@ray.remote
def _put(obj):
return obj
def put_object(obj, use_ray_put):
if use_ray_put:
return ray.put(obj)
else:
return _put.remote(obj)
def wait_until_server_available(address,
timeout_ms=5000,
retry_interval_ms=100):
ip_port = address.split(":")
ip = ip_port[0]
port = int(ip_port[1])
time_elapsed = 0
start = time.time()
while time_elapsed <= timeout_ms:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(1)
try:
s.connect((ip, port))
except Exception:
time_elapsed = (time.time() - start) * 1000
time.sleep(retry_interval_ms / 1000.0)
s.close()
continue
s.close()
return True
return False
def get_other_nodes(cluster, exclude_head=False):
"""Get all nodes except the one that we're connected to."""
return [
node for node in cluster.list_all_nodes() if
node._raylet_socket_name != ray.worker._global_node._raylet_socket_name
and (exclude_head is False or node.head is False)
]
def get_non_head_nodes(cluster):
"""Get all non-head nodes."""
return list(filter(lambda x: x.head is False, cluster.list_all_nodes()))