diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 699733104..7f6a8e22f 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -5,6 +5,7 @@ import json import logging import os import platform +import re import shutil import time import traceback @@ -18,19 +19,24 @@ import ray.utils # entry/init points. logger = logging.getLogger(__name__) +# First group is worker id. Second group is job id. +JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)") + class LogFileInfo: def __init__(self, filename=None, size_when_last_opened=None, file_position=None, - file_handle=None): + file_handle=None, + job_id=None): assert (filename is not None and size_when_last_opened is not None and file_position is not None) self.filename = filename self.size_when_last_opened = size_when_last_opened self.file_position = file_position self.file_handle = file_handle + self.job_id = job_id self.worker_pid = None @@ -116,13 +122,20 @@ class LogMonitor: for file_path in log_file_paths + raylet_err_paths: if os.path.isfile( file_path) and file_path not in self.log_filenames: + job_match = JOB_LOG_PATTERN.match(file_path) + if job_match: + job_id = job_match.group(2) + else: + job_id = None + self.log_filenames.add(file_path) self.closed_file_infos.append( LogFileInfo( filename=file_path, size_when_last_opened=0, file_position=0, - file_handle=None)) + file_handle=None, + job_id=job_id)) log_filename = os.path.basename(file_path) logger.info("Beginning to track file {}".format(log_filename)) @@ -231,6 +244,7 @@ class LogMonitor: json.dumps({ "ip": self.ip, "pid": file_info.worker_pid, + "job": file_info.job_id, "lines": lines_to_publish })) anything_published = True diff --git a/python/ray/test_utils.py b/python/ray/test_utils.py index 7d696da1d..2e2d03a5b 100644 --- a/python/ray/test_utils.py +++ b/python/ray/test_utils.py @@ -130,7 +130,9 @@ def run_string_as_driver_nonblocking(driver_script): f.write(driver_script.encode("ascii")) f.flush() return subprocess.Popen( - [sys.executable, f.name], stdout=subprocess.PIPE) + [sys.executable, f.name], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) def flat_errors(): diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index e31c1f574..397ecc3d1 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -5,13 +5,9 @@ import time import ray from ray.test_utils import ( - RayTestTimeoutException, - run_string_as_driver, - run_string_as_driver_nonblocking, - wait_for_children_of_pid, - wait_for_children_of_pid_to_exit, - kill_process_by_name, -) + RayTestTimeoutException, run_string_as_driver, + run_string_as_driver_nonblocking, wait_for_children_of_pid, + wait_for_children_of_pid_to_exit, kill_process_by_name, Semaphore) def test_error_isolation(call_ray_start): @@ -634,6 +630,77 @@ print("success") ray.get(f.remote()) +def test_multi_driver_logging(ray_start_regular): + address_info = ray_start_regular + address = address_info["redis_address"] + + # ray.init(address=address) + driver1_wait = Semaphore.options(name="driver1_wait").remote(value=0) + driver2_wait = Semaphore.options(name="driver2_wait").remote(value=0) + main_wait = Semaphore.options(name="main_wait").remote(value=0) + + # Params are address, semaphore name, output1, output2 + driver_script_template = """ +import ray +import sys +from ray.test_utils import Semaphore + +@ray.remote(num_cpus=0) +def remote_print(s, file=None): + print(s, file=file) + +ray.init(address="{}") + +driver_wait = ray.get_actor("{}") +main_wait = ray.get_actor("main_wait") + +ray.get(main_wait.release.remote()) +ray.get(driver_wait.acquire.remote()) + +s1 = "{}" +ray.get(remote_print.remote(s1)) + +ray.get(main_wait.release.remote()) +ray.get(driver_wait.acquire.remote()) + +s2 = "{}" +ray.get(remote_print.remote(s2)) + +ray.get(main_wait.release.remote()) + """ + + p1 = run_string_as_driver_nonblocking( + driver_script_template.format(address, "driver1_wait", "1", "2")) + p2 = run_string_as_driver_nonblocking( + driver_script_template.format(address, "driver2_wait", "3", "4")) + + ray.get(main_wait.acquire.remote()) + ray.get(main_wait.acquire.remote()) + # At this point both of the other drivers are fully initialized. + + ray.get(driver1_wait.release.remote()) + ray.get(driver2_wait.release.remote()) + + # At this point driver1 should receive '1' and driver2 '3' + ray.get(main_wait.acquire.remote()) + ray.get(main_wait.acquire.remote()) + + ray.get(driver1_wait.release.remote()) + ray.get(driver2_wait.release.remote()) + + # At this point driver1 should receive '2' and driver2 '4' + ray.get(main_wait.acquire.remote()) + ray.get(main_wait.acquire.remote()) + + driver1_out = p1.stdout.read().decode("ascii").split("\n") + driver2_out = p2.stdout.read().decode("ascii").split("\n") + + assert driver1_out[0][-1] == "1" + assert driver1_out[1][-1] == "2" + assert driver2_out[0][-1] == "3" + assert driver2_out[1][-1] == "4" + + if __name__ == "__main__": import pytest import sys diff --git a/python/ray/utils.py b/python/ray/utils.py index 692e819d2..369ce5cf4 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -403,6 +403,7 @@ def create_and_init_new_worker_log(path, worker_pid): Args: path (str): The name/path of the file to be opened. + worker_pid (int): The pid of the worker process. Returns: A file-like object which can be written to. diff --git a/python/ray/worker.py b/python/ray/worker.py index 4b71b10fa..e7fa68574 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -953,13 +953,15 @@ def set_log_file(stdout_name, stderr_name): return stdout_path, stderr_path -def print_logs(redis_client, threads_stopped): +def print_logs(redis_client, threads_stopped, job_id): """Prints log messages from workers on all of the nodes. Args: redis_client: A client to the primary Redis shard. threads_stopped (threading.Event): A threading event used to signal to the thread that it should exit. + job_id (JobID): The id of the driver's job + """ pubsub_client = redis_client.pubsub(ignore_subscribe_messages=True) pubsub_client.subscribe(ray.gcs_utils.LOG_FILE_CHANNEL) @@ -981,9 +983,20 @@ def print_logs(redis_client, threads_stopped): threads_stopped.wait(timeout=0.01) continue num_consecutive_messages_received += 1 + if (num_consecutive_messages_received % 100 == 0 + and num_consecutive_messages_received > 0): + logger.warning( + "The driver may not be able to keep up with the " + "stdout/stderr of the workers. To avoid forwarding logs " + "to the driver, use 'ray.init(log_to_driver=False)'.") data = json.loads(ray.utils.decode(msg["data"])) + # Don't show logs from other drivers. + if data["job"] and ray.utils.binary_to_hex( + job_id.binary()) != data["job"]: + continue + def color_for(data): if data["pid"] == "raylet": return colorama.Fore.YELLOW @@ -1001,12 +1014,6 @@ def print_logs(redis_client, threads_stopped): colorama.Style.DIM, color_for(data), data["pid"], data["ip"], colorama.Style.RESET_ALL, line)) - if (num_consecutive_messages_received % 100 == 0 - and num_consecutive_messages_received > 0): - logger.warning( - "The driver may not be able to keep up with the " - "stdout/stderr of the workers. To avoid forwarding logs " - "to the driver, use 'ray.init(log_to_driver=False)'.") except (OSError, redis.exceptions.ConnectionError) as e: logger.error("print_logs: {}".format(e)) finally: @@ -1310,7 +1317,7 @@ def connect(node, worker.logger_thread = threading.Thread( target=print_logs, name="ray_print_logs", - args=(worker.redis_client, worker.threads_stopped)) + args=(worker.redis_client, worker.threads_stopped, job_id)) worker.logger_thread.daemon = True worker.logger_thread.start()