diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 47cee86ba..75a5045e4 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -492,6 +492,10 @@ cdef execute_task( core_worker.store_task_outputs( worker, outputs, c_return_ids, returns) except Exception as error: + # If the debugger is enabled, drop into the remote pdb here. + if "RAY_PDB" in os.environ: + ray.util.pdb.post_mortem() + if (task_type == TASK_TYPE_ACTOR_CREATION_TASK): worker.mark_actor_init_failed(error) diff --git a/python/ray/experimental/internal_kv.py b/python/ray/experimental/internal_kv.py index 14434b558..6ce2ad162 100644 --- a/python/ray/experimental/internal_kv.py +++ b/python/ray/experimental/internal_kv.py @@ -32,3 +32,12 @@ def _internal_kv_put(key, value, overwrite=False): def _internal_kv_del(key): return ray.worker.global_worker.redis_client.delete(key) + + +def _internal_kv_list(prefix): + """List all keys in the internal KV store that start with the prefix.""" + if isinstance(prefix, bytes): + pattern = prefix + b"*" + else: + pattern = prefix + "*" + return ray.worker.global_worker.redis_client.keys(pattern=pattern) diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index d8558ba37..ab09dec8e 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -150,6 +150,44 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port): from None +@cli.command() +@click.option( + "--address", + required=False, + type=str, + help="Override the address to connect to.") +def debug(address): + """Debug Ray program.""" + from telnetlib import Telnet + if not address: + address = services.find_redis_address_or_die() + logger.info(f"Connecting to Ray instance at {address}.") + ray.init(address=address) + while True: + active_sessions = ray.experimental.internal_kv._internal_kv_list( + "RAY_PDB_") + print("Active breakpoints:") + for i, active_session in enumerate(active_sessions): + data = json.loads( + ray.experimental.internal_kv._internal_kv_get(active_session)) + print( + str(i) + ": " + data["proctitle"] + " | " + data["filename"] + + ":" + str(data["lineno"])) + print(data["traceback"]) + inp = input("Enter breakpoint index or press enter to refresh: ") + if inp == "": + print() + continue + else: + index = int(inp) + session = json.loads( + ray.experimental.internal_kv._internal_kv_get( + active_sessions[index])) + host, port = session["pdb_address"].split(":") + with Telnet(host, int(port)) as tn: + tn.interact() + + @cli.command() @click.option( "--node-ip-address", @@ -1397,6 +1435,7 @@ def add_command_alias(command, name, hidden): cli.add_command(dashboard) +cli.add_command(debug) cli.add_command(start) cli.add_command(stop) cli.add_command(up) diff --git a/python/ray/serve/backend_worker.py b/python/ray/serve/backend_worker.py index 510663932..8be81721a 100644 --- a/python/ray/serve/backend_worker.py +++ b/python/ray/serve/backend_worker.py @@ -244,6 +244,9 @@ class RayServeWorker: result = await method_to_call(arg) self.request_counter.record(1) except Exception as e: + import os + if "RAY_PDB" in os.environ: + ray.util.pdb.post_mortem() result = wrap_to_ray_error(e) self.error_counter.record(1) diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index d9676022c..f4255827c 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -107,6 +107,7 @@ py_test_module_list( "test_node_manager.py", "test_numba.py", "test_queue.py", + "test_ray_debugger.py", "test_ray_init.py", "test_tempfile.py", ], diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py new file mode 100644 index 000000000..67df794f8 --- /dev/null +++ b/python/ray/tests/test_ray_debugger.py @@ -0,0 +1,42 @@ +import json +import os +import sys +from telnetlib import Telnet + +import ray + + +def test_ray_debugger_breakpoint(shutdown_only): + ray.init(num_cpus=1) + + @ray.remote + def f(): + ray.util.pdb.set_trace() + return 1 + + result = f.remote() + + # Wait until the breakpoint is hit: + while True: + active_sessions = ray.experimental.internal_kv._internal_kv_list( + "RAY_PDB_") + if len(active_sessions) > 0: + break + + # Now continue execution: + session = json.loads( + ray.experimental.internal_kv._internal_kv_get(active_sessions[0])) + host, port = session["pdb_address"].split(":") + tn = Telnet(host, int(port)) + tn.write(b"c\n") + + # This should succeed now! + ray.get(result) + + +if __name__ == "__main__": + import pytest + # Make subprocess happy in bazel. + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["LANG"] = "en_US.UTF-8" + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/util/__init__.py b/python/ray/util/__init__.py index 055f37697..2a6d0a029 100644 --- a/python/ray/util/__init__.py +++ b/python/ray/util/__init__.py @@ -4,9 +4,10 @@ from ray.util.debug import log_once, disable_log_once_globally, \ enable_periodic_logging from ray.util.placement_group import (placement_group, placement_group_table, remove_placement_group) +from ray.util import rpdb as pdb __all__ = [ "ActorPool", "disable_log_once_globally", "enable_periodic_logging", - "iter", "log_once", "placement_group", "placement_group_table", + "iter", "log_once", "pdb", "placement_group", "placement_group_table", "remove_placement_group" ] diff --git a/python/ray/util/rpdb.py b/python/ray/util/rpdb.py new file mode 100644 index 000000000..21d2ec07d --- /dev/null +++ b/python/ray/util/rpdb.py @@ -0,0 +1,181 @@ +# Some code in this file is from +# https://github.com/ionelmc/python-remote-pdb/blob/07d563331c4ab9eb45731bb272b158816d98236e/src/remote_pdb.py +# (BSD 2-Clause "Simplified" License) + +import errno +import inspect +import json +import logging +import os +import re +import socket +import sys +import uuid +from pdb import Pdb +import setproctitle +import traceback + +from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put + +PY3 = sys.version_info[0] == 3 +log = logging.getLogger(__name__) + + +def cry(message, stderr=sys.__stderr__): + print(message, file=stderr) + stderr.flush() + + +class LF2CRLF_FileWrapper(object): + def __init__(self, connection): + self.connection = connection + self.stream = fh = connection.makefile("rw") + self.read = fh.read + self.readline = fh.readline + self.readlines = fh.readlines + self.close = fh.close + self.flush = fh.flush + self.fileno = fh.fileno + if hasattr(fh, "encoding"): + self._send = lambda data: connection.sendall( + data.encode(fh.encoding)) + else: + self._send = connection.sendall + + @property + def encoding(self): + return self.stream.encoding + + def __iter__(self): + return self.stream.__iter__() + + def write(self, data, nl_rex=re.compile("\r?\n")): + data = nl_rex.sub("\r\n", data) + self._send(data) + + def writelines(self, lines, nl_rex=re.compile("\r?\n")): + for line in lines: + self.write(line, nl_rex) + + +class RemotePdb(Pdb): + """ + This will run pdb as a ephemeral telnet service. Once you connect no one + else can connect. On construction this object will block execution till a + client has connected. + Based on https://github.com/tamentis/rpdb I think ... + To use this:: + RemotePdb(host="0.0.0.0", port=4444).set_trace() + Then run: telnet 127.0.0.1 4444 + """ + active_instance = None + + def __init__(self, host, port, patch_stdstreams=False, quiet=False): + self._quiet = quiet + self._patch_stdstreams = patch_stdstreams + self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + True) + self._listen_socket.bind((host, port)) + + def listen(self): + if not self._quiet: + cry("RemotePdb session open at %s:%s, " + "use 'ray debug' to connect..." % + self._listen_socket.getsockname()) + self._listen_socket.listen(1) + connection, address = self._listen_socket.accept() + if not self._quiet: + cry("RemotePdb accepted connection from %s." % repr(address)) + self.handle = LF2CRLF_FileWrapper(connection) + Pdb.__init__( + self, completekey="tab", stdin=self.handle, stdout=self.handle) + self.backup = [] + if self._patch_stdstreams: + for name in ( + "stderr", + "stdout", + "__stderr__", + "__stdout__", + "stdin", + "__stdin__", + ): + self.backup.append((name, getattr(sys, name))) + setattr(sys, name, self.handle) + RemotePdb.active_instance = self + + def __restore(self): + if self.backup and not self._quiet: + cry("Restoring streams: %s ..." % self.backup) + for name, fh in self.backup: + setattr(sys, name, fh) + self.handle.close() + RemotePdb.active_instance = None + + def do_quit(self, arg): + self.__restore() + return Pdb.do_quit(self, arg) + + do_q = do_exit = do_quit + + def set_trace(self, frame=None): + if frame is None: + frame = sys._getframe().f_back + try: + Pdb.set_trace(self, frame) + except IOError as exc: + if exc.errno != errno.ECONNRESET: + raise + + def post_mortem(self, traceback=None): + # See https://github.com/python/cpython/blob/ + # 022bc7572f061e1d1132a4db9d085b29707701e7/Lib/pdb.py#L1617 + try: + t = sys.exc_info()[2] + self.reset() + Pdb.interaction(self, None, t) + except IOError as exc: + if exc.errno != errno.ECONNRESET: + raise + + +def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None): + """ + Opens a remote PDB on first available port. + """ + if host is None: + host = os.environ.get("REMOTE_PDB_HOST", "127.0.0.1") + if port is None: + port = int(os.environ.get("REMOTE_PDB_PORT", "0")) + if quiet is None: + quiet = bool(os.environ.get("REMOTE_PDB_QUIET", "")) + rdb = RemotePdb( + host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet) + sockname = rdb._listen_socket.getsockname() + pdb_address = "{}:{}".format(sockname[0], sockname[1]) + parentframeinfo = inspect.getouterframes(inspect.currentframe())[2] + data = { + "proctitle": setproctitle.getproctitle(), + "pdb_address": pdb_address, + "filename": parentframeinfo.filename, + "lineno": parentframeinfo.lineno, + "traceback": "\n".join(traceback.format_exception(*sys.exc_info())) + } + breakpoint_uuid = uuid.uuid4() + _internal_kv_put( + "RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True) + rdb.listen() + _internal_kv_del("RAY_PDB_{}".format(breakpoint_uuid)) + + return rdb + + +def set_trace(host=None, port=None, patch_stdstreams=False, quiet=None): + frame = sys._getframe().f_back + rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet) + rdb.set_trace(frame=frame) + + +def post_mortem(host=None, port=None, patch_stdstreams=False, quiet=None): + rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet) + rdb.post_mortem()