mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
Ray PDB support (#11739)
This commit is contained in:
@@ -492,6 +492,10 @@ cdef execute_task(
|
||||
core_worker.store_task_outputs(
|
||||
worker, outputs, c_return_ids, returns)
|
||||
except Exception as error:
|
||||
# If the debugger is enabled, drop into the remote pdb here.
|
||||
if "RAY_PDB" in os.environ:
|
||||
ray.util.pdb.post_mortem()
|
||||
|
||||
if (<int>task_type == <int>TASK_TYPE_ACTOR_CREATION_TASK):
|
||||
worker.mark_actor_init_failed(error)
|
||||
|
||||
|
||||
@@ -32,3 +32,12 @@ def _internal_kv_put(key, value, overwrite=False):
|
||||
|
||||
def _internal_kv_del(key):
|
||||
return ray.worker.global_worker.redis_client.delete(key)
|
||||
|
||||
|
||||
def _internal_kv_list(prefix):
|
||||
"""List all keys in the internal KV store that start with the prefix."""
|
||||
if isinstance(prefix, bytes):
|
||||
pattern = prefix + b"*"
|
||||
else:
|
||||
pattern = prefix + "*"
|
||||
return ray.worker.global_worker.redis_client.keys(pattern=pattern)
|
||||
|
||||
@@ -150,6 +150,44 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
|
||||
from None
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--address",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Override the address to connect to.")
|
||||
def debug(address):
|
||||
"""Debug Ray program."""
|
||||
from telnetlib import Telnet
|
||||
if not address:
|
||||
address = services.find_redis_address_or_die()
|
||||
logger.info(f"Connecting to Ray instance at {address}.")
|
||||
ray.init(address=address)
|
||||
while True:
|
||||
active_sessions = ray.experimental.internal_kv._internal_kv_list(
|
||||
"RAY_PDB_")
|
||||
print("Active breakpoints:")
|
||||
for i, active_session in enumerate(active_sessions):
|
||||
data = json.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(active_session))
|
||||
print(
|
||||
str(i) + ": " + data["proctitle"] + " | " + data["filename"] +
|
||||
":" + str(data["lineno"]))
|
||||
print(data["traceback"])
|
||||
inp = input("Enter breakpoint index or press enter to refresh: ")
|
||||
if inp == "":
|
||||
print()
|
||||
continue
|
||||
else:
|
||||
index = int(inp)
|
||||
session = json.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
active_sessions[index]))
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--node-ip-address",
|
||||
@@ -1397,6 +1435,7 @@ def add_command_alias(command, name, hidden):
|
||||
|
||||
|
||||
cli.add_command(dashboard)
|
||||
cli.add_command(debug)
|
||||
cli.add_command(start)
|
||||
cli.add_command(stop)
|
||||
cli.add_command(up)
|
||||
|
||||
@@ -244,6 +244,9 @@ class RayServeWorker:
|
||||
result = await method_to_call(arg)
|
||||
self.request_counter.record(1)
|
||||
except Exception as e:
|
||||
import os
|
||||
if "RAY_PDB" in os.environ:
|
||||
ray.util.pdb.post_mortem()
|
||||
result = wrap_to_ray_error(e)
|
||||
self.error_counter.record(1)
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ py_test_module_list(
|
||||
"test_node_manager.py",
|
||||
"test_numba.py",
|
||||
"test_queue.py",
|
||||
"test_ray_debugger.py",
|
||||
"test_ray_init.py",
|
||||
"test_tempfile.py",
|
||||
],
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from telnetlib import Telnet
|
||||
|
||||
import ray
|
||||
|
||||
|
||||
def test_ray_debugger_breakpoint(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
ray.util.pdb.set_trace()
|
||||
return 1
|
||||
|
||||
result = f.remote()
|
||||
|
||||
# Wait until the breakpoint is hit:
|
||||
while True:
|
||||
active_sessions = ray.experimental.internal_kv._internal_kv_list(
|
||||
"RAY_PDB_")
|
||||
if len(active_sessions) > 0:
|
||||
break
|
||||
|
||||
# Now continue execution:
|
||||
session = json.loads(
|
||||
ray.experimental.internal_kv._internal_kv_get(active_sessions[0]))
|
||||
host, port = session["pdb_address"].split(":")
|
||||
tn = Telnet(host, int(port))
|
||||
tn.write(b"c\n")
|
||||
|
||||
# This should succeed now!
|
||||
ray.get(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
# Make subprocess happy in bazel.
|
||||
os.environ["LC_ALL"] = "en_US.UTF-8"
|
||||
os.environ["LANG"] = "en_US.UTF-8"
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -4,9 +4,10 @@ from ray.util.debug import log_once, disable_log_once_globally, \
|
||||
enable_periodic_logging
|
||||
from ray.util.placement_group import (placement_group, placement_group_table,
|
||||
remove_placement_group)
|
||||
from ray.util import rpdb as pdb
|
||||
|
||||
__all__ = [
|
||||
"ActorPool", "disable_log_once_globally", "enable_periodic_logging",
|
||||
"iter", "log_once", "placement_group", "placement_group_table",
|
||||
"iter", "log_once", "pdb", "placement_group", "placement_group_table",
|
||||
"remove_placement_group"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,181 @@
|
||||
# Some code in this file is from
|
||||
# https://github.com/ionelmc/python-remote-pdb/blob/07d563331c4ab9eb45731bb272b158816d98236e/src/remote_pdb.py
|
||||
# (BSD 2-Clause "Simplified" License)
|
||||
|
||||
import errno
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import uuid
|
||||
from pdb import Pdb
|
||||
import setproctitle
|
||||
import traceback
|
||||
|
||||
from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cry(message, stderr=sys.__stderr__):
|
||||
print(message, file=stderr)
|
||||
stderr.flush()
|
||||
|
||||
|
||||
class LF2CRLF_FileWrapper(object):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self.stream = fh = connection.makefile("rw")
|
||||
self.read = fh.read
|
||||
self.readline = fh.readline
|
||||
self.readlines = fh.readlines
|
||||
self.close = fh.close
|
||||
self.flush = fh.flush
|
||||
self.fileno = fh.fileno
|
||||
if hasattr(fh, "encoding"):
|
||||
self._send = lambda data: connection.sendall(
|
||||
data.encode(fh.encoding))
|
||||
else:
|
||||
self._send = connection.sendall
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
return self.stream.encoding
|
||||
|
||||
def __iter__(self):
|
||||
return self.stream.__iter__()
|
||||
|
||||
def write(self, data, nl_rex=re.compile("\r?\n")):
|
||||
data = nl_rex.sub("\r\n", data)
|
||||
self._send(data)
|
||||
|
||||
def writelines(self, lines, nl_rex=re.compile("\r?\n")):
|
||||
for line in lines:
|
||||
self.write(line, nl_rex)
|
||||
|
||||
|
||||
class RemotePdb(Pdb):
|
||||
"""
|
||||
This will run pdb as a ephemeral telnet service. Once you connect no one
|
||||
else can connect. On construction this object will block execution till a
|
||||
client has connected.
|
||||
Based on https://github.com/tamentis/rpdb I think ...
|
||||
To use this::
|
||||
RemotePdb(host="0.0.0.0", port=4444).set_trace()
|
||||
Then run: telnet 127.0.0.1 4444
|
||||
"""
|
||||
active_instance = None
|
||||
|
||||
def __init__(self, host, port, patch_stdstreams=False, quiet=False):
|
||||
self._quiet = quiet
|
||||
self._patch_stdstreams = patch_stdstreams
|
||||
self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._listen_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
|
||||
True)
|
||||
self._listen_socket.bind((host, port))
|
||||
|
||||
def listen(self):
|
||||
if not self._quiet:
|
||||
cry("RemotePdb session open at %s:%s, "
|
||||
"use 'ray debug' to connect..." %
|
||||
self._listen_socket.getsockname())
|
||||
self._listen_socket.listen(1)
|
||||
connection, address = self._listen_socket.accept()
|
||||
if not self._quiet:
|
||||
cry("RemotePdb accepted connection from %s." % repr(address))
|
||||
self.handle = LF2CRLF_FileWrapper(connection)
|
||||
Pdb.__init__(
|
||||
self, completekey="tab", stdin=self.handle, stdout=self.handle)
|
||||
self.backup = []
|
||||
if self._patch_stdstreams:
|
||||
for name in (
|
||||
"stderr",
|
||||
"stdout",
|
||||
"__stderr__",
|
||||
"__stdout__",
|
||||
"stdin",
|
||||
"__stdin__",
|
||||
):
|
||||
self.backup.append((name, getattr(sys, name)))
|
||||
setattr(sys, name, self.handle)
|
||||
RemotePdb.active_instance = self
|
||||
|
||||
def __restore(self):
|
||||
if self.backup and not self._quiet:
|
||||
cry("Restoring streams: %s ..." % self.backup)
|
||||
for name, fh in self.backup:
|
||||
setattr(sys, name, fh)
|
||||
self.handle.close()
|
||||
RemotePdb.active_instance = None
|
||||
|
||||
def do_quit(self, arg):
|
||||
self.__restore()
|
||||
return Pdb.do_quit(self, arg)
|
||||
|
||||
do_q = do_exit = do_quit
|
||||
|
||||
def set_trace(self, frame=None):
|
||||
if frame is None:
|
||||
frame = sys._getframe().f_back
|
||||
try:
|
||||
Pdb.set_trace(self, frame)
|
||||
except IOError as exc:
|
||||
if exc.errno != errno.ECONNRESET:
|
||||
raise
|
||||
|
||||
def post_mortem(self, traceback=None):
|
||||
# See https://github.com/python/cpython/blob/
|
||||
# 022bc7572f061e1d1132a4db9d085b29707701e7/Lib/pdb.py#L1617
|
||||
try:
|
||||
t = sys.exc_info()[2]
|
||||
self.reset()
|
||||
Pdb.interaction(self, None, t)
|
||||
except IOError as exc:
|
||||
if exc.errno != errno.ECONNRESET:
|
||||
raise
|
||||
|
||||
|
||||
def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
"""
|
||||
Opens a remote PDB on first available port.
|
||||
"""
|
||||
if host is None:
|
||||
host = os.environ.get("REMOTE_PDB_HOST", "127.0.0.1")
|
||||
if port is None:
|
||||
port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
|
||||
if quiet is None:
|
||||
quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
|
||||
rdb = RemotePdb(
|
||||
host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet)
|
||||
sockname = rdb._listen_socket.getsockname()
|
||||
pdb_address = "{}:{}".format(sockname[0], sockname[1])
|
||||
parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
|
||||
data = {
|
||||
"proctitle": setproctitle.getproctitle(),
|
||||
"pdb_address": pdb_address,
|
||||
"filename": parentframeinfo.filename,
|
||||
"lineno": parentframeinfo.lineno,
|
||||
"traceback": "\n".join(traceback.format_exception(*sys.exc_info()))
|
||||
}
|
||||
breakpoint_uuid = uuid.uuid4()
|
||||
_internal_kv_put(
|
||||
"RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True)
|
||||
rdb.listen()
|
||||
_internal_kv_del("RAY_PDB_{}".format(breakpoint_uuid))
|
||||
|
||||
return rdb
|
||||
|
||||
|
||||
def set_trace(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
frame = sys._getframe().f_back
|
||||
rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet)
|
||||
rdb.set_trace(frame=frame)
|
||||
|
||||
|
||||
def post_mortem(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
rdb = connect_ray_pdb(host, port, patch_stdstreams, quiet)
|
||||
rdb.post_mortem()
|
||||
Reference in New Issue
Block a user