[ray_client] Implement a gRPC streaming logs API for the client (#13001)

This commit is contained in:
Barak Michener
2020-12-20 19:35:34 -08:00
committed by GitHub
parent 4caa6c6d78
commit c576f0b073
10 changed files with 332 additions and 34 deletions
+3 -6
View File
@@ -26,6 +26,7 @@ class DataClient:
Args:
channel: connected gRPC channel
client_id: the generated ID representing this client
"""
self.channel = channel
self.request_queue = queue.Queue()
@@ -68,18 +69,14 @@ class DataClient:
logger.info("Cancelling data channel")
else:
logger.error(
f"Got Error from rpc channel -- shutting down: {e}")
f"Got Error from data channel -- shutting down: {e}")
raise e
def close(self, close_channel: bool = False) -> None:
def close(self) -> None:
if self.request_queue is not None:
self.request_queue.put(None)
self.request_queue = None
if close_channel:
self.channel.close()
if self.data_thread is not None:
self.data_thread.join()
self.data_thread = None
def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:
@@ -0,0 +1,84 @@
"""
This file implements a threaded stream controller to return logs back from
the ray clientserver.
"""
import sys
import logging
import queue
import threading
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
class LogstreamClient:
def __init__(self, channel: "grpc._channel.Channel"):
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
"""
self.channel = channel
self.request_queue = queue.Queue()
self.log_thread = self._start_logthread()
self.log_thread.start()
def _start_logthread(self) -> threading.Thread:
return threading.Thread(target=self._log_main, args=(), daemon=True)
def _log_main(self) -> None:
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
log_stream = stub.Logstream(iter(self.request_queue.get, None))
try:
for record in log_stream:
if record.level < 0:
self.stdstream(level=record.level, msg=record.msg)
self.log(level=record.level, msg=record.msg)
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED != e.code():
# Not just shutting down normally
logger.error(
f"Got Error from logger channel -- shutting down: {e}")
raise e
def log(self, level: int, msg: str):
"""
Log the message from the log stream.
By default, calls logger.log but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
logger.log(level=level, msg=msg)
def stdstream(self, level: int, msg: str):
"""
Log the stdout/stderr entry from the log stream.
By default, calls print but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
print_file = sys.stderr if level == -2 else sys.stdout
print(msg, file=print_file)
def set_logstream_level(self, level: int):
req = ray_client_pb2.LogSettingsRequest()
req.enabled = True
req.loglevel = level
self.request_queue.put(req)
def close(self) -> None:
self.request_queue.put(None)
if self.log_thread is not None:
self.log_thread.join()
def disable_logs(self) -> None:
req = ray_client_pb2.LogSettingsRequest()
req.enabled = False
self.request_queue.put(req)
@@ -48,7 +48,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
logger.debug(f"Closing channel: {e}")
logger.debug(f"Closing data channel: {e}")
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
@@ -0,0 +1,99 @@
"""
This file responds to log stream requests and forwards logs
with its handler.
"""
import io
import threading
import queue
import logging
import grpc
import uuid
from ray.worker import print_worker_logs
from ray.ray_logging import global_worker_stdstream_dispatcher
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
class LogstreamHandler(logging.Handler):
def __init__(self, queue, level):
super().__init__()
self.queue = queue
self.level = level
def emit(self, record: logging.LogRecord):
logdata = ray_client_pb2.LogData()
logdata.msg = record.getMessage()
logdata.level = record.levelno
logdata.name = record.name
self.queue.put(logdata)
class StdStreamHandler:
def __init__(self, queue):
self.queue = queue
self.id = str(uuid.uuid4())
def handle(self, data):
logdata = ray_client_pb2.LogData()
logdata.level = -2 if data["is_err"] else -1
logdata.name = "stderr" if data["is_err"] else "stdout"
with io.StringIO() as file:
print_worker_logs(data, file)
logdata.msg = file.getvalue()
self.queue.put(logdata)
def register_global(self):
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
def unregister_global(self):
global_worker_stdstream_dispatcher.remove_handler(self.id)
def log_status_change_thread(log_queue, request_iterator):
std_handler = StdStreamHandler(log_queue)
current_handler = None
root_logger = logging.getLogger("ray")
default_level = root_logger.getEffectiveLevel()
try:
for req in request_iterator:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
if not req.enabled:
current_handler = None
continue
current_handler = LogstreamHandler(log_queue, req.loglevel)
std_handler.register_global()
root_logger.addHandler(current_handler)
root_logger.setLevel(req.loglevel)
finally:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
log_queue.put(None)
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
def Logstream(self, request_iterator, context):
logger.info("New logs connection")
log_queue = queue.Queue()
thread = threading.Thread(
target=log_status_change_thread,
args=(log_queue, request_iterator),
daemon=True)
thread.start()
try:
queue_iter = iter(log_queue.get, None)
for record in queue_iter:
if record is None:
break
yield record
except grpc.RpcError as e:
logger.debug(f"Closing log channel: {e}")
finally:
thread.join()
@@ -23,6 +23,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.logservicer import LogstreamServicer
from ray.experimental.client.server.server_stubs import current_remote
logger = logging.getLogger(__name__)
@@ -372,11 +373,14 @@ def serve(connection_str, test_mode=False):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
data_servicer = DataServicer(task_servicer)
logs_servicer = LogstreamServicer()
_set_server_api(RayServerAPI(task_servicer))
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
data_servicer, server)
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
logs_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return server
+10 -3
View File
@@ -30,6 +30,7 @@ from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientStub
from ray.experimental.client.dataclient import DataClient
from ray.experimental.client.logsclient import LogstreamClient
logger = logging.getLogger(__name__)
@@ -54,9 +55,13 @@ class Worker:
else:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self.data_client = DataClient(self.channel, self._client_id)
self.reference_count: Dict[bytes, int] = defaultdict(int)
self.log_client = LogstreamClient(self.channel)
self.log_client.set_logstream_level(logging.INFO)
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
@@ -197,14 +202,16 @@ class Worker:
ray_client_pb2.ReleaseRequest(ids=[id]))
def call_retain(self, id: bytes) -> None:
logger.debug(f"Retaining {id}")
logger.debug(f"Retaining {id.hex()}")
self.reference_count[id] += 1
def close(self):
self.data_client.close(close_channel=True)
self.server = None
self.log_client.close()
self.data_client.close()
if self.channel:
self.channel.close()
self.channel = None
self.server = None
def get_actor(self, name: str) -> ClientActorHandle:
task = ray_client_pb2.ClientTask()
+27
View File
@@ -1,8 +1,11 @@
import logging
import os
import sys
import threading
from logging.handlers import RotatingFileHandler
from typing import Callable
import ray
from ray.utils import binary_to_hex
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
# logger to add a newline at the end of string.
handler.terminator = ""
return logger
class WorkerStandardStreamDispatcher:
def __init__(self):
self.handlers = []
self._lock = threading.Lock()
def add_handler(self, name: str, handler: Callable) -> None:
with self._lock:
self.handlers.append((name, handler))
def remove_handler(self, name: str) -> None:
with self._lock:
new_handlers = [pair for pair in self.handlers if pair[0] != name]
self.handlers = new_handlers
def emit(self, data):
with self._lock:
for pair in self.handlers:
_, handle = pair
handle(data)
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
+44 -1
View File
@@ -1,4 +1,7 @@
import pytest
import time
import sys
import logging
from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server
@@ -234,6 +237,47 @@ def test_pass_handles(ray_start_regular_shared):
4)) == local_fact(4)
def test_basic_log_stream(ray_start_regular_shared):
with ray_start_client_server() as ray:
log_msgs = []
def test_log(level, msg):
log_msgs.append(msg)
ray.worker.log_client.log = test_log
ray.worker.log_client.set_logstream_level(logging.DEBUG)
# Allow some time to propogate
time.sleep(1)
x = ray.put("Foo")
assert ray.get(x) == "Foo"
time.sleep(1)
logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0]
assert len(logs_with_id) >= 2
assert any((msg.find("get") >= 0 for msg in logs_with_id))
assert any((msg.find("put") >= 0 for msg in logs_with_id))
def test_stdout_log_stream(ray_start_regular_shared):
with ray_start_client_server() as ray:
log_msgs = []
def test_log(level, msg):
log_msgs.append(msg)
ray.worker.log_client.stdstream = test_log
@ray.remote
def print_on_stderr_and_stdout(s):
print(s)
print(s, file=sys.stderr)
time.sleep(1)
print_on_stderr_and_stdout.remote("Hello world")
time.sleep(1)
assert len(log_msgs) == 2
assert all((msg.find("Hello world") for msg in log_msgs))
def test_basic_named_actor(ray_start_regular_shared):
"""
Test that ray.get_actor() can create and return a detached actor.
@@ -264,5 +308,4 @@ def test_basic_named_actor(ray_start_regular_shared):
if __name__ == "__main__":
import sys
sys.exit(pytest.main(["-v", __file__]))
+33 -23
View File
@@ -48,6 +48,7 @@ from ray.exceptions import (
)
from ray.function_manager import FunctionActorManager
from ray.ray_logging import setup_logger
from ray.ray_logging import global_worker_stdstream_dispatcher
from ray.utils import _random_string, check_oversized_pickle
from ray.util.inspect import is_cython
@@ -910,29 +911,8 @@ def print_logs(redis_client, threads_stopped, job_id):
if data["job"] and ray.utils.binary_to_hex(
job_id.binary()) != data["job"]:
continue
print_file = sys.stderr if data["is_err"] else sys.stdout
def color_for(data):
if data["pid"] == "raylet":
return colorama.Fore.YELLOW
else:
return colorama.Fore.CYAN
if data["ip"] == localhost:
for line in data["lines"]:
print(
"{}{}(pid={}){} {}".format(
colorama.Style.DIM, color_for(data), data["pid"],
colorama.Style.RESET_ALL, line),
file=print_file)
else:
for line in data["lines"]:
print(
"{}{}(pid={}, ip={}){} {}".format(
colorama.Style.DIM, color_for(data), data["pid"],
data["ip"], colorama.Style.RESET_ALL, line),
file=print_file)
data["localhost"] = localhost
global_worker_stdstream_dispatcher.emit(data)
except (OSError, redis.exceptions.ConnectionError) as e:
logger.error(f"print_logs: {e}")
@@ -941,6 +921,34 @@ def print_logs(redis_client, threads_stopped, job_id):
pubsub_client.close()
def print_to_stdstream(data):
print_file = sys.stderr if data["is_err"] else sys.stdout
print_worker_logs(data, print_file)
def print_worker_logs(data, print_file):
def color_for(data):
if data["pid"] == "raylet":
return colorama.Fore.YELLOW
else:
return colorama.Fore.CYAN
if data["ip"] == data["localhost"]:
for line in data["lines"]:
print(
"{}{}(pid={}){} {}".format(colorama.Style.DIM, color_for(data),
data["pid"],
colorama.Style.RESET_ALL, line),
file=print_file)
else:
for line in data["lines"]:
print(
"{}{}(pid={}, ip={}){} {}".format(
colorama.Style.DIM, color_for(data), data["pid"],
data["ip"], colorama.Style.RESET_ALL, line),
file=print_file)
def print_error_messages_raylet(task_error_queue, threads_stopped):
"""Prints message received in the given output queue.
@@ -1201,6 +1209,8 @@ def connect(node,
worker.printer_thread.daemon = True
worker.printer_thread.start()
if log_to_driver:
global_worker_stdstream_dispatcher.add_handler(
"ray_print_logs", print_to_stdstream)
worker.logger_thread = threading.Thread(
target=print_logs,
name="ray_print_logs",