diff --git a/python/ray/tests/test_multi_node.py b/python/ray/tests/test_multi_node.py index 719e47380..8e6cd7fcf 100644 --- a/python/ray/tests/test_multi_node.py +++ b/python/ray/tests/test_multi_node.py @@ -183,6 +183,41 @@ print("success") assert "success" in out +def test_receive_late_worker_logs(): + # Make sure that log messages from tasks appear in the stdout even if the + # script exits quickly. + log_message = "some helpful debugging message" + + # Define a driver that creates a task that prints something, ensures that + # the task runs, and then exits. + driver_script = """ +import ray +import random +import time + +log_message = "{}" + +@ray.remote +class Actor(object): + def log(self): + print(log_message) + +@ray.remote +def f(): + print(log_message) + +ray.init(num_cpus=2) + +a = Actor.remote() +ray.get([a.log.remote(), f.remote()]) +ray.get([a.log.remote(), f.remote()]) +""".format(log_message) + + for _ in range(2): + out = run_string_as_driver(driver_script) + assert out.count(log_message) == 4 + + @pytest.fixture def ray_start_head_with_resources(): out = run_and_get_output( diff --git a/python/ray/worker.py b/python/ray/worker.py index bcd4a681c..c7ee4e4d7 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1541,7 +1541,7 @@ def init(redis_address=None, _post_init_hooks = [] -def shutdown(): +def shutdown(exiting_interpreter=False): """Disconnect the worker, and terminate processes started by ray.init(). This will automatically run at the end when a Python process that uses Ray @@ -1553,7 +1553,17 @@ def shutdown(): defined remote functions or actors after calling ray.shutdown(), then you need to redefine them. If they were defined in an imported module, then you will need to reload the module. + + Args: + exiting_interpreter (bool): True if this is called by the atexit hook + and false otherwise. If we are exiting the interpreter, we will + wait a little while to print any extra error messages. """ + if exiting_interpreter and global_worker.mode == SCRIPT_MODE: + # This is a duration to sleep before shutting down everything in order + # to make sure that log messages finish printing. + time.sleep(0.5) + disconnect() # Shut down the Ray processes. @@ -1565,7 +1575,7 @@ def shutdown(): global_worker.set_mode(None) -atexit.register(shutdown) +atexit.register(shutdown, True) # Define a custom excepthook so that if the driver exits with an exception, we # can push that exception to Redis. @@ -1670,6 +1680,8 @@ def print_error_messages_raylet(task_error_queue, threads_stopped): # messages originating from the worker. while t + UNCAUGHT_ERROR_GRACE_PERIOD > time.time(): threads_stopped.wait(timeout=1) + if threads_stopped.is_set(): + break if t < last_task_error_raise_time + UNCAUGHT_ERROR_GRACE_PERIOD: logger.debug("Suppressing error from worker: {}".format(error)) else: