mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 13:11:03 +08:00
[ray_client] Implement a gRPC streaming logs API for the client (#13001)
This commit is contained in:
@@ -26,6 +26,7 @@ class DataClient:
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
client_id: the generated ID representing this client
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
@@ -68,18 +69,14 @@ class DataClient:
|
||||
logger.info("Cancelling data channel")
|
||||
else:
|
||||
logger.error(
|
||||
f"Got Error from rpc channel -- shutting down: {e}")
|
||||
f"Got Error from data channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def close(self, close_channel: bool = False) -> None:
|
||||
def close(self) -> None:
|
||||
if self.request_queue is not None:
|
||||
self.request_queue.put(None)
|
||||
self.request_queue = None
|
||||
if close_channel:
|
||||
self.channel.close()
|
||||
if self.data_thread is not None:
|
||||
self.data_thread.join()
|
||||
self.data_thread = None
|
||||
|
||||
def _blocking_send(self, req: ray_client_pb2.DataRequest
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
|
||||
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
This file implements a threaded stream controller to return logs back from
|
||||
the ray clientserver.
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogstreamClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel"):
|
||||
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.log_thread = self._start_logthread()
|
||||
self.log_thread.start()
|
||||
|
||||
def _start_logthread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._log_main, args=(), daemon=True)
|
||||
|
||||
def _log_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
|
||||
log_stream = stub.Logstream(iter(self.request_queue.get, None))
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED != e.code():
|
||||
# Not just shutting down normally
|
||||
logger.error(
|
||||
f"Got Error from logger channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def log(self, level: int, msg: str):
|
||||
"""
|
||||
Log the message from the log stream.
|
||||
By default, calls logger.log but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
logger.log(level=level, msg=msg)
|
||||
|
||||
def stdstream(self, level: int, msg: str):
|
||||
"""
|
||||
Log the stdout/stderr entry from the log stream.
|
||||
By default, calls print but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
print_file = sys.stderr if level == -2 else sys.stdout
|
||||
print(msg, file=print_file)
|
||||
|
||||
def set_logstream_level(self, level: int):
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = True
|
||||
req.loglevel = level
|
||||
self.request_queue.put(req)
|
||||
|
||||
def close(self) -> None:
|
||||
self.request_queue.put(None)
|
||||
if self.log_thread is not None:
|
||||
self.log_thread.join()
|
||||
|
||||
def disable_logs(self) -> None:
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = False
|
||||
self.request_queue.put(req)
|
||||
@@ -48,7 +48,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing channel: {e}")
|
||||
logger.debug(f"Closing data channel: {e}")
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
This file responds to log stream requests and forwards logs
|
||||
with its handler.
|
||||
"""
|
||||
import io
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
import grpc
|
||||
import uuid
|
||||
|
||||
from ray.worker import print_worker_logs
|
||||
from ray.ray_logging import global_worker_stdstream_dispatcher
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogstreamHandler(logging.Handler):
|
||||
def __init__(self, queue, level):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.level = level
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.msg = record.getMessage()
|
||||
logdata.level = record.levelno
|
||||
logdata.name = record.name
|
||||
self.queue.put(logdata)
|
||||
|
||||
|
||||
class StdStreamHandler:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def handle(self, data):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.level = -2 if data["is_err"] else -1
|
||||
logdata.name = "stderr" if data["is_err"] else "stdout"
|
||||
with io.StringIO() as file:
|
||||
print_worker_logs(data, file)
|
||||
logdata.msg = file.getvalue()
|
||||
self.queue.put(logdata)
|
||||
|
||||
def register_global(self):
|
||||
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
|
||||
|
||||
def unregister_global(self):
|
||||
global_worker_stdstream_dispatcher.remove_handler(self.id)
|
||||
|
||||
|
||||
def log_status_change_thread(log_queue, request_iterator):
|
||||
std_handler = StdStreamHandler(log_queue)
|
||||
current_handler = None
|
||||
root_logger = logging.getLogger("ray")
|
||||
default_level = root_logger.getEffectiveLevel()
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
if not req.enabled:
|
||||
current_handler = None
|
||||
continue
|
||||
current_handler = LogstreamHandler(log_queue, req.loglevel)
|
||||
std_handler.register_global()
|
||||
root_logger.addHandler(current_handler)
|
||||
root_logger.setLevel(req.loglevel)
|
||||
finally:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
log_queue.put(None)
|
||||
|
||||
|
||||
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
def Logstream(self, request_iterator, context):
|
||||
logger.info("New logs connection")
|
||||
log_queue = queue.Queue()
|
||||
thread = threading.Thread(
|
||||
target=log_status_change_thread,
|
||||
args=(log_queue, request_iterator),
|
||||
daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
queue_iter = iter(log_queue.get, None)
|
||||
for record in queue_iter:
|
||||
if record is None:
|
||||
break
|
||||
yield record
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing log channel: {e}")
|
||||
finally:
|
||||
thread.join()
|
||||
@@ -23,6 +23,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server
|
||||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.logservicer import LogstreamServicer
|
||||
from ray.experimental.client.server.server_stubs import current_remote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -372,11 +373,14 @@ def serve(connection_str, test_mode=False):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
data_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
logs_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
@@ -30,6 +30,7 @@ from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
from ray.experimental.client.logsclient import LogstreamClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,9 +55,13 @@ class Worker:
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
self.data_client = DataClient(self.channel, self._client_id)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
self.log_client = LogstreamClient(self.channel)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
@@ -197,14 +202,16 @@ class Worker:
|
||||
ray_client_pb2.ReleaseRequest(ids=[id]))
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
logger.debug(f"Retaining {id}")
|
||||
logger.debug(f"Retaining {id.hex()}")
|
||||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.data_client.close(close_channel=True)
|
||||
self.server = None
|
||||
self.log_client.close()
|
||||
self.data_client.close()
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
self.server = None
|
||||
|
||||
def get_actor(self, name: str) -> ClientActorHandle:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import ray
|
||||
from ray.utils import binary_to_hex
|
||||
|
||||
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
|
||||
# logger to add a newline at the end of string.
|
||||
handler.terminator = ""
|
||||
return logger
|
||||
|
||||
|
||||
class WorkerStandardStreamDispatcher:
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_handler(self, name: str, handler: Callable) -> None:
|
||||
with self._lock:
|
||||
self.handlers.append((name, handler))
|
||||
|
||||
def remove_handler(self, name: str) -> None:
|
||||
with self._lock:
|
||||
new_handlers = [pair for pair in self.handlers if pair[0] != name]
|
||||
self.handlers = new_handlers
|
||||
|
||||
def emit(self, data):
|
||||
with self._lock:
|
||||
for pair in self.handlers:
|
||||
_, handle = pair
|
||||
handle(data)
|
||||
|
||||
|
||||
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import pytest
|
||||
import time
|
||||
import sys
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
@@ -234,6 +237,47 @@ def test_pass_handles(ray_start_regular_shared):
|
||||
4)) == local_fact(4)
|
||||
|
||||
|
||||
def test_basic_log_stream(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
log_msgs = []
|
||||
|
||||
def test_log(level, msg):
|
||||
log_msgs.append(msg)
|
||||
|
||||
ray.worker.log_client.log = test_log
|
||||
ray.worker.log_client.set_logstream_level(logging.DEBUG)
|
||||
# Allow some time to propogate
|
||||
time.sleep(1)
|
||||
x = ray.put("Foo")
|
||||
assert ray.get(x) == "Foo"
|
||||
time.sleep(1)
|
||||
logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0]
|
||||
assert len(logs_with_id) >= 2
|
||||
assert any((msg.find("get") >= 0 for msg in logs_with_id))
|
||||
assert any((msg.find("put") >= 0 for msg in logs_with_id))
|
||||
|
||||
|
||||
def test_stdout_log_stream(ray_start_regular_shared):
|
||||
with ray_start_client_server() as ray:
|
||||
log_msgs = []
|
||||
|
||||
def test_log(level, msg):
|
||||
log_msgs.append(msg)
|
||||
|
||||
ray.worker.log_client.stdstream = test_log
|
||||
|
||||
@ray.remote
|
||||
def print_on_stderr_and_stdout(s):
|
||||
print(s)
|
||||
print(s, file=sys.stderr)
|
||||
|
||||
time.sleep(1)
|
||||
print_on_stderr_and_stdout.remote("Hello world")
|
||||
time.sleep(1)
|
||||
assert len(log_msgs) == 2
|
||||
assert all((msg.find("Hello world") for msg in log_msgs))
|
||||
|
||||
|
||||
def test_basic_named_actor(ray_start_regular_shared):
|
||||
"""
|
||||
Test that ray.get_actor() can create and return a detached actor.
|
||||
@@ -264,5 +308,4 @@ def test_basic_named_actor(ray_start_regular_shared):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
|
||||
+33
-23
@@ -48,6 +48,7 @@ from ray.exceptions import (
|
||||
)
|
||||
from ray.function_manager import FunctionActorManager
|
||||
from ray.ray_logging import setup_logger
|
||||
from ray.ray_logging import global_worker_stdstream_dispatcher
|
||||
from ray.utils import _random_string, check_oversized_pickle
|
||||
from ray.util.inspect import is_cython
|
||||
|
||||
@@ -910,29 +911,8 @@ def print_logs(redis_client, threads_stopped, job_id):
|
||||
if data["job"] and ray.utils.binary_to_hex(
|
||||
job_id.binary()) != data["job"]:
|
||||
continue
|
||||
|
||||
print_file = sys.stderr if data["is_err"] else sys.stdout
|
||||
|
||||
def color_for(data):
|
||||
if data["pid"] == "raylet":
|
||||
return colorama.Fore.YELLOW
|
||||
else:
|
||||
return colorama.Fore.CYAN
|
||||
|
||||
if data["ip"] == localhost:
|
||||
for line in data["lines"]:
|
||||
print(
|
||||
"{}{}(pid={}){} {}".format(
|
||||
colorama.Style.DIM, color_for(data), data["pid"],
|
||||
colorama.Style.RESET_ALL, line),
|
||||
file=print_file)
|
||||
else:
|
||||
for line in data["lines"]:
|
||||
print(
|
||||
"{}{}(pid={}, ip={}){} {}".format(
|
||||
colorama.Style.DIM, color_for(data), data["pid"],
|
||||
data["ip"], colorama.Style.RESET_ALL, line),
|
||||
file=print_file)
|
||||
data["localhost"] = localhost
|
||||
global_worker_stdstream_dispatcher.emit(data)
|
||||
|
||||
except (OSError, redis.exceptions.ConnectionError) as e:
|
||||
logger.error(f"print_logs: {e}")
|
||||
@@ -941,6 +921,34 @@ def print_logs(redis_client, threads_stopped, job_id):
|
||||
pubsub_client.close()
|
||||
|
||||
|
||||
def print_to_stdstream(data):
|
||||
print_file = sys.stderr if data["is_err"] else sys.stdout
|
||||
print_worker_logs(data, print_file)
|
||||
|
||||
|
||||
def print_worker_logs(data, print_file):
|
||||
def color_for(data):
|
||||
if data["pid"] == "raylet":
|
||||
return colorama.Fore.YELLOW
|
||||
else:
|
||||
return colorama.Fore.CYAN
|
||||
|
||||
if data["ip"] == data["localhost"]:
|
||||
for line in data["lines"]:
|
||||
print(
|
||||
"{}{}(pid={}){} {}".format(colorama.Style.DIM, color_for(data),
|
||||
data["pid"],
|
||||
colorama.Style.RESET_ALL, line),
|
||||
file=print_file)
|
||||
else:
|
||||
for line in data["lines"]:
|
||||
print(
|
||||
"{}{}(pid={}, ip={}){} {}".format(
|
||||
colorama.Style.DIM, color_for(data), data["pid"],
|
||||
data["ip"], colorama.Style.RESET_ALL, line),
|
||||
file=print_file)
|
||||
|
||||
|
||||
def print_error_messages_raylet(task_error_queue, threads_stopped):
|
||||
"""Prints message received in the given output queue.
|
||||
|
||||
@@ -1201,6 +1209,8 @@ def connect(node,
|
||||
worker.printer_thread.daemon = True
|
||||
worker.printer_thread.start()
|
||||
if log_to_driver:
|
||||
global_worker_stdstream_dispatcher.add_handler(
|
||||
"ray_print_logs", print_to_stdstream)
|
||||
worker.logger_thread = threading.Thread(
|
||||
target=print_logs,
|
||||
name="ray_print_logs",
|
||||
|
||||
Reference in New Issue
Block a user