Ray debugger stepping between tasks (#12075)

This commit is contained in:
Philipp Moritz
2020-12-06 21:50:18 -08:00
committed by GitHub
parent 260b07cf0c
commit 73a1a232b9
28 changed files with 267 additions and 40 deletions
+32 -4
View File
@@ -335,6 +335,7 @@ cdef execute_task(
const c_vector[shared_ptr[CRayObject]] &c_args,
const c_vector[CObjectID] &c_arg_reference_ids,
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns):
worker = ray.worker.global_worker
@@ -456,7 +457,23 @@ cdef execute_task(
task_exception = True
try:
with ray.worker._changeproctitle(title, next_title):
if debugger_breakpoint != b"":
ray.util.pdb.set_trace(
breakpoint_uuid=debugger_breakpoint)
outputs = function_executor(*args, **kwargs)
next_breakpoint = (
ray.worker.global_worker.debugger_breakpoint)
if next_breakpoint != b"":
# If this happens, the user typed "remote" and
# there were no more remote calls left in this
# task. In that case we just exit the debugger.
ray.experimental.internal_kv._internal_kv_put(
"RAY_PDB_{}".format(next_breakpoint),
"{\"exit_debugger\": true}")
ray.experimental.internal_kv._internal_kv_del(
"RAY_PDB_CONTINUE_{}".format(next_breakpoint)
)
ray.worker.global_worker.debugger_breakpoint = b""
task_exception = False
except KeyboardInterrupt as e:
raise TaskCancelledError(
@@ -522,6 +539,7 @@ cdef CRayStatus task_execution_handler(
const c_vector[shared_ptr[CRayObject]] &c_args,
const c_vector[CObjectID] &c_arg_reference_ids,
const c_vector[CObjectID] &c_return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns) nogil:
with gil:
@@ -531,7 +549,7 @@ cdef CRayStatus task_execution_handler(
# it does, that indicates that there was an internal error.
execute_task(task_type, task_name, ray_function, c_resources,
c_args, c_arg_reference_ids, c_return_ids,
returns)
debugger_breakpoint, returns)
except Exception:
traceback_str = traceback.format_exc() + (
"An unexpected internal error occurred while the worker "
@@ -1040,6 +1058,7 @@ cdef class CoreWorker:
PlacementGroupID placement_group_id,
int64_t placement_group_bundle_index,
c_bool placement_group_capture_child_tasks,
c_string debugger_breakpoint,
override_environment_variables):
cdef:
unordered_map[c_string, double] c_resources
@@ -1066,7 +1085,8 @@ cdef class CoreWorker:
&return_ids, max_retries,
c_pair[CPlacementGroupID, int64_t](
c_placement_group_id, placement_group_bundle_index),
placement_group_capture_child_tasks)
placement_group_capture_child_tasks,
debugger_breakpoint)
return VectorToObjectRefs(return_ids)
@@ -1411,8 +1431,16 @@ cdef class CoreWorker:
context = worker.get_serialization_context()
serialized_object = context.serialize(output)
data_sizes.push_back(serialized_object.total_bytes)
metadatas.push_back(
string_to_buffer(serialized_object.metadata))
metadata = serialized_object.metadata
if ray.worker.global_worker.debugger_get_breakpoint:
breakpoint = (
ray.worker.global_worker.debugger_get_breakpoint())
metadata += (
b"," + ray_constants.OBJECT_METADATA_DEBUG_PREFIX +
breakpoint.encode())
# Reset debugging context of this worker.
ray.worker.global_worker.debugger_get_breakpoint = b""
metadatas.push_back(string_to_buffer(metadata))
serialized_objects.append(serialized_object)
contained_ids.push_back(
ObjectRefsToVector(serialized_object.contained_object_refs)
+3 -1
View File
@@ -90,7 +90,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const CTaskOptions &options, c_vector[CObjectID] *return_ids,
int max_retries,
c_pair[CPlacementGroupID, int64_t] placement_options,
c_bool placement_group_capture_child_tasks)
c_bool placement_group_capture_child_tasks,
c_string debugger_breakpoint)
CRayStatus CreateActor(
const CRayFunction &function,
const c_vector[unique_ptr[CTaskArg]] &args,
@@ -224,6 +225,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
const c_vector[shared_ptr[CRayObject]] &args,
const c_vector[CObjectID] &arg_reference_ids,
const c_vector[CObjectID] &return_ids,
const c_string debugger_breakpoint,
c_vector[shared_ptr[CRayObject]] *returns) nogil
) task_execution_callback
(void(const CWorkerID &) nogil) on_worker_shutdown
+5 -1
View File
@@ -197,7 +197,8 @@ LOG_MONITOR_MAX_OPEN_FILES = 200
# The object metadata field uses the following format: It is a comma
# separated list of fields. The first field is mandatory and is the
# type of the object (see types below) or an integer, which is interpreted
# as an error value.
# as an error value. The second part is optional and if present has the
# form DEBUG:<breakpoint_id>, it is used for implementing the debugger.
# A constant used as object metadata to indicate the object is cross language.
OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG"
@@ -213,6 +214,9 @@ OBJECT_METADATA_TYPE_RAW = b"RAW"
# of XLANG.
OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE"
# A constant indicating the debugging part of the metadata (see above).
OBJECT_METADATA_DEBUG_PREFIX = b"DEBUG:"
AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"
# The default password to prevent redis port scanning attack.
+4
View File
@@ -258,8 +258,12 @@ class RemoteFunction:
placement_group.id,
placement_group_bundle_index,
placement_group_capture_child_tasks,
worker.debugger_breakpoint,
override_environment_variables=override_environment_variables
or dict())
# Reset worker's debug context from the last "remote" command
# (which applies only to this .remote call).
worker.debugger_breakpoint = b""
if len(object_refs) == 1:
return object_refs[0]
elif len(object_refs) > 1:
+33 -2
View File
@@ -6,6 +6,7 @@ import logging
import os
import subprocess
import sys
from telnetlib import Telnet
import time
import urllib
import urllib.parse
@@ -150,6 +151,35 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
from None
def continue_debug_session():
"""Continue active debugging session.
This function will connect 'ray debug' to the right debugger
when a user is stepping between Ray tasks.
"""
active_sessions = ray.experimental.internal_kv._internal_kv_list(
"RAY_PDB_")
for active_session in active_sessions:
if active_session.startswith(b"RAY_PDB_CONTINUE"):
print("Continuing pdb session in different process...")
key = b"RAY_PDB_" + active_session[len("RAY_PDB_CONTINUE_"):]
while True:
data = ray.experimental.internal_kv._internal_kv_get(key)
if data:
session = json.loads(data)
if "exit_debugger" in session:
ray.experimental.internal_kv._internal_kv_del(key)
return
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
ray.experimental.internal_kv._internal_kv_del(key)
continue_debug_session()
return
time.sleep(1.0)
@cli.command()
@click.option(
"--address",
@@ -158,12 +188,13 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port):
help="Override the address to connect to.")
def debug(address):
"""Show all active breakpoints and exceptions in the Ray debugger."""
from telnetlib import Telnet
if not address:
address = services.get_ray_address_to_use_or_die()
logger.info(f"Connecting to Ray instance at {address}.")
ray.init(address=address)
ray.init(address=address, log_to_driver=False)
while True:
continue_debug_session()
active_sessions = ray.experimental.internal_kv._internal_kv_list(
"RAY_PDB_")
print("Active breakpoints:")
+62
View File
@@ -3,6 +3,7 @@ import os
import sys
from telnetlib import Telnet
import pexpect
import ray
@@ -34,6 +35,67 @@ def test_ray_debugger_breakpoint(shutdown_only):
ray.get(result)
def test_ray_debugger_stepping(shutdown_only):
ray.init(num_cpus=1)
@ray.remote
def g():
return None
@ray.remote
def f():
ray.util.pdb.set_trace()
x = g.remote()
return ray.get(x)
result = f.remote()
p = pexpect.spawn("ray debug")
p.expect("Enter breakpoint index or press enter to refresh: ")
p.sendline("0")
p.expect("-> x = g.remote()")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("get")
p.expect("(Pdb)")
p.sendline("continue")
# This should succeed now!
ray.get(result)
def test_ray_debugger_recursive(shutdown_only):
ray.init(num_cpus=1)
@ray.remote
def fact(n):
if n < 1:
return n
ray.util.pdb.set_trace()
n_id = fact.remote(n - 1)
return n * ray.get(n_id)
result = fact.remote(5)
p = pexpect.spawn("ray debug")
p.expect("Enter breakpoint index or press enter to refresh: ")
p.sendline("0")
p.expect("(Pdb)")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("remote")
p.expect("(Pdb)")
p.sendline("remote")
ray.get(result)
if __name__ == "__main__":
import pytest
# Make subprocess happy in bazel.
+52 -8
View File
@@ -15,6 +15,7 @@ from pdb import Pdb
import setproctitle
import traceback
import ray
from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put
PY3 = sys.version_info[0] == 3
@@ -70,7 +71,13 @@ class RemotePdb(Pdb):
"""
active_instance = None
def __init__(self, host, port, patch_stdstreams=False, quiet=False):
def __init__(self,
breakpoint_uuid,
host,
port,
patch_stdstreams=False,
quiet=False):
self._breakpoint_uuid = breakpoint_uuid
self._quiet = quiet
self._patch_stdstreams = patch_stdstreams
self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -138,8 +145,35 @@ class RemotePdb(Pdb):
if exc.errno != errno.ECONNRESET:
raise
def do_remote(self, arg):
"""remote
Skip into the next remote call.
"""
# Tell the next task to drop into the debugger.
ray.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid
# Tell the debug loop to connect to the next task.
_internal_kv_put("RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid),
"")
self.__restore()
self.handle.connection.close()
return Pdb.do_continue(self, arg)
def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
def do_get(self, arg):
"""get
Skip to where the current task returns to.
"""
ray.worker.global_worker.debugger_get_breakpoint = (
self._breakpoint_uuid)
self.__restore()
self.handle.connection.close()
return Pdb.do_continue(self, arg)
def connect_ray_pdb(host=None,
port=None,
patch_stdstreams=False,
quiet=None,
breakpoint_uuid=None):
"""
Opens a remote PDB on first available port.
"""
@@ -149,8 +183,14 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
port = int(os.environ.get("REMOTE_PDB_PORT", "0"))
if quiet is None:
quiet = bool(os.environ.get("REMOTE_PDB_QUIET", ""))
if not breakpoint_uuid:
breakpoint_uuid = uuid.uuid4().hex
rdb = RemotePdb(
host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet)
breakpoint_uuid=breakpoint_uuid,
host=host,
port=port,
patch_stdstreams=patch_stdstreams,
quiet=quiet)
sockname = rdb._listen_socket.getsockname()
pdb_address = "{}:{}".format(sockname[0], sockname[1])
parentframeinfo = inspect.getouterframes(inspect.currentframe())[2]
@@ -161,7 +201,6 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
"lineno": parentframeinfo.lineno,
"traceback": "\n".join(traceback.format_exception(*sys.exc_info()))
}
breakpoint_uuid = uuid.uuid4()
_internal_kv_put(
"RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True)
rdb.listen()
@@ -170,14 +209,19 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None):
return rdb
def set_trace():
def set_trace(breakpoint_uuid=None):
"""Interrupt the flow of the program and drop into the Ray debugger.
Can be used within a Ray task or actor.
"""
frame = sys._getframe().f_back
rdb = connect_ray_pdb(None, None, False, None)
rdb.set_trace(frame=frame)
# If there is an active debugger already, we do not want to
# start another one, so "set_trace" is just a no-op in that case.
if ray.worker.global_worker.debugger_breakpoint == b"":
frame = sys._getframe().f_back
rdb = connect_ray_pdb(
None, None, False, None,
breakpoint_uuid.decode() if breakpoint_uuid else None)
rdb.set_trace(frame=frame)
def post_mortem():
+31 -2
View File
@@ -102,6 +102,13 @@ class Worker:
# Index of the current session. This number will
# increment every time when `ray.shutdown` is called.
self._session_index = 0
# If this is set, the next .remote call should drop into the
# debugger, at the specified breakpoint ID.
self.debugger_breakpoint = b""
# If this is set, ray.get calls invoked on the object ID returned
# by the worker should drop into the debugger at the specified
# breakpoint ID.
self.debugger_get_breakpoint = b""
@property
def connected(self):
@@ -280,6 +287,10 @@ class Worker:
whose values should be retrieved.
timeout (float): timeout (float): The maximum amount of time in
seconds to wait before returning.
Returns:
list: List of deserialized objects
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
@@ -291,7 +302,16 @@ class Worker:
timeout_ms = int(timeout * 1000) if timeout else -1
data_metadata_pairs = self.core_worker.get_objects(
object_refs, self.current_task_id, timeout_ms)
return self.deserialize_objects(data_metadata_pairs, object_refs)
debugger_breakpoint = b""
for (data, metadata) in data_metadata_pairs:
if metadata:
metadata_fields = metadata.split(b",")
if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
ray_constants.OBJECT_METADATA_DEBUG_PREFIX):
debugger_breakpoint = metadata_fields[1][len(
ray_constants.OBJECT_METADATA_DEBUG_PREFIX):]
return self.deserialize_objects(data_metadata_pairs,
object_refs), debugger_breakpoint
def run_function_on_all_workers(self, function,
run_on_other_drivers=False):
@@ -1345,7 +1365,8 @@ def get(object_refs, *, timeout=None):
global last_task_error_raise_time
# TODO(ujvl): Consider how to allow user to retrieve the ready objects.
values = worker.get_objects(object_refs, timeout=timeout)
values, debugger_breakpoint = worker.get_objects(
object_refs, timeout=timeout)
for i, value in enumerate(values):
if isinstance(value, RayError):
last_task_error_raise_time = time.time()
@@ -1358,6 +1379,14 @@ def get(object_refs, *, timeout=None):
if is_individual_id:
values = values[0]
if debugger_breakpoint != b"":
frame = sys._getframe().f_back
rdb = ray.util.pdb.connect_ray_pdb(
None, None, False, None,
debugger_breakpoint.decode() if debugger_breakpoint else None)
rdb.set_trace(frame=frame)
return values
+1
View File
@@ -57,6 +57,7 @@ numba
# higher version of llvmlite breaks windows
llvmlite==0.34.0
openpyxl
pexpect
Pillow; platform_system != "Windows"
pygments
pytest==5.4.3