diff --git a/python/ray/log_monitor.py b/python/ray/log_monitor.py index 7f6a8e22f..31e83c8c3 100644 --- a/python/ray/log_monitor.py +++ b/python/ray/log_monitor.py @@ -29,6 +29,7 @@ class LogFileInfo: size_when_last_opened=None, file_position=None, file_handle=None, + is_err_file=False, job_id=None): assert (filename is not None and size_when_last_opened is not None and file_position is not None) @@ -36,6 +37,7 @@ class LogFileInfo: self.size_when_last_opened = size_when_last_opened self.file_position = file_position self.file_handle = file_handle + self.is_err_file = is_err_file self.job_id = job_id self.worker_pid = None @@ -128,6 +130,8 @@ class LogMonitor: else: job_id = None + is_err_file = file_path.endswith("err") + self.log_filenames.add(file_path) self.closed_file_infos.append( LogFileInfo( @@ -135,6 +139,7 @@ class LogMonitor: size_when_last_opened=0, file_position=0, file_handle=None, + is_err_file=is_err_file, job_id=job_id)) log_filename = os.path.basename(file_path) logger.info("Beginning to track file {}".format(log_filename)) @@ -245,6 +250,7 @@ class LogMonitor: "ip": self.ip, "pid": file_info.worker_pid, "job": file_info.job_id, + "is_err": file_info.is_err_file, "lines": lines_to_publish })) anything_published = True diff --git a/python/ray/tests/test_output.py b/python/ray/tests/test_output.py index 5b0e986bf..2e9437d52 100644 --- a/python/ray/tests/test_output.py +++ b/python/ray/tests/test_output.py @@ -4,6 +4,31 @@ import sys import pytest import ray +from ray.test_utils import run_string_as_driver_nonblocking + + +def test_worker_stdout(): + script = """ +import ray +import sys + +ray.init(num_cpus=2) + +@ray.remote +def foo(out_str, err_str): + print(out_str) + print(err_str, file=sys.stderr) + +ray.get(foo.remote("abc", "def")) + """ + + proc = run_string_as_driver_nonblocking(script) + out_str = proc.stdout.read().decode("ascii") + err_str = proc.stderr.read().decode("ascii") + + assert out_str.endswith("abc\n") + assert err_str.split("\n")[-2].endswith("def") + def test_output(): # Use subprocess to execute the __main__ below. diff --git a/python/ray/worker.py b/python/ray/worker.py index 6e62b3bd4..7c1256855 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1009,6 +1009,8 @@ def print_logs(redis_client, threads_stopped, job_id): job_id.binary()) != data["job"]: continue + print_file = sys.stderr if data["is_err"] else sys.stdout + def color_for(data): if data["pid"] == "raylet": return colorama.Fore.YELLOW @@ -1017,14 +1019,18 @@ def print_logs(redis_client, threads_stopped, job_id): if data["ip"] == localhost: for line in data["lines"]: - print("{}{}(pid={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - colorama.Style.RESET_ALL, line)) + print( + "{}{}(pid={}){} {}".format( + colorama.Style.DIM, color_for(data), data["pid"], + colorama.Style.RESET_ALL, line), + file=print_file) else: for line in data["lines"]: - print("{}{}(pid={}, ip={}){} {}".format( - colorama.Style.DIM, color_for(data), data["pid"], - data["ip"], colorama.Style.RESET_ALL, line)) + print( + "{}{}(pid={}, ip={}){} {}".format( + colorama.Style.DIM, color_for(data), data["pid"], + data["ip"], colorama.Style.RESET_ALL, line), + file=print_file) except (OSError, redis.exceptions.ConnectionError) as e: logger.error("print_logs: {}".format(e))