Ray PDB support (#11739)

This commit is contained in:
Philipp Moritz
2020-11-03 09:49:23 -08:00
committed by GitHub
parent 952b71dc94
commit 39ce0eadbe
8 changed files with 281 additions and 1 deletions
+4
View File
@@ -492,6 +492,10 @@ cdef execute_task(
core_worker.store_task_outputs(
worker, outputs, c_return_ids, returns)
except Exception as error:
# If the debugger is enabled, drop into the remote pdb here.
if "RAY_PDB" in os.environ:
ray.util.pdb.post_mortem()
if (<int>task_type == <int>TASK_TYPE_ACTOR_CREATION_TASK):
worker.mark_actor_init_failed(error)
+9
View File
@@ -32,3 +32,12 @@ def _internal_kv_put(key, value, overwrite=False):
def _internal_kv_del(key):
return ray.worker.global_worker.redis_client.delete(key)
def _internal_kv_list(prefix):
"""List all keys in the internal KV store that start with the prefix."""
if isinstance(prefix, bytes):
pattern = prefix + b"*"
else:
pattern = prefix + "*"
return ray.worker.global_worker.redis_client.keys(pattern=pattern)
+39
View File
@@ -150,6 +150,44 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
from None
@cli.command()
@click.option(
"--address",
required=False,
type=str,
help="Override the address to connect to.")
def debug(address):
"""Debug Ray program."""
from telnetlib import Telnet
if not address:
address = services.find_redis_address_or_die()
logger.info(f"Connecting to Ray instance at {address}.")
ray.init(address=address)
while True:
active_sessions = ray.experimental.internal_kv._internal_kv_list(
"RAY_PDB_")
print("Active breakpoints:")
for i, active_session in enumerate(active_sessions):
data = json.loads(
ray.experimental.internal_kv._internal_kv_get(active_session))
print(
str(i) + ": " + data["proctitle"] + " | " + data["filename"] +
":" + str(data["lineno"]))
print(data["traceback"])
inp = input("Enter breakpoint index or press enter to refresh: ")
if inp == "":
print()
continue
else:
index = int(inp)
session = json.loads(
ray.experimental.internal_kv._internal_kv_get(
active_sessions[index]))
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
@cli.command()
@click.option(
"--node-ip-address",
@@ -1397,6 +1435,7 @@ def add_command_alias(command, name, hidden):
cli.add_command(dashboard)
cli.add_command(debug)
cli.add_command(start)
cli.add_command(stop)
cli.add_command(up)
+3
View File
@@ -244,6 +244,9 @@ class RayServeWorker:
result = await method_to_call(arg)
self.request_counter.record(1)
except Exception as e:
import os
if "RAY_PDB" in os.environ:
ray.util.pdb.post_mortem()
result = wrap_to_ray_error(e)
self.error_counter.record(1)
+1
View File
@@ -107,6 +107,7 @@ py_test_module_list(
"test_node_manager.py",
"test_numba.py",
"test_queue.py",
"test_ray_debugger.py",
"test_ray_init.py",
"test_tempfile.py",
],
+42
View File
@@ -0,0 +1,42 @@
import json
import os
import sys
from telnetlib import Telnet
import ray
def test_ray_debugger_breakpoint(shutdown_only):
ray.init(num_cpus=1)
@ray.remote
def f():
ray.util.pdb.set_trace()
return 1
result = f.remote()
# Wait until the breakpoint is hit:
while True:
active_sessions = ray.experimental.internal_kv._internal_kv_list(
"RAY_PDB_")
if len(active_sessions) > 0:
break
# Now continue execution:
session = json.loads(
ray.experimental.internal_kv._internal_kv_get(active_sessions[0]))
host, port = session["pdb_address"].split(":")
tn = Telnet(host, int(port))
tn.write(b"c\n")
# This should succeed now!
ray.get(result)
if __name__ == "__main__":
import pytest
# Make subprocess happy in bazel.
os.environ["LC_ALL"] = "en_US.UTF-8"
os.environ["LANG"] = "en_US.UTF-8"
sys.exit(pytest.main(["-v", __file__]))
+2 -1
View File
@@ -4,9 +4,10 @@ from ray.util.debug import log_once, disable_log_once_globally, \
enable_periodic_logging
from ray.util.placement_group import (placement_group, placement_group_table,
remove_placement_group)
from ray.util import rpdb as pdb
__all__ = [
"ActorPool", "disable_log_once_globally", "enable_periodic_logging",
"iter", "log_once", "placement_group", "placement_group_table",
"iter", "log_once", "pdb", "placement_group", "placement_group_table",
"remove_placement_group"
]
+181
View File
@@ -0,0 +1,181 @@
# Some code in this file is from
# https://github.com/ionelmc/python-remote-pdb/blob/07d563331c4ab9eb45731bb272b158816d98236e/src/remote_pdb.py
# (BSD 2-Clause "Simplified" License)
import errno
import inspect
import json
import logging
import os
import re
import socket
import sys
import uuid
from pdb import Pdb
import setproctitle
import traceback
from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
PY3 = sys.version_info[0] == 3
log = logging.getLogger(__name__)
def cry(message, stderr=sys.__stderr__):
print(message, file=stderr)
stderr.flush()
class LF2CRLF_FileWrapper(object):
def __init__(self, connection):
self.connection = connection
self.stream = fh = connection.makefile("rw")
self.read = fh.read
self.readline = fh.readline
self.readlines = fh.readlines
self.close = fh.close
self.flush = fh.flush
self.fileno = fh.fileno
if hasattr(fh, "encoding"):
self._send = lambda data: connection.sendall(
data.encode(fh.encoding))
else:
self._send = connection.sendall
@property
def encoding(self):
return self.stream.encoding
def __iter__(self):
return self.stream.__iter__()
def write(self, data, nl_rex=re.compile("\r?\n")):
data = nl_rex.sub("\r\n", data)
self._send(data)
def writelines(self, lines, nl_rex=re.compile("\r?\n")):
for line in lines:
self.write(line, nl_rex)
class RemotePdb(Pdb):
"""
This will run pdb as a ephemeral telnet service. Once you connect no one
else can connect. On construction this object will block execution till a
client has connected.
Based on https://github.com/tamentis/rpdb I think ...
To use this::
RemotePdb(host="0.0.0.0", port=4444).set_trace()
Then run: telnet 127.0.0.1 4444
"""
active_instance = None
def __init__(self, host, port, patch_stdstreams=False, quiet=False):
self._quiet = quiet
self._patch_stdstreams = patch_stdstreams
self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
True)
self._listen_socket.bind((host, port))
def listen(self):
if not self._quiet:
cry("RemotePdb session open at %s:%s, "
"use 'ray debug' to connect..." %
self._listen_socket.getsockname())
self._listen_socket.listen(1)
connection, address = self._listen_socket.accept()
if not self._quiet:
cry("RemotePdb accepted connection from %s." % repr(address))
self.handle = LF2CRLF_FileWrapper(connection)
Pdb.__init__(
self, completekey="tab", stdin=self.handle, stdout=self.handle)
self.backup = []
if self._patch_stdstreams:
for name in (
"stderr",
"stdout",
"__stderr__",
"__stdout__",
"stdin",
"__stdin__",
):
self.backup.append((name, getattr(sys, name)))
setattr(sys, name, self.handle)
RemotePdb.active_instance = self
def __restore(self):
if self.backup and not self._quiet:
cry("Restoring streams: %s ..." % self.backup)
for name, fh in self.backup:
setattr(sys, name, fh)
self.handle.close()
RemotePdb.active_instance = None
def do_quit(self, arg):
self.__restore()
return Pdb.do_quit(self, arg)
do_q = do_exit = do_quit
def set_trace(self, frame=None):
if frame is None:
frame = sys._getframe().f_back
try:
Pdb.set_trace(self, frame)
except IOError as exc:
if exc.errno != errno.ECONNRESET:
raise
def post_mortem(self, traceback=None):
# See https://github.com/python/cpython/blob/
# 022bc7572f061e1d1132a4db9d085b29707701e7/Lib/pdb.py#L1617
try:
t = sys.exc_info()[2]
self.reset()
Pdb.interaction(self, None, t)
except IOError as exc:
if exc.errno != errno.ECONNRESET:
raise
def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
"""
Opens a remote PDB on first available port.
"""
if host is None:
host = os.environ.get("REMOTE_PDB_HOST", "127.0.0.1")
if port is None:
port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
if quiet is None:
quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
rdb = RemotePdb(
host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet)
sockname = rdb._listen_socket.getsockname()
pdb_address = "{}:{}".format(sockname[0], sockname[1])
parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
data = {
"proctitle": setproctitle.getproctitle(),
"pdb_address": pdb_address,
"filename": parentframeinfo.filename,
"lineno": parentframeinfo.lineno,
"traceback": "\n".join(traceback.format_exception(*sys.exc_info()))
}
breakpoint_uuid = uuid.uuid4()
_internal_kv_put(
"RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True)
rdb.listen()
_internal_kv_del("RAY_PDB_{}".format(breakpoint_uuid))
return rdb
def set_trace(host=None, port=None, patch_stdstreams=False, quiet=None):
frame = sys._getframe().f_back
rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet)
rdb.set_trace(frame=frame)
def post_mortem(host=None, port=None, patch_stdstreams=False, quiet=None):
rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet)
rdb.post_mortem()