From c576f0b0737370507bfa1a55075977e7df6e82a1 Mon Sep 17 00:00:00 2001 From: Barak Michener Date: Sun, 20 Dec 2020 19:35:34 -0800 Subject: [PATCH] [ray_client] Implement a gRPC streaming logs API for the client (#13001) --- python/ray/experimental/client/dataclient.py | 9 +- python/ray/experimental/client/logsclient.py | 84 ++++++++++++++++ .../client/server/dataservicer.py | 2 +- .../experimental/client/server/logservicer.py | 99 +++++++++++++++++++ .../ray/experimental/client/server/server.py | 4 + python/ray/experimental/client/worker.py | 13 ++- python/ray/ray_logging.py | 27 +++++ python/ray/tests/test_experimental_client.py | 45 ++++++++- python/ray/worker.py | 56 ++++++----- src/ray/protobuf/ray_client.proto | 27 +++++ 10 files changed, 332 insertions(+), 34 deletions(-) create mode 100644 python/ray/experimental/client/logsclient.py create mode 100644 python/ray/experimental/client/server/logservicer.py diff --git a/python/ray/experimental/client/dataclient.py b/python/ray/experimental/client/dataclient.py index c6a745df8..b0dda0a1b 100644 --- a/python/ray/experimental/client/dataclient.py +++ b/python/ray/experimental/client/dataclient.py @@ -26,6 +26,7 @@ class DataClient: Args: channel: connected gRPC channel + client_id: the generated ID representing this client """ self.channel = channel self.request_queue = queue.Queue() @@ -68,18 +69,14 @@ class DataClient: logger.info("Cancelling data channel") else: logger.error( - f"Got Error from rpc channel -- shutting down: {e}") + f"Got Error from data channel -- shutting down: {e}") raise e - def close(self, close_channel: bool = False) -> None: + def close(self) -> None: if self.request_queue is not None: self.request_queue.put(None) - self.request_queue = None - if close_channel: - self.channel.close() if self.data_thread is not None: self.data_thread.join() - self.data_thread = None def _blocking_send(self, req: ray_client_pb2.DataRequest ) -> ray_client_pb2.DataResponse: diff --git a/python/ray/experimental/client/logsclient.py b/python/ray/experimental/client/logsclient.py new file mode 100644 index 000000000..f26417e7e --- /dev/null +++ b/python/ray/experimental/client/logsclient.py @@ -0,0 +1,84 @@ +""" +This file implements a threaded stream controller to return logs back from +the ray clientserver. +""" +import sys +import logging +import queue +import threading +import grpc + +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + + +class LogstreamClient: + def __init__(self, channel: "grpc._channel.Channel"): + """Initializes a thread-safe log stream over a Ray Client gRPC channel. + + Args: + channel: connected gRPC channel + """ + self.channel = channel + self.request_queue = queue.Queue() + self.log_thread = self._start_logthread() + self.log_thread.start() + + def _start_logthread(self) -> threading.Thread: + return threading.Thread(target=self._log_main, args=(), daemon=True) + + def _log_main(self) -> None: + stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel) + log_stream = stub.Logstream(iter(self.request_queue.get, None)) + try: + for record in log_stream: + if record.level < 0: + self.stdstream(level=record.level, msg=record.msg) + self.log(level=record.level, msg=record.msg) + except grpc.RpcError as e: + if grpc.StatusCode.CANCELLED != e.code(): + # Not just shutting down normally + logger.error( + f"Got Error from logger channel -- shutting down: {e}") + raise e + + def log(self, level: int, msg: str): + """ + Log the message from the log stream. + By default, calls logger.log but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + logger.log(level=level, msg=msg) + + def stdstream(self, level: int, msg: str): + """ + Log the stdout/stderr entry from the log stream. + By default, calls print but this can be overridden. + + Args: + level: The loglevel of the received log message + msg: The content of the message + """ + print_file = sys.stderr if level == -2 else sys.stdout + print(msg, file=print_file) + + def set_logstream_level(self, level: int): + req = ray_client_pb2.LogSettingsRequest() + req.enabled = True + req.loglevel = level + self.request_queue.put(req) + + def close(self) -> None: + self.request_queue.put(None) + if self.log_thread is not None: + self.log_thread.join() + + def disable_logs(self) -> None: + req = ray_client_pb2.LogSettingsRequest() + req.enabled = False + self.request_queue.put(req) diff --git a/python/ray/experimental/client/server/dataservicer.py b/python/ray/experimental/client/server/dataservicer.py index 874e741d9..925adca28 100644 --- a/python/ray/experimental/client/server/dataservicer.py +++ b/python/ray/experimental/client/server/dataservicer.py @@ -48,7 +48,7 @@ class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer): resp.req_id = req.req_id yield resp except grpc.RpcError as e: - logger.debug(f"Closing channel: {e}") + logger.debug(f"Closing data channel: {e}") finally: logger.info(f"Lost data connection from client {client_id}") self.basic_service.release_all(client_id) diff --git a/python/ray/experimental/client/server/logservicer.py b/python/ray/experimental/client/server/logservicer.py new file mode 100644 index 000000000..9b2fa24bf --- /dev/null +++ b/python/ray/experimental/client/server/logservicer.py @@ -0,0 +1,99 @@ +""" +This file responds to log stream requests and forwards logs +with its handler. +""" +import io +import threading +import queue +import logging +import grpc +import uuid + +from ray.worker import print_worker_logs +from ray.ray_logging import global_worker_stdstream_dispatcher +import ray.core.generated.ray_client_pb2 as ray_client_pb2 +import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc + +logger = logging.getLogger(__name__) + + +class LogstreamHandler(logging.Handler): + def __init__(self, queue, level): + super().__init__() + self.queue = queue + self.level = level + + def emit(self, record: logging.LogRecord): + logdata = ray_client_pb2.LogData() + logdata.msg = record.getMessage() + logdata.level = record.levelno + logdata.name = record.name + self.queue.put(logdata) + + +class StdStreamHandler: + def __init__(self, queue): + self.queue = queue + self.id = str(uuid.uuid4()) + + def handle(self, data): + logdata = ray_client_pb2.LogData() + logdata.level = -2 if data["is_err"] else -1 + logdata.name = "stderr" if data["is_err"] else "stdout" + with io.StringIO() as file: + print_worker_logs(data, file) + logdata.msg = file.getvalue() + self.queue.put(logdata) + + def register_global(self): + global_worker_stdstream_dispatcher.add_handler(self.id, self.handle) + + def unregister_global(self): + global_worker_stdstream_dispatcher.remove_handler(self.id) + + +def log_status_change_thread(log_queue, request_iterator): + std_handler = StdStreamHandler(log_queue) + current_handler = None + root_logger = logging.getLogger("ray") + default_level = root_logger.getEffectiveLevel() + try: + for req in request_iterator: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + if not req.enabled: + current_handler = None + continue + current_handler = LogstreamHandler(log_queue, req.loglevel) + std_handler.register_global() + root_logger.addHandler(current_handler) + root_logger.setLevel(req.loglevel) + finally: + if current_handler is not None: + root_logger.setLevel(default_level) + root_logger.removeHandler(current_handler) + std_handler.unregister_global() + log_queue.put(None) + + +class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer): + def Logstream(self, request_iterator, context): + logger.info("New logs connection") + log_queue = queue.Queue() + thread = threading.Thread( + target=log_status_change_thread, + args=(log_queue, request_iterator), + daemon=True) + thread.start() + try: + queue_iter = iter(log_queue.get, None) + for record in queue_iter: + if record is None: + break + yield record + except grpc.RpcError as e: + logger.debug(f"Closing log channel: {e}") + finally: + thread.join() diff --git a/python/ray/experimental/client/server/server.py b/python/ray/experimental/client/server/server.py index 442cf1afa..7cc286de8 100644 --- a/python/ray/experimental/client/server/server.py +++ b/python/ray/experimental/client/server/server.py @@ -23,6 +23,7 @@ from ray.experimental.client.server.server_pickler import dumps_from_server from ray.experimental.client.server.server_pickler import loads_from_client from ray.experimental.client.server.core_ray_api import RayServerAPI from ray.experimental.client.server.dataservicer import DataServicer +from ray.experimental.client.server.logservicer import LogstreamServicer from ray.experimental.client.server.server_stubs import current_remote logger = logging.getLogger(__name__) @@ -372,11 +373,14 @@ def serve(connection_str, test_mode=False): server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) task_servicer = RayletServicer(test_mode=test_mode) data_servicer = DataServicer(task_servicer) + logs_servicer = LogstreamServicer() _set_server_api(RayServerAPI(task_servicer)) ray_client_pb2_grpc.add_RayletDriverServicer_to_server( task_servicer, server) ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server( data_servicer, server) + ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server( + logs_servicer, server) server.add_insecure_port(connection_str) server.start() return server diff --git a/python/ray/experimental/client/worker.py b/python/ray/experimental/client/worker.py index bba23584b..8ed41bff4 100644 --- a/python/ray/experimental/client/worker.py +++ b/python/ray/experimental/client/worker.py @@ -30,6 +30,7 @@ from ray.experimental.client.common import ClientObjectRef from ray.experimental.client.common import ClientRemoteFunc from ray.experimental.client.common import ClientStub from ray.experimental.client.dataclient import DataClient +from ray.experimental.client.logsclient import LogstreamClient logger = logging.getLogger(__name__) @@ -54,9 +55,13 @@ class Worker: else: self.channel = grpc.insecure_channel(conn_str) self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel) + self.data_client = DataClient(self.channel, self._client_id) self.reference_count: Dict[bytes, int] = defaultdict(int) + self.log_client = LogstreamClient(self.channel) + self.log_client.set_logstream_level(logging.INFO) + def get(self, vals, *, timeout: Optional[float] = None) -> Any: to_get = [] single = False @@ -197,14 +202,16 @@ class Worker: ray_client_pb2.ReleaseRequest(ids=[id])) def call_retain(self, id: bytes) -> None: - logger.debug(f"Retaining {id}") + logger.debug(f"Retaining {id.hex()}") self.reference_count[id] += 1 def close(self): - self.data_client.close(close_channel=True) - self.server = None + self.log_client.close() + self.data_client.close() if self.channel: + self.channel.close() self.channel = None + self.server = None def get_actor(self, name: str) -> ClientActorHandle: task = ray_client_pb2.ClientTask() diff --git a/python/ray/ray_logging.py b/python/ray/ray_logging.py index 0668f397f..56df7b5c2 100644 --- a/python/ray/ray_logging.py +++ b/python/ray/ray_logging.py @@ -1,8 +1,11 @@ import logging import os import sys +import threading from logging.handlers import RotatingFileHandler +from typing import Callable + import ray from ray.utils import binary_to_hex @@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args, # logger to add a newline at the end of string. handler.terminator = "" return logger + + +class WorkerStandardStreamDispatcher: + def __init__(self): + self.handlers = [] + self._lock = threading.Lock() + + def add_handler(self, name: str, handler: Callable) -> None: + with self._lock: + self.handlers.append((name, handler)) + + def remove_handler(self, name: str) -> None: + with self._lock: + new_handlers = [pair for pair in self.handlers if pair[0] != name] + self.handlers = new_handlers + + def emit(self, data): + with self._lock: + for pair in self.handlers: + _, handle = pair + handle(data) + + +global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher() diff --git a/python/ray/tests/test_experimental_client.py b/python/ray/tests/test_experimental_client.py index cc15e7272..e6afee042 100644 --- a/python/ray/tests/test_experimental_client.py +++ b/python/ray/tests/test_experimental_client.py @@ -1,4 +1,7 @@ import pytest +import time +import sys +import logging from contextlib import contextmanager import ray.experimental.client.server.server as ray_client_server @@ -234,6 +237,47 @@ def test_pass_handles(ray_start_regular_shared): 4)) == local_fact(4) +def test_basic_log_stream(ray_start_regular_shared): + with ray_start_client_server() as ray: + log_msgs = [] + + def test_log(level, msg): + log_msgs.append(msg) + + ray.worker.log_client.log = test_log + ray.worker.log_client.set_logstream_level(logging.DEBUG) + # Allow some time to propogate + time.sleep(1) + x = ray.put("Foo") + assert ray.get(x) == "Foo" + time.sleep(1) + logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0] + assert len(logs_with_id) >= 2 + assert any((msg.find("get") >= 0 for msg in logs_with_id)) + assert any((msg.find("put") >= 0 for msg in logs_with_id)) + + +def test_stdout_log_stream(ray_start_regular_shared): + with ray_start_client_server() as ray: + log_msgs = [] + + def test_log(level, msg): + log_msgs.append(msg) + + ray.worker.log_client.stdstream = test_log + + @ray.remote + def print_on_stderr_and_stdout(s): + print(s) + print(s, file=sys.stderr) + + time.sleep(1) + print_on_stderr_and_stdout.remote("Hello world") + time.sleep(1) + assert len(log_msgs) == 2 + assert all((msg.find("Hello world") for msg in log_msgs)) + + def test_basic_named_actor(ray_start_regular_shared): """ Test that ray.get_actor() can create and return a detached actor. @@ -264,5 +308,4 @@ def test_basic_named_actor(ray_start_regular_shared): if __name__ == "__main__": - import sys sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/worker.py b/python/ray/worker.py index 495478ad7..631a82767 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -48,6 +48,7 @@ from ray.exceptions import ( ) from ray.function_manager import FunctionActorManager from ray.ray_logging import setup_logger +from ray.ray_logging import global_worker_stdstream_dispatcher from ray.utils import _random_string, check_oversized_pickle from ray.util.inspect import is_cython @@ -910,29 +911,8 @@ def print_logs(redis_client, threads_stopped, job_id): if data["job"] and ray.utils.binary_to_hex( job_id.binary()) != data["job"]: continue - - print_file = sys.stderr if data["is_err"] else sys.stdout - - def color_for(data): - if data["pid"] == "raylet": - return colorama.Fore.YELLOW - else: - return colorama.Fore.CYAN - - if data["ip"] == localhost: - for line in data["lines"]: - print( - "{}{}(pid={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - colorama.Style.RESET_ALL, line), - file=print_file) - else: - for line in data["lines"]: - print( - "{}{}(pid={}, ip={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - data["ip"], colorama.Style.RESET_ALL, line), - file=print_file) + data["localhost"] = localhost + global_worker_stdstream_dispatcher.emit(data) except (OSError, redis.exceptions.ConnectionError) as e: logger.error(f"print_logs: {e}") @@ -941,6 +921,34 @@ def print_logs(redis_client, threads_stopped, job_id): pubsub_client.close() +def print_to_stdstream(data): + print_file = sys.stderr if data["is_err"] else sys.stdout + print_worker_logs(data, print_file) + + +def print_worker_logs(data, print_file): + def color_for(data): + if data["pid"] == "raylet": + return colorama.Fore.YELLOW + else: + return colorama.Fore.CYAN + + if data["ip"] == data["localhost"]: + for line in data["lines"]: + print( + "{}{}(pid={}){} {}".format(colorama.Style.DIM, color_for(data), + data["pid"], + colorama.Style.RESET_ALL, line), + file=print_file) + else: + for line in data["lines"]: + print( + "{}{}(pid={}, ip={}){} {}".format( + colorama.Style.DIM, color_for(data), data["pid"], + data["ip"], colorama.Style.RESET_ALL, line), + file=print_file) + + def print_error_messages_raylet(task_error_queue, threads_stopped): """Prints message received in the given output queue. @@ -1201,6 +1209,8 @@ def connect(node, worker.printer_thread.daemon = True worker.printer_thread.start() if log_to_driver: + global_worker_stdstream_dispatcher.add_handler( + "ray_print_logs", print_to_stdstream) worker.logger_thread = threading.Thread( target=print_logs, name="ray_print_logs", diff --git a/src/ray/protobuf/ray_client.proto b/src/ray/protobuf/ray_client.proto index a566f8031..3dd3128b2 100644 --- a/src/ray/protobuf/ray_client.proto +++ b/src/ray/protobuf/ray_client.proto @@ -229,3 +229,30 @@ service RayletDataStreamer { rpc Datapath(stream DataRequest) returns (stream DataResponse) { } } + +// A request to change the quantity or type of the logs +// currently being streamed. Initially, all logs are disabled. +message LogSettingsRequest { + // Set to recieve logs. + bool enabled = 1; + // At what loglevel should logs be forwarded on the stream. + int32 loglevel = 2; + // TODO(barakmich): More log filtering options. +} + +message LogData { + // The message data in the log + string msg = 1; + // The loglevel at which this log should be displayed. + // * level > 0: Log leveling as per python's logging library + // * level == -1: stdout (fd 1) + // * level == -2: stderr (fd 2) + int32 level = 2; + // The name of the logger that generated this message. + string name = 3; +} + +service RayletLogStreamer { + rpc Logstream(stream LogSettingsRequest) returns (stream LogData) { + } +}