mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:15:35 +08:00
Ray debugger stepping between tasks (#12075)
This commit is contained in:
+32
-4
@@ -335,6 +335,7 @@ cdef execute_task(
|
||||
const c_vector[shared_ptr[CRayObject]] &c_args,
|
||||
const c_vector[CObjectID] &c_arg_reference_ids,
|
||||
const c_vector[CObjectID] &c_return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns):
|
||||
|
||||
worker = ray.worker.global_worker
|
||||
@@ -456,7 +457,23 @@ cdef execute_task(
|
||||
task_exception = True
|
||||
try:
|
||||
with ray.worker._changeproctitle(title, next_title):
|
||||
if debugger_breakpoint != b"":
|
||||
ray.util.pdb.set_trace(
|
||||
breakpoint_uuid=debugger_breakpoint)
|
||||
outputs = function_executor(*args, **kwargs)
|
||||
next_breakpoint = (
|
||||
ray.worker.global_worker.debugger_breakpoint)
|
||||
if next_breakpoint != b"":
|
||||
# If this happens, the user typed "remote" and
|
||||
# there were no more remote calls left in this
|
||||
# task. In that case we just exit the debugger.
|
||||
ray.experimental.internal_kv._internal_kv_put(
|
||||
"RAY_PDB_{}".format(next_breakpoint),
|
||||
"{\"exit_debugger\": true}")
|
||||
ray.experimental.internal_kv._internal_kv_del(
|
||||
"RAY_PDB_CONTINUE_{}".format(next_breakpoint)
|
||||
)
|
||||
ray.worker.global_worker.debugger_breakpoint = b""
|
||||
task_exception = False
|
||||
except KeyboardInterrupt as e:
|
||||
raise TaskCancelledError(
|
||||
@@ -522,6 +539,7 @@ cdef CRayStatus task_execution_handler(
|
||||
const c_vector[shared_ptr[CRayObject]] &c_args,
|
||||
const c_vector[CObjectID] &c_arg_reference_ids,
|
||||
const c_vector[CObjectID] &c_return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns) nogil:
|
||||
|
||||
with gil:
|
||||
@@ -531,7 +549,7 @@ cdef CRayStatus task_execution_handler(
|
||||
# it does, that indicates that there was an internal error.
|
||||
execute_task(task_type, task_name, ray_function, c_resources,
|
||||
c_args, c_arg_reference_ids, c_return_ids,
|
||||
returns)
|
||||
debugger_breakpoint, returns)
|
||||
except Exception:
|
||||
traceback_str = traceback.format_exc() + (
|
||||
"An unexpected internal error occurred while the worker "
|
||||
@@ -1040,6 +1058,7 @@ cdef class CoreWorker:
|
||||
PlacementGroupID placement_group_id,
|
||||
int64_t placement_group_bundle_index,
|
||||
c_bool placement_group_capture_child_tasks,
|
||||
c_string debugger_breakpoint,
|
||||
override_environment_variables):
|
||||
cdef:
|
||||
unordered_map[c_string, double] c_resources
|
||||
@@ -1066,7 +1085,8 @@ cdef class CoreWorker:
|
||||
&return_ids, max_retries,
|
||||
c_pair[CPlacementGroupID, int64_t](
|
||||
c_placement_group_id, placement_group_bundle_index),
|
||||
placement_group_capture_child_tasks)
|
||||
placement_group_capture_child_tasks,
|
||||
debugger_breakpoint)
|
||||
|
||||
return VectorToObjectRefs(return_ids)
|
||||
|
||||
@@ -1411,8 +1431,16 @@ cdef class CoreWorker:
|
||||
context = worker.get_serialization_context()
|
||||
serialized_object = context.serialize(output)
|
||||
data_sizes.push_back(serialized_object.total_bytes)
|
||||
metadatas.push_back(
|
||||
string_to_buffer(serialized_object.metadata))
|
||||
metadata = serialized_object.metadata
|
||||
if ray.worker.global_worker.debugger_get_breakpoint:
|
||||
breakpoint = (
|
||||
ray.worker.global_worker.debugger_get_breakpoint())
|
||||
metadata += (
|
||||
b"," + ray_constants.OBJECT_METADATA_DEBUG_PREFIX +
|
||||
breakpoint.encode())
|
||||
# Reset debugging context of this worker.
|
||||
ray.worker.global_worker.debugger_get_breakpoint = b""
|
||||
metadatas.push_back(string_to_buffer(metadata))
|
||||
serialized_objects.append(serialized_object)
|
||||
contained_ids.push_back(
|
||||
ObjectRefsToVector(serialized_object.contained_object_refs)
|
||||
|
||||
@@ -90,7 +90,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
|
||||
int max_retries,
|
||||
c_pair[CPlacementGroupID, int64_t] placement_options,
|
||||
c_bool placement_group_capture_child_tasks)
|
||||
c_bool placement_group_capture_child_tasks,
|
||||
c_string debugger_breakpoint)
|
||||
CRayStatus CreateActor(
|
||||
const CRayFunction &function,
|
||||
const c_vector[unique_ptr[CTaskArg]] &args,
|
||||
@@ -224,6 +225,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
const c_vector[shared_ptr[CRayObject]] &args,
|
||||
const c_vector[CObjectID] &arg_reference_ids,
|
||||
const c_vector[CObjectID] &return_ids,
|
||||
const c_string debugger_breakpoint,
|
||||
c_vector[shared_ptr[CRayObject]] *returns) nogil
|
||||
) task_execution_callback
|
||||
(void(const CWorkerID &) nogil) on_worker_shutdown
|
||||
|
||||
@@ -197,7 +197,8 @@ LOG_MONITOR_MAX_OPEN_FILES = 200
|
||||
# The object metadata field uses the following format: It is a comma
|
||||
# separated list of fields. The first field is mandatory and is the
|
||||
# type of the object (see types below) or an integer, which is interpreted
|
||||
# as an error value.
|
||||
# as an error value. The second part is optional and if present has the
|
||||
# form DEBUG:<breakpoint_id>, it is used for implementing the debugger.
|
||||
|
||||
# A constant used as object metadata to indicate the object is cross language.
|
||||
OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG"
|
||||
@@ -213,6 +214,9 @@ OBJECT_METADATA_TYPE_RAW = b"RAW"
|
||||
# of XLANG.
|
||||
OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE"
|
||||
|
||||
# A constant indicating the debugging part of the metadata (see above).
|
||||
OBJECT_METADATA_DEBUG_PREFIX = b"DEBUG:"
|
||||
|
||||
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
|
||||
|
||||
# The default password to prevent redis port scanning attack.
|
||||
|
||||
@@ -258,8 +258,12 @@ class RemoteFunction:
|
||||
placement_group.id,
|
||||
placement_group_bundle_index,
|
||||
placement_group_capture_child_tasks,
|
||||
worker.debugger_breakpoint,
|
||||
override_environment_variables=override_environment_variables
|
||||
or dict())
|
||||
# Reset worker's debug context from the last "remote" command
|
||||
# (which applies only to this .remote call).
|
||||
worker.debugger_breakpoint = b""
|
||||
if len(object_refs) == 1:
|
||||
return object_refs[0]
|
||||
elif len(object_refs) > 1:
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from telnetlib import Telnet
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
@@ -150,6 +151,35 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
|
||||
from None
|
||||
|
||||
|
||||
def continue_debug_session():
|
||||
"""Continue active debugging session.
|
||||
|
||||
This function will connect 'ray debug' to the right debugger
|
||||
when a user is stepping between Ray tasks.
|
||||
"""
|
||||
active_sessions = ray.experimental.internal_kv._internal_kv_list(
|
||||
"RAY_PDB_")
|
||||
|
||||
for active_session in active_sessions:
|
||||
if active_session.startswith(b"RAY_PDB_CONTINUE"):
|
||||
print("Continuing pdb session in different process...")
|
||||
key = b"RAY_PDB_" + active_session[len("RAY_PDB_CONTINUE_"):]
|
||||
while True:
|
||||
data = ray.experimental.internal_kv._internal_kv_get(key)
|
||||
if data:
|
||||
session = json.loads(data)
|
||||
if "exit_debugger" in session:
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
return
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
continue_debug_session()
|
||||
return
|
||||
time.sleep(1.0)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option(
|
||||
"--address",
|
||||
@@ -158,12 +188,13 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
|
||||
help="Override the address to connect to.")
|
||||
def debug(address):
|
||||
"""Show all active breakpoints and exceptions in the Ray debugger."""
|
||||
from telnetlib import Telnet
|
||||
if not address:
|
||||
address = services.get_ray_address_to_use_or_die()
|
||||
logger.info(f"Connecting to Ray instance at {address}.")
|
||||
ray.init(address=address)
|
||||
ray.init(address=address, log_to_driver=False)
|
||||
while True:
|
||||
continue_debug_session()
|
||||
|
||||
active_sessions = ray.experimental.internal_kv._internal_kv_list(
|
||||
"RAY_PDB_")
|
||||
print("Active breakpoints:")
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import sys
|
||||
from telnetlib import Telnet
|
||||
|
||||
import pexpect
|
||||
import ray
|
||||
|
||||
|
||||
@@ -34,6 +35,67 @@ def test_ray_debugger_breakpoint(shutdown_only):
|
||||
ray.get(result)
|
||||
|
||||
|
||||
def test_ray_debugger_stepping(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def g():
|
||||
return None
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
ray.util.pdb.set_trace()
|
||||
x = g.remote()
|
||||
return ray.get(x)
|
||||
|
||||
result = f.remote()
|
||||
|
||||
p = pexpect.spawn("ray debug")
|
||||
p.expect("Enter breakpoint index or press enter to refresh: ")
|
||||
p.sendline("0")
|
||||
p.expect("-> x = g.remote()")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("get")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("continue")
|
||||
|
||||
# This should succeed now!
|
||||
ray.get(result)
|
||||
|
||||
|
||||
def test_ray_debugger_recursive(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
@ray.remote
|
||||
def fact(n):
|
||||
if n < 1:
|
||||
return n
|
||||
ray.util.pdb.set_trace()
|
||||
n_id = fact.remote(n - 1)
|
||||
return n * ray.get(n_id)
|
||||
|
||||
result = fact.remote(5)
|
||||
|
||||
p = pexpect.spawn("ray debug")
|
||||
p.expect("Enter breakpoint index or press enter to refresh: ")
|
||||
p.sendline("0")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
p.expect("(Pdb)")
|
||||
p.sendline("remote")
|
||||
|
||||
ray.get(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
# Make subprocess happy in bazel.
|
||||
|
||||
+52
-8
@@ -15,6 +15,7 @@ from pdb import Pdb
|
||||
import setproctitle
|
||||
import traceback
|
||||
|
||||
import ray
|
||||
from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
|
||||
|
||||
PY3 = sys.version_info[0] == 3
|
||||
@@ -70,7 +71,13 @@ class RemotePdb(Pdb):
|
||||
"""
|
||||
active_instance = None
|
||||
|
||||
def __init__(self, host, port, patch_stdstreams=False, quiet=False):
|
||||
def __init__(self,
|
||||
breakpoint_uuid,
|
||||
host,
|
||||
port,
|
||||
patch_stdstreams=False,
|
||||
quiet=False):
|
||||
self._breakpoint_uuid = breakpoint_uuid
|
||||
self._quiet = quiet
|
||||
self._patch_stdstreams = patch_stdstreams
|
||||
self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
@@ -138,8 +145,35 @@ class RemotePdb(Pdb):
|
||||
if exc.errno != errno.ECONNRESET:
|
||||
raise
|
||||
|
||||
def do_remote(self, arg):
|
||||
"""remote
|
||||
Skip into the next remote call.
|
||||
"""
|
||||
# Tell the next task to drop into the debugger.
|
||||
ray.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid
|
||||
# Tell the debug loop to connect to the next task.
|
||||
_internal_kv_put("RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid),
|
||||
"")
|
||||
self.__restore()
|
||||
self.handle.connection.close()
|
||||
return Pdb.do_continue(self, arg)
|
||||
|
||||
def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
def do_get(self, arg):
|
||||
"""get
|
||||
Skip to where the current task returns to.
|
||||
"""
|
||||
ray.worker.global_worker.debugger_get_breakpoint = (
|
||||
self._breakpoint_uuid)
|
||||
self.__restore()
|
||||
self.handle.connection.close()
|
||||
return Pdb.do_continue(self, arg)
|
||||
|
||||
|
||||
def connect_ray_pdb(host=None,
|
||||
port=None,
|
||||
patch_stdstreams=False,
|
||||
quiet=None,
|
||||
breakpoint_uuid=None):
|
||||
"""
|
||||
Opens a remote PDB on first available port.
|
||||
"""
|
||||
@@ -149,8 +183,14 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
|
||||
if quiet is None:
|
||||
quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
|
||||
if not breakpoint_uuid:
|
||||
breakpoint_uuid = uuid.uuid4().hex
|
||||
rdb = RemotePdb(
|
||||
host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet)
|
||||
breakpoint_uuid=breakpoint_uuid,
|
||||
host=host,
|
||||
port=port,
|
||||
patch_stdstreams=patch_stdstreams,
|
||||
quiet=quiet)
|
||||
sockname = rdb._listen_socket.getsockname()
|
||||
pdb_address = "{}:{}".format(sockname[0], sockname[1])
|
||||
parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
|
||||
@@ -161,7 +201,6 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
"lineno": parentframeinfo.lineno,
|
||||
"traceback": "\n".join(traceback.format_exception(*sys.exc_info()))
|
||||
}
|
||||
breakpoint_uuid = uuid.uuid4()
|
||||
_internal_kv_put(
|
||||
"RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True)
|
||||
rdb.listen()
|
||||
@@ -170,14 +209,19 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
|
||||
return rdb
|
||||
|
||||
|
||||
def set_trace():
|
||||
def set_trace(breakpoint_uuid=None):
|
||||
"""Interrupt the flow of the program and drop into the Ray debugger.
|
||||
|
||||
Can be used within a Ray task or actor.
|
||||
"""
|
||||
frame = sys._getframe().f_back
|
||||
rdb = connect_ray_pdb(None, None, False, None)
|
||||
rdb.set_trace(frame=frame)
|
||||
# If there is an active debugger already, we do not want to
|
||||
# start another one, so "set_trace" is just a no-op in that case.
|
||||
if ray.worker.global_worker.debugger_breakpoint == b"":
|
||||
frame = sys._getframe().f_back
|
||||
rdb = connect_ray_pdb(
|
||||
None, None, False, None,
|
||||
breakpoint_uuid.decode() if breakpoint_uuid else None)
|
||||
rdb.set_trace(frame=frame)
|
||||
|
||||
|
||||
def post_mortem():
|
||||
|
||||
+31
-2
@@ -102,6 +102,13 @@ class Worker:
|
||||
# Index of the current session. This number will
|
||||
# increment every time when `ray.shutdown` is called.
|
||||
self._session_index = 0
|
||||
# If this is set, the next .remote call should drop into the
|
||||
# debugger, at the specified breakpoint ID.
|
||||
self.debugger_breakpoint = b""
|
||||
# If this is set, ray.get calls invoked on the object ID returned
|
||||
# by the worker should drop into the debugger at the specified
|
||||
# breakpoint ID.
|
||||
self.debugger_get_breakpoint = b""
|
||||
|
||||
@property
|
||||
def connected(self):
|
||||
@@ -280,6 +287,10 @@ class Worker:
|
||||
whose values should be retrieved.
|
||||
timeout (float): timeout (float): The maximum amount of time in
|
||||
seconds to wait before returning.
|
||||
Returns:
|
||||
list: List of deserialized objects
|
||||
bytes: UUID of the debugger breakpoint we should drop
|
||||
into or b"" if there is no breakpoint.
|
||||
"""
|
||||
# Make sure that the values are object refs.
|
||||
for object_ref in object_refs:
|
||||
@@ -291,7 +302,16 @@ class Worker:
|
||||
timeout_ms = int(timeout * 1000) if timeout else -1
|
||||
data_metadata_pairs = self.core_worker.get_objects(
|
||||
object_refs, self.current_task_id, timeout_ms)
|
||||
return self.deserialize_objects(data_metadata_pairs, object_refs)
|
||||
debugger_breakpoint = b""
|
||||
for (data, metadata) in data_metadata_pairs:
|
||||
if metadata:
|
||||
metadata_fields = metadata.split(b",")
|
||||
if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
|
||||
ray_constants.OBJECT_METADATA_DEBUG_PREFIX):
|
||||
debugger_breakpoint = metadata_fields[1][len(
|
||||
ray_constants.OBJECT_METADATA_DEBUG_PREFIX):]
|
||||
return self.deserialize_objects(data_metadata_pairs,
|
||||
object_refs), debugger_breakpoint
|
||||
|
||||
def run_function_on_all_workers(self, function,
|
||||
run_on_other_drivers=False):
|
||||
@@ -1345,7 +1365,8 @@ def get(object_refs, *, timeout=None):
|
||||
|
||||
global last_task_error_raise_time
|
||||
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
|
||||
values = worker.get_objects(object_refs, timeout=timeout)
|
||||
values, debugger_breakpoint = worker.get_objects(
|
||||
object_refs, timeout=timeout)
|
||||
for i, value in enumerate(values):
|
||||
if isinstance(value, RayError):
|
||||
last_task_error_raise_time = time.time()
|
||||
@@ -1358,6 +1379,14 @@ def get(object_refs, *, timeout=None):
|
||||
|
||||
if is_individual_id:
|
||||
values = values[0]
|
||||
|
||||
if debugger_breakpoint != b"":
|
||||
frame = sys._getframe().f_back
|
||||
rdb = ray.util.pdb.connect_ray_pdb(
|
||||
None, None, False, None,
|
||||
debugger_breakpoint.decode() if debugger_breakpoint else None)
|
||||
rdb.set_trace(frame=frame)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ numba
|
||||
# higher version of llvmlite breaks windows
|
||||
llvmlite==0.34.0
|
||||
openpyxl
|
||||
pexpect
|
||||
Pillow; platform_system != "Windows"
|
||||
pygments
|
||||
pytest==5.4.3
|
||||
|
||||
Reference in New Issue
Block a user