diff --git a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc index 3239c44a9..4d2bf1c7c 100644 --- a/cpp/src/ray/runtime/task/local_mode_task_submitter.cc +++ b/cpp/src/ray/runtime/task/local_mode_task_submitter.cc @@ -44,7 +44,7 @@ ObjectID LocalModeTaskSubmitter::Submit(InvocationSpec &invocation) { local_mode_ray_tuntime_.GetCurrentTaskId(), 0, local_mode_ray_tuntime_.GetCurrentTaskId(), address, 1, required_resources, required_placement_resources, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); if (invocation.task_type == TaskType::NORMAL_TASK) { } else if (invocation.task_type == TaskType::ACTOR_CREATION_TASK) { invocation.actor_id = local_mode_ray_tuntime_.GetNextActorID(); diff --git a/cpp/src/ray/runtime/task/native_task_submitter.cc b/cpp/src/ray/runtime/task/native_task_submitter.cc index 7fb4dacc8..4e1649a51 100644 --- a/cpp/src/ray/runtime/task/native_task_submitter.cc +++ b/cpp/src/ray/runtime/task/native_task_submitter.cc @@ -27,7 +27,7 @@ ObjectID NativeTaskSubmitter::Submit(InvocationSpec &invocation) { } else { core_worker.SubmitTask(BuildRayFunction(invocation), invocation.args, TaskOptions(), &return_ids, 1, std::make_pair(PlacementGroupID::Nil(), -1), - true); + true, ""); } return return_ids[0]; } diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index 029d08ebc..f2b06af09 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -27,7 +27,7 @@ Status TaskExecutor::ExecuteTask( const std::unordered_map &required_resources, const std::vector> &args_buffer, const std::vector &arg_reference_ids, - const std::vector &return_ids, + const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results) { RAY_LOG(INFO) << "TaskExecutor::ExecuteTask"; RAY_CHECK(ray_function.GetLanguage() == Language::CPP); diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index 5d02f3e00..cf2d24745 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -38,7 +38,7 @@ class TaskExecutor { const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_reference_ids, - const std::vector &return_ids, + const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results); virtual ~TaskExecutor(){}; diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index c13776a92..15e5d3aff 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -335,6 +335,7 @@ cdef execute_task( const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, + const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns): worker = ray.worker.global_worker @@ -456,7 +457,23 @@ cdef execute_task( task_exception = True try: with ray.worker._changeproctitle(title, next_title): + if debugger_breakpoint != b"": + ray.util.pdb.set_trace( + breakpoint_uuid=debugger_breakpoint) outputs = function_executor(*args, **kwargs) + next_breakpoint = ( + ray.worker.global_worker.debugger_breakpoint) + if next_breakpoint != b"": + # If this happens, the user typed "remote" and + # there were no more remote calls left in this + # task. In that case we just exit the debugger. + ray.experimental.internal_kv._internal_kv_put( + "RAY_PDB_{}".format(next_breakpoint), + "{\"exit_debugger\": true}") + ray.experimental.internal_kv._internal_kv_del( + "RAY_PDB_CONTINUE_{}".format(next_breakpoint) + ) + ray.worker.global_worker.debugger_breakpoint = b"" task_exception = False except KeyboardInterrupt as e: raise TaskCancelledError( @@ -522,6 +539,7 @@ cdef CRayStatus task_execution_handler( const c_vector[shared_ptr[CRayObject]] &c_args, const c_vector[CObjectID] &c_arg_reference_ids, const c_vector[CObjectID] &c_return_ids, + const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns) nogil: with gil: @@ -531,7 +549,7 @@ cdef CRayStatus task_execution_handler( # it does, that indicates that there was an internal error. execute_task(task_type, task_name, ray_function, c_resources, c_args, c_arg_reference_ids, c_return_ids, - returns) + debugger_breakpoint, returns) except Exception: traceback_str = traceback.format_exc() + ( "An unexpected internal error occurred while the worker " @@ -1040,6 +1058,7 @@ cdef class CoreWorker: PlacementGroupID placement_group_id, int64_t placement_group_bundle_index, c_bool placement_group_capture_child_tasks, + c_string debugger_breakpoint, override_environment_variables): cdef: unordered_map[c_string, double] c_resources @@ -1066,7 +1085,8 @@ cdef class CoreWorker: &return_ids, max_retries, c_pair[CPlacementGroupID, int64_t]( c_placement_group_id, placement_group_bundle_index), - placement_group_capture_child_tasks) + placement_group_capture_child_tasks, + debugger_breakpoint) return VectorToObjectRefs(return_ids) @@ -1411,8 +1431,16 @@ cdef class CoreWorker: context = worker.get_serialization_context() serialized_object = context.serialize(output) data_sizes.push_back(serialized_object.total_bytes) - metadatas.push_back( - string_to_buffer(serialized_object.metadata)) + metadata = serialized_object.metadata + if ray.worker.global_worker.debugger_get_breakpoint: + breakpoint = ( + ray.worker.global_worker.debugger_get_breakpoint()) + metadata += ( + b"," + ray_constants.OBJECT_METADATA_DEBUG_PREFIX + + breakpoint.encode()) + # Reset debugging context of this worker. + ray.worker.global_worker.debugger_get_breakpoint = b"" + metadatas.push_back(string_to_buffer(metadata)) serialized_objects.append(serialized_object) contained_ids.push_back( ObjectRefsToVector(serialized_object.contained_object_refs) diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 8abb45b49..abf1290b9 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -90,7 +90,8 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const CTaskOptions &options, c_vector[CObjectID] *return_ids, int max_retries, c_pair[CPlacementGroupID, int64_t] placement_options, - c_bool placement_group_capture_child_tasks) + c_bool placement_group_capture_child_tasks, + c_string debugger_breakpoint) CRayStatus CreateActor( const CRayFunction &function, const c_vector[unique_ptr[CTaskArg]] &args, @@ -224,6 +225,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: const c_vector[shared_ptr[CRayObject]] &args, const c_vector[CObjectID] &arg_reference_ids, const c_vector[CObjectID] &return_ids, + const c_string debugger_breakpoint, c_vector[shared_ptr[CRayObject]] *returns) nogil ) task_execution_callback (void(const CWorkerID &) nogil) on_worker_shutdown diff --git a/python/ray/ray_constants.py b/python/ray/ray_constants.py index 12e429ff5..be717ca3c 100644 --- a/python/ray/ray_constants.py +++ b/python/ray/ray_constants.py @@ -197,7 +197,8 @@ LOG_MONITOR_MAX_OPEN_FILES = 200 # The object metadata field uses the following format: It is a comma # separated list of fields. The first field is mandatory and is the # type of the object (see types below) or an integer, which is interpreted -# as an error value. +# as an error value. The second part is optional and if present has the +# form DEBUG:, it is used for implementing the debugger. # A constant used as object metadata to indicate the object is cross language. OBJECT_METADATA_TYPE_CROSS_LANGUAGE = b"XLANG" @@ -213,6 +214,9 @@ OBJECT_METADATA_TYPE_RAW = b"RAW" # of XLANG. OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE" +# A constant indicating the debugging part of the metadata (see above). +OBJECT_METADATA_DEBUG_PREFIX = b"DEBUG:" + AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request" # The default password to prevent redis port scanning attack. diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 68a0eef84..e717e2d28 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -258,8 +258,12 @@ class RemoteFunction: placement_group.id, placement_group_bundle_index, placement_group_capture_child_tasks, + worker.debugger_breakpoint, override_environment_variables=override_environment_variables or dict()) + # Reset worker's debug context from the last "remote" command + # (which applies only to this .remote call). + worker.debugger_breakpoint = b"" if len(object_refs) == 1: return object_refs[0] elif len(object_refs) > 1: diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index e3e1ea52f..fd0c286f2 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -6,6 +6,7 @@ import logging import os import subprocess import sys +from telnetlib import Telnet import time import urllib import urllib.parse @@ -150,6 +151,35 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port): from None +def continue_debug_session(): + """Continue active debugging session. + + This function will connect 'ray debug' to the right debugger + when a user is stepping between Ray tasks. + """ + active_sessions = ray.experimental.internal_kv._internal_kv_list( + "RAY_PDB_") + + for active_session in active_sessions: + if active_session.startswith(b"RAY_PDB_CONTINUE"): + print("Continuing pdb session in different process...") + key = b"RAY_PDB_" + active_session[len("RAY_PDB_CONTINUE_"):] + while True: + data = ray.experimental.internal_kv._internal_kv_get(key) + if data: + session = json.loads(data) + if "exit_debugger" in session: + ray.experimental.internal_kv._internal_kv_del(key) + return + host, port = session["pdb_address"].split(":") + with Telnet(host, int(port)) as tn: + tn.interact() + ray.experimental.internal_kv._internal_kv_del(key) + continue_debug_session() + return + time.sleep(1.0) + + @cli.command() @click.option( "--address", @@ -158,12 +188,13 @@ def dashboard(cluster_config_file, cluster_name, port, remote_port): help="Override the address to connect to.") def debug(address): """Show all active breakpoints and exceptions in the Ray debugger.""" - from telnetlib import Telnet if not address: address = services.get_ray_address_to_use_or_die() logger.info(f"Connecting to Ray instance at {address}.") - ray.init(address=address) + ray.init(address=address, log_to_driver=False) while True: + continue_debug_session() + active_sessions = ray.experimental.internal_kv._internal_kv_list( "RAY_PDB_") print("Active breakpoints:") diff --git a/python/ray/tests/test_ray_debugger.py b/python/ray/tests/test_ray_debugger.py index 67df794f8..df9954995 100644 --- a/python/ray/tests/test_ray_debugger.py +++ b/python/ray/tests/test_ray_debugger.py @@ -3,6 +3,7 @@ import os import sys from telnetlib import Telnet +import pexpect import ray @@ -34,6 +35,67 @@ def test_ray_debugger_breakpoint(shutdown_only): ray.get(result) +def test_ray_debugger_stepping(shutdown_only): + ray.init(num_cpus=1) + + @ray.remote + def g(): + return None + + @ray.remote + def f(): + ray.util.pdb.set_trace() + x = g.remote() + return ray.get(x) + + result = f.remote() + + p = pexpect.spawn("ray debug") + p.expect("Enter breakpoint index or press enter to refresh: ") + p.sendline("0") + p.expect("-> x = g.remote()") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("get") + p.expect("(Pdb)") + p.sendline("continue") + + # This should succeed now! + ray.get(result) + + +def test_ray_debugger_recursive(shutdown_only): + ray.init(num_cpus=1) + + @ray.remote + def fact(n): + if n < 1: + return n + ray.util.pdb.set_trace() + n_id = fact.remote(n - 1) + return n * ray.get(n_id) + + result = fact.remote(5) + + p = pexpect.spawn("ray debug") + p.expect("Enter breakpoint index or press enter to refresh: ") + p.sendline("0") + p.expect("(Pdb)") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("remote") + p.expect("(Pdb)") + p.sendline("remote") + + ray.get(result) + + if __name__ == "__main__": import pytest # Make subprocess happy in bazel. diff --git a/python/ray/util/rpdb.py b/python/ray/util/rpdb.py index 3ffd3626f..33d430ea0 100644 --- a/python/ray/util/rpdb.py +++ b/python/ray/util/rpdb.py @@ -15,6 +15,7 @@ from pdb import Pdb import setproctitle import traceback +import ray from ray.experimental.internal_kv import _internal_kv_del, _internal_kv_put PY3 = sys.version_info[0] == 3 @@ -70,7 +71,13 @@ class RemotePdb(Pdb): """ active_instance = None - def __init__(self, host, port, patch_stdstreams=False, quiet=False): + def __init__(self, + breakpoint_uuid, + host, + port, + patch_stdstreams=False, + quiet=False): + self._breakpoint_uuid = breakpoint_uuid self._quiet = quiet self._patch_stdstreams = patch_stdstreams self._listen_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -138,8 +145,35 @@ class RemotePdb(Pdb): if exc.errno != errno.ECONNRESET: raise + def do_remote(self, arg): + """remote + Skip into the next remote call. + """ + # Tell the next task to drop into the debugger. + ray.worker.global_worker.debugger_breakpoint = self._breakpoint_uuid + # Tell the debug loop to connect to the next task. + _internal_kv_put("RAY_PDB_CONTINUE_{}".format(self._breakpoint_uuid), + "") + self.__restore() + self.handle.connection.close() + return Pdb.do_continue(self, arg) -def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None): + def do_get(self, arg): + """get + Skip to where the current task returns to. + """ + ray.worker.global_worker.debugger_get_breakpoint = ( + self._breakpoint_uuid) + self.__restore() + self.handle.connection.close() + return Pdb.do_continue(self, arg) + + +def connect_ray_pdb(host=None, + port=None, + patch_stdstreams=False, + quiet=None, + breakpoint_uuid=None): """ Opens a remote PDB on first available port. """ @@ -149,8 +183,14 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None): port = int(os.environ.get("REMOTE_PDB_PORT", "0")) if quiet is None: quiet = bool(os.environ.get("REMOTE_PDB_QUIET", "")) + if not breakpoint_uuid: + breakpoint_uuid = uuid.uuid4().hex rdb = RemotePdb( - host=host, port=port, patch_stdstreams=patch_stdstreams, quiet=quiet) + breakpoint_uuid=breakpoint_uuid, + host=host, + port=port, + patch_stdstreams=patch_stdstreams, + quiet=quiet) sockname = rdb._listen_socket.getsockname() pdb_address = "{}:{}".format(sockname[0], sockname[1]) parentframeinfo = inspect.getouterframes(inspect.currentframe())[2] @@ -161,7 +201,6 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None): "lineno": parentframeinfo.lineno, "traceback": "\n".join(traceback.format_exception(*sys.exc_info())) } - breakpoint_uuid = uuid.uuid4() _internal_kv_put( "RAY_PDB_{}".format(breakpoint_uuid), json.dumps(data), overwrite=True) rdb.listen() @@ -170,14 +209,19 @@ def connect_ray_pdb(host=None, port=None, patch_stdstreams=False, quiet=None): return rdb -def set_trace(): +def set_trace(breakpoint_uuid=None): """Interrupt the flow of the program and drop into the Ray debugger. Can be used within a Ray task or actor. """ - frame = sys._getframe().f_back - rdb = connect_ray_pdb(None, None, False, None) - rdb.set_trace(frame=frame) + # If there is an active debugger already, we do not want to + # start another one, so "set_trace" is just a no-op in that case. + if ray.worker.global_worker.debugger_breakpoint == b"": + frame = sys._getframe().f_back + rdb = connect_ray_pdb( + None, None, False, None, + breakpoint_uuid.decode() if breakpoint_uuid else None) + rdb.set_trace(frame=frame) def post_mortem(): diff --git a/python/ray/worker.py b/python/ray/worker.py index d5093e360..a5b5559aa 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -102,6 +102,13 @@ class Worker: # Index of the current session. This number will # increment every time when `ray.shutdown` is called. self._session_index = 0 + # If this is set, the next .remote call should drop into the + # debugger, at the specified breakpoint ID. + self.debugger_breakpoint = b"" + # If this is set, ray.get calls invoked on the object ID returned + # by the worker should drop into the debugger at the specified + # breakpoint ID. + self.debugger_get_breakpoint = b"" @property def connected(self): @@ -280,6 +287,10 @@ class Worker: whose values should be retrieved. timeout (float): timeout (float): The maximum amount of time in seconds to wait before returning. + Returns: + list: List of deserialized objects + bytes: UUID of the debugger breakpoint we should drop + into or b"" if there is no breakpoint. """ # Make sure that the values are object refs. for object_ref in object_refs: @@ -291,7 +302,16 @@ class Worker: timeout_ms = int(timeout * 1000) if timeout else -1 data_metadata_pairs = self.core_worker.get_objects( object_refs, self.current_task_id, timeout_ms) - return self.deserialize_objects(data_metadata_pairs, object_refs) + debugger_breakpoint = b"" + for (data, metadata) in data_metadata_pairs: + if metadata: + metadata_fields = metadata.split(b",") + if len(metadata_fields) >= 2 and metadata_fields[1].startswith( + ray_constants.OBJECT_METADATA_DEBUG_PREFIX): + debugger_breakpoint = metadata_fields[1][len( + ray_constants.OBJECT_METADATA_DEBUG_PREFIX):] + return self.deserialize_objects(data_metadata_pairs, + object_refs), debugger_breakpoint def run_function_on_all_workers(self, function, run_on_other_drivers=False): @@ -1345,7 +1365,8 @@ def get(object_refs, *, timeout=None): global last_task_error_raise_time # TODO(ujvl): Consider how to allow user to retrieve the ready objects. - values = worker.get_objects(object_refs, timeout=timeout) + values, debugger_breakpoint = worker.get_objects( + object_refs, timeout=timeout) for i, value in enumerate(values): if isinstance(value, RayError): last_task_error_raise_time = time.time() @@ -1358,6 +1379,14 @@ def get(object_refs, *, timeout=None): if is_individual_id: values = values[0] + + if debugger_breakpoint != b"": + frame = sys._getframe().f_back + rdb = ray.util.pdb.connect_ray_pdb( + None, None, False, None, + debugger_breakpoint.decode() if debugger_breakpoint else None) + rdb.set_trace(frame=frame) + return values diff --git a/python/requirements.txt b/python/requirements.txt index 19b01dfc1..76d0799fe 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -57,6 +57,7 @@ numba # higher version of llvmlite breaks windows llvmlite==0.34.0 openpyxl +pexpect Pillow; platform_system != "Windows" pygments pytest==5.4.3 diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 875b61276..a07f5bade 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -193,6 +193,10 @@ const ResourceSet &TaskSpecification::GetRequiredPlacementResources() const { return *required_placement_resources_; } +std::string TaskSpecification::GetDebuggerBreakpoint() const { + return message_->debugger_breakpoint(); +} + std::unordered_map TaskSpecification::OverrideEnvironmentVariables() const { return MapFromProtobuf(message_->override_environment_variables()); diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 2dec30283..bee66464c 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -131,6 +131,8 @@ class TaskSpecification : public MessageWrapper { /// \return The recomputed dependencies for the task. std::vector GetDependencies() const; + std::string GetDebuggerBreakpoint() const; + std::unordered_map OverrideEnvironmentVariables() const; bool IsDriverTask() const; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index 825ae659e..32ef5e3c8 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -87,6 +87,7 @@ class TaskSpecBuilder { const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, const BundleID &bundle_id, bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint, const std::unordered_map &override_environment_variables = {}) { message_->set_type(TaskType::NORMAL_TASK); @@ -108,6 +109,7 @@ class TaskSpecBuilder { message_->set_placement_group_bundle_index(bundle_id.second); message_->set_placement_group_capture_child_tasks( placement_group_capture_child_tasks); + message_->set_debugger_breakpoint(debugger_breakpoint); for (const auto &env : override_environment_variables) { (*message_->mutable_override_environment_variables())[env.first] = env.second; } diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 8c4bf546f..c18a011a4 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -39,14 +39,14 @@ void BuildCommonTaskSpec( const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources, std::vector *return_ids, const ray::BundleID &bundle_id, - bool placement_group_capture_child_tasks, + bool placement_group_capture_child_tasks, const std::string debugger_breakpoint, const std::unordered_map &override_environment_variables) { // Build common task spec. builder.SetCommonTaskSpec( task_id, name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, current_task_id, task_index, caller_id, address, num_returns, required_resources, required_placement_resources, bundle_id, placement_group_capture_child_tasks, - override_environment_variables); + debugger_breakpoint, override_environment_variables); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1294,7 +1294,8 @@ void CoreWorker::SubmitTask(const RayFunction &function, const TaskOptions &task_options, std::vector *return_ids, int max_retries, BundleID placement_options, - bool placement_group_capture_child_tasks) { + bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint) { TaskSpecBuilder builder; const int next_task_index = worker_context_.GetNextTaskIndex(); const auto task_id = @@ -1320,7 +1321,7 @@ void CoreWorker::SubmitTask(const RayFunction &function, rpc_address_, function, args, task_options.num_returns, constrained_resources, required_resources, return_ids, placement_options, placement_group_capture_child_tasks, - override_environment_variables); + debugger_breakpoint, override_environment_variables); TaskSpecification task_spec = builder.Build(); if (options_.is_local_mode) { ExecuteTaskLocalMode(task_spec); @@ -1376,6 +1377,7 @@ Status CoreWorker::CreateActor(const RayFunction &function, new_placement_resources, &return_ids, actor_creation_options.placement_options, actor_creation_options.placement_group_capture_child_tasks, + "", /* debugger_breakpoint */ override_environment_variables); builder.SetActorCreationTaskSpec(actor_id, actor_creation_options.max_restarts, actor_creation_options.dynamic_worker_options, @@ -1504,6 +1506,7 @@ void CoreWorker::SubmitActorTask(const ActorID &actor_id, const RayFunction &fun required_resources, return_ids, std::make_pair(PlacementGroupID::Nil(), -1), true, /* placement_group_capture_child_tasks */ + "", /* debugger_breakpoint */ override_environment_variables); // NOTE: placement_group_capture_child_tasks and override_environment_variables will be // ignored in the actor because we should always follow the actor's option. @@ -1802,7 +1805,7 @@ Status CoreWorker::ExecuteTask(const TaskSpecification &task_spec, status = options_.task_execution_callback( task_type, task_spec.GetName(), func, task_spec.GetRequiredResources().GetResourceMap(), args, arg_reference_ids, - return_ids, return_objects); + return_ids, task_spec.GetDebuggerBreakpoint(), return_objects); absl::optional caller_address( options_.is_local_mode ? absl::optional() diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index beea3a874..1b877ec2f 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -61,7 +61,7 @@ struct CoreWorkerOptions { const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_reference_ids, - const std::vector &return_ids, + const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results)>; CoreWorkerOptions() @@ -632,12 +632,15 @@ class CoreWorker : public rpc::CoreWorkerServiceHandler { /// \param[in] max_retires max number of retry when the task fails. /// \param[in] placement_options placement group options. /// \param[in] placement_group_capture_child_tasks whether or not the submitted task + /// \param[in] debugger_breakpoint breakpoint to drop into for the debugger after this + /// task starts executing, or "" if we do not want to drop into the debugger. /// should capture parent's placement group implicilty. void SubmitTask(const RayFunction &function, const std::vector> &args, const TaskOptions &task_options, std::vector *return_ids, int max_retries, BundleID placement_options, - bool placement_group_capture_child_tasks); + bool placement_group_capture_child_tasks, + const std::string &debugger_breakpoint); /// Create an actor. /// diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index ee8c76a29..b7fb72310 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -110,7 +110,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( const std::unordered_map &required_resources, const std::vector> &args, const std::vector &arg_reference_ids, - const std::vector &return_ids, + const std::vector &return_ids, const std::string &debugger_breakpoint, std::vector> *results) { JNIEnv *env = GetJNIEnv(); RAY_CHECK(java_task_executor); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index c11f782c2..031f62c44 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -222,7 +222,9 @@ JNIEXPORT jobject JNICALL Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSub ray_function, task_args, task_options, &return_ids, /*max_retries=*/0, /*placement_options=*/ - std::pair(ray::PlacementGroupID::Nil(), 0), true); + std::pair(ray::PlacementGroupID::Nil(), 0), + /*placement_group_capture_child_tasks=*/true, + /*debugger_breakpoint*/ ""); // This is to avoid creating an empty java list and boost performance. if (return_ids.empty()) { diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 23d16890b..626fb2d59 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -257,7 +257,8 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res TaskOptions options; std::vector return_ids; driver.SubmitTask(func, args, options, &return_ids, /*max_retries=*/0, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, + /*debugger_breakpoint=*/""); ASSERT_EQ(return_ids.size(), 1); @@ -533,7 +534,7 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { builder.SetCommonTaskSpec(RandomTaskId(), options.name, function.GetLanguage(), function.GetFunctionDescriptor(), job_id, RandomTaskId(), 0, RandomTaskId(), address, num_returns, resources, resources, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index a6056d45a..27af163b2 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -328,7 +328,7 @@ TaskSpecification BuildTaskSpec(const std::unordered_map &r builder.SetCommonTaskSpec(TaskID::Nil(), "dummy_task", Language::PYTHON, function_descriptor, JobID::Nil(), TaskID::Nil(), 0, TaskID::Nil(), empty_address, 1, resources, resources, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); return builder.Build(); } diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 47f111dfd..4439519bb 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -46,7 +46,7 @@ class MockWorker { options.node_manager_port = node_manager_port; options.raylet_ip_address = "127.0.0.1"; options.task_execution_callback = - std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8); + std::bind(&MockWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8, _9); options.ref_counting_enabled = true; options.num_workers = 1; options.metrics_agent_port = -1; @@ -62,6 +62,7 @@ class MockWorker { const std::vector> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results) { // Note that this doesn't include dummy object id. const ray::FunctionDescriptor function_descriptor = diff --git a/src/ray/gcs/test/gcs_test_util.h b/src/ray/gcs/test/gcs_test_util.h index 40c478c37..5d152c1f2 100644 --- a/src/ray/gcs/test/gcs_test_util.h +++ b/src/ray/gcs/test/gcs_test_util.h @@ -41,7 +41,7 @@ struct Mocker { builder.SetCommonTaskSpec(task_id, name + ":" + empty_descriptor->CallString(), Language::PYTHON, empty_descriptor, job_id, TaskID::Nil(), 0, TaskID::Nil(), owner_address, 1, resource, resource, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); builder.SetActorCreationTaskSpec(actor_id, max_restarts, {}, 1, detached, name); return builder.Build(); } diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index b894f92aa..ec64b6674 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -200,6 +200,9 @@ message TaskSpec { // the receiver will not execute the task. This field is used by async actors // to guarantee task submission order after restart. bool skip_execution = 22; + // Breakpoint if this task should drop into the debugger when it starts executing + // and "" if the task should not drop into the debugger. + bytes debugger_breakpoint = 23; } message Bundle { diff --git a/src/ray/raylet/scheduling/cluster_task_manager_test.cc b/src/ray/raylet/scheduling/cluster_task_manager_test.cc index f20515353..9c046630b 100644 --- a/src/ray/raylet/scheduling/cluster_task_manager_test.cc +++ b/src/ray/raylet/scheduling/cluster_task_manager_test.cc @@ -80,7 +80,7 @@ Task CreateTask(const std::unordered_map &required_resource FunctionDescriptorBuilder::BuildPython("", "", "", ""), job_id, TaskID::Nil(), 0, TaskID::Nil(), address, 0, required_resources, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); for (int i = 0; i < num_args; i++) { ObjectID put_id = ObjectID::FromIndex(TaskID::Nil(), /*index=*/i + 1); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index 3f53b5a09..b8c1c9b31 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -69,7 +69,7 @@ static inline Task ExampleTask(const std::vector &arguments, FunctionDescriptorBuilder::BuildPython("", "", "", ""), JobID::Nil(), RandomTaskId(), 0, RandomTaskId(), address, num_returns, {}, {}, - std::make_pair(PlacementGroupID::Nil(), -1), true); + std::make_pair(PlacementGroupID::Nil(), -1), true, ""); builder.SetActorCreationTaskSpec(ActorID::Nil(), 1, {}, 1, false, "", false); for (const auto &arg : arguments) { builder.AddArg(TaskArgByReference(arg, rpc::Address())); diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 28911758e..09e255995 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -499,8 +499,8 @@ class StreamingWorker { options.node_ip_address = "127.0.0.1"; options.node_manager_port = node_manager_port; options.raylet_ip_address = "127.0.0.1"; - options.task_execution_callback = - std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7, _8); + options.task_execution_callback = std::bind(&StreamingWorker::ExecuteTask, this, _1, + _2, _3, _4, _5, _6, _7, _8, _9); options.ref_counting_enabled = true; options.num_workers = 1; options.metrics_agent_port = -1; @@ -520,6 +520,7 @@ class StreamingWorker { const std::vector> &args, const std::vector &arg_reference_ids, const std::vector &return_ids, + const std::string &debugger_breakpoint, std::vector> *results) { // Only one arg param used in streaming. STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size();