mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 01:37:40 +08:00
[ID Refactor] Rename DriverID to JobID (#5004)
* WIP WIP WIP Rename Driver -> Job Fix complition Fix Rename in Java In py WIP Fix WIP Fix Fix test Fix Fix C++ linting Fix * Update java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Update src/ray/core_worker/core_worker.cc Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Address comments * Fix * Fix CI * Fix cpp linting * Fix py lint * FIx * Address comments and fix * Address comments * Address * Fix import_threading
This commit is contained in:
+77
-80
@@ -40,7 +40,8 @@ from ray import (
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
DriverID,
|
||||
WorkerID,
|
||||
JobID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
)
|
||||
@@ -145,9 +146,9 @@ class Worker(object):
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
# It is a DriverID.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Identity of the job that this worker is processing.
|
||||
# It is a JobID.
|
||||
self.current_job_id = JobID.nil()
|
||||
self._task_context = threading.local()
|
||||
# This event is checked regularly by all of the threads so that they
|
||||
# know when to exit.
|
||||
@@ -227,24 +228,24 @@ class Worker(object):
|
||||
if self.actor_init_error is not None:
|
||||
raise self.actor_init_error
|
||||
|
||||
def get_serialization_context(self, driver_id):
|
||||
"""Get the SerializationContext of the driver that this worker is processing.
|
||||
def get_serialization_context(self, job_id):
|
||||
"""Get the SerializationContext of the job that this worker is processing.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that indicates which driver to get
|
||||
job_id: The ID of the job that indicates which job to get
|
||||
the serialization context for.
|
||||
|
||||
Returns:
|
||||
The serialization context of the given driver.
|
||||
The serialization context of the given job.
|
||||
"""
|
||||
# This function needs to be proctected by a lock, because it will be
|
||||
# called by`register_class_for_serialization`, as well as the import
|
||||
# thread, from different threads. Also, this function will recursively
|
||||
# call itself, so we use RLock here.
|
||||
with self.lock:
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
if job_id not in self.serialization_context_map:
|
||||
_initialize_serialization(job_id)
|
||||
return self.serialization_context_map[job_id]
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
@@ -314,7 +315,7 @@ class Worker(object):
|
||||
object_id=pyarrow.plasma.ObjectID(object_id.binary()),
|
||||
memcopy_threads=self.memcopy_threads,
|
||||
serialization_context=self.get_serialization_context(
|
||||
self.task_driver_id))
|
||||
self.current_job_id))
|
||||
break
|
||||
except pyarrow.SerializationCallbackError as e:
|
||||
try:
|
||||
@@ -388,17 +389,17 @@ class Worker(object):
|
||||
# should return an error code to the caller instead of printing a
|
||||
# message.
|
||||
logger.info(
|
||||
"The object with ID {} already exists in the object store."
|
||||
.format(object_id))
|
||||
"The object with ID {} already exists in the object store.".
|
||||
format(object_id))
|
||||
except TypeError:
|
||||
# This error can happen because one of the members of the object
|
||||
# may not be serializable for cloudpickle. So we need these extra
|
||||
# fallbacks here to start from the beginning. Hopefully the object
|
||||
# could have a `__reduce__` method.
|
||||
register_custom_serializer(type(value), use_pickle=True)
|
||||
warning_message = ("WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle."
|
||||
.format(type(value)))
|
||||
warning_message = (
|
||||
"WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle.".format(type(value)))
|
||||
logger.warning(warning_message)
|
||||
self.store_and_register(object_id, value)
|
||||
|
||||
@@ -407,7 +408,7 @@ class Worker(object):
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
serialization_context = self.get_serialization_context(
|
||||
self.task_driver_id)
|
||||
self.current_job_id)
|
||||
while True:
|
||||
try:
|
||||
# We divide very large get requests into smaller get requests
|
||||
@@ -449,7 +450,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
warning_sent = True
|
||||
|
||||
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
||||
@@ -575,7 +576,7 @@ class Worker(object):
|
||||
num_return_vals=None,
|
||||
resources=None,
|
||||
placement_resources=None,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with
|
||||
@@ -601,11 +602,11 @@ class Worker(object):
|
||||
placement_resources: The resources required for placing the task.
|
||||
If this is not provided or if it is an empty dictionary, then
|
||||
the placement resources will be equal to resources.
|
||||
driver_id: The ID of the relevant driver. This is almost always the
|
||||
driver ID of the driver that is currently running. However, in
|
||||
job_id: The ID of the relevant job. This is almost always the
|
||||
job ID of the job that is currently running. However, in
|
||||
the exceptional case that an actor task is being dispatched to
|
||||
an actor created by a different driver, this should be the
|
||||
driver ID of the driver that created the actor.
|
||||
an actor created by a different job, this should be the
|
||||
job ID of the job that created the actor.
|
||||
|
||||
Returns:
|
||||
The return object IDs for this task.
|
||||
@@ -642,8 +643,8 @@ class Worker(object):
|
||||
if new_actor_handles is None:
|
||||
new_actor_handles = []
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = self.task_driver_id
|
||||
if job_id is None:
|
||||
job_id = self.current_job_id
|
||||
|
||||
if resources is None:
|
||||
raise ValueError("The resources dictionary is required.")
|
||||
@@ -674,13 +675,13 @@ class Worker(object):
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Current driver id must not be nil when submitting a task.
|
||||
# Because every task must belong to a driver.
|
||||
assert not self.task_driver_id.is_nil()
|
||||
assert not self.current_job_id.is_nil()
|
||||
# Submit the task to raylet.
|
||||
function_descriptor_list = (
|
||||
function_descriptor.get_function_descriptor_list())
|
||||
assert isinstance(driver_id, DriverID)
|
||||
assert isinstance(job_id, JobID)
|
||||
task = ray._raylet.Task(
|
||||
driver_id,
|
||||
job_id,
|
||||
function_descriptor_list,
|
||||
args_for_raylet,
|
||||
num_return_vals,
|
||||
@@ -747,7 +748,7 @@ class Worker(object):
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.binary(),
|
||||
"job_id": self.current_job_id.binary(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickled_function,
|
||||
"run_on_other_drivers": str(run_on_other_drivers)
|
||||
@@ -853,17 +854,17 @@ class Worker(object):
|
||||
assert self.task_context.task_index == 0
|
||||
assert self.task_context.put_index == 1
|
||||
if task.actor_id().is_nil():
|
||||
# If this worker is not an actor, check that `task_driver_id`
|
||||
# If this worker is not an actor, check that `current_job_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.task_driver_id.is_nil()
|
||||
assert self.current_job_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_job_id = task.job_id()
|
||||
else:
|
||||
# If this worker is an actor, task_driver_id wasn't reset.
|
||||
# If this worker is an actor, current_job_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.task_driver_id == task.driver_id()
|
||||
assert self.current_job_id == task.job_id()
|
||||
|
||||
self.task_context.current_task_id = task.task_id()
|
||||
|
||||
@@ -945,7 +946,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
# Mark the actor init as failed
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
@@ -960,7 +961,7 @@ class Worker(object):
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
driver_id = task.driver_id()
|
||||
job_id = task.job_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
@@ -969,7 +970,7 @@ class Worker(object):
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
actor_class = self.function_actor_manager.load_actor_class(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
self.actors[self.actor_id] = actor_class.__new__(actor_class)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
@@ -978,7 +979,7 @@ class Worker(object):
|
||||
)
|
||||
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
# Execute the task.
|
||||
function_name = execution_info.function_name
|
||||
@@ -1005,20 +1006,20 @@ class Worker(object):
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# Don't need to reset `current_job_id` if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
self.current_job_id = WorkerID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
reached_max_executions = (self.function_actor_manager.get_task_counter(
|
||||
driver_id, function_descriptor) == execution_info.max_calls)
|
||||
job_id, function_descriptor) == execution_info.max_calls)
|
||||
if reached_max_executions:
|
||||
self.raylet_client.disconnect()
|
||||
sys.exit(0)
|
||||
@@ -1141,7 +1142,7 @@ def print_failed_task(task_status):
|
||||
task_status["error_message"]))
|
||||
|
||||
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
def _initialize_serialization(job_id, worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
@@ -1177,7 +1178,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
custom_serializer=actor_handle_serializer,
|
||||
custom_deserializer=actor_handle_deserializer)
|
||||
|
||||
worker.serialization_context_map[driver_id] = serialization_context
|
||||
worker.serialization_context_map[job_id] = serialization_context
|
||||
|
||||
# Register exception types.
|
||||
for error_cls in RAY_EXCEPTION_TYPES:
|
||||
@@ -1185,7 +1186,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
error_cls,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id=error_cls.__module__ + ". " + error_cls.__name__,
|
||||
)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
@@ -1193,22 +1194,18 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
type(lambda: 0),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="lambda")
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_custom_serializer(
|
||||
type(int),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="type")
|
||||
type(int), use_pickle=True, local=True, job_id=job_id, class_id="type")
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
register_custom_serializer(
|
||||
ray.signature.FunctionSignature,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="ray.signature.FunctionSignature")
|
||||
|
||||
|
||||
@@ -1231,7 +1228,7 @@ def init(redis_address=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
@@ -1302,7 +1299,7 @@ def init(redis_address=None,
|
||||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which displays the status of the Ray cluster.
|
||||
driver_id: The ID of driver.
|
||||
job_id: The ID of this job.
|
||||
configure_logging: True if allow the logging cofiguration here.
|
||||
Otherwise, the users may want to configure it by their own.
|
||||
logging_level: Logging level, default will be logging.INFO.
|
||||
@@ -1449,7 +1446,7 @@ def init(redis_address=None,
|
||||
mode=driver_mode,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
@@ -1660,10 +1657,10 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
assert len(gcs_entry.entries) == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.FromString(
|
||||
gcs_entry.entries[0])
|
||||
driver_id = error_data.driver_id
|
||||
if driver_id not in [
|
||||
worker.task_driver_id.binary(),
|
||||
DriverID.nil().binary()
|
||||
job_id = error_data.job_id
|
||||
if job_id not in [
|
||||
worker.current_job_id.binary(),
|
||||
JobID.nil().binary()
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -1691,7 +1688,7 @@ def connect(node,
|
||||
mode=WORKER_MODE,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Connect this worker to the raylet, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
@@ -1701,7 +1698,7 @@ def connect(node,
|
||||
log_to_driver (bool): If true, then output from all of the worker
|
||||
processes on all nodes will be directed to the driver.
|
||||
worker: The ray.Worker instance.
|
||||
driver_id: The ID of driver. If it's None, then we will generate one.
|
||||
job_id: The ID of job. If it's None, then we will generate one.
|
||||
"""
|
||||
# Do some basic checking to make sure we didn't call ray.init twice.
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
@@ -1721,20 +1718,20 @@ def connect(node,
|
||||
setproctitle.setproctitle("ray_worker")
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = DriverID.from_random()
|
||||
if job_id is None:
|
||||
job_id = JobID.from_random()
|
||||
|
||||
if not isinstance(driver_id, DriverID):
|
||||
raise TypeError("The type of given driver id must be DriverID.")
|
||||
if not isinstance(job_id, JobID):
|
||||
raise TypeError("The type of given job id must be JobID.")
|
||||
|
||||
worker.worker_id = driver_id.binary()
|
||||
worker.worker_id = job_id.binary()
|
||||
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
# drivers, the current job ID is used to keep track of which driver is
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = DriverID(worker.worker_id)
|
||||
worker.current_job_id = JobID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
@@ -1766,7 +1763,7 @@ def connect(node,
|
||||
worker.redis_client,
|
||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
job_id=None)
|
||||
|
||||
worker.lock = threading.RLock()
|
||||
|
||||
@@ -1831,7 +1828,7 @@ def connect(node,
|
||||
# Create an object store client.
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
|
||||
driver_id_str = _random_string()
|
||||
job_id_str = _random_string()
|
||||
|
||||
# If this is a driver, set the current task ID, the task driver ID, and set
|
||||
# the task index to 0.
|
||||
@@ -1859,11 +1856,11 @@ def connect(node,
|
||||
|
||||
function_descriptor = FunctionDescriptor.for_driver_task()
|
||||
driver_task = ray._raylet.Task(
|
||||
worker.task_driver_id,
|
||||
worker.current_job_id,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
TaskID(driver_id_str[:TaskID.size()]), # parent_task_id.
|
||||
TaskID(job_id_str[:TaskID.size()]), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ActorID.nil(), # actor_creation_id.
|
||||
ObjectID.nil(), # actor_creation_dummy_object_id.
|
||||
@@ -1895,7 +1892,7 @@ def connect(node,
|
||||
node.raylet_socket_name,
|
||||
ClientID(worker.worker_id),
|
||||
(mode == WORKER_MODE),
|
||||
DriverID(driver_id_str),
|
||||
JobID(job_id_str),
|
||||
)
|
||||
|
||||
# Start the import thread
|
||||
@@ -2057,7 +2054,7 @@ def register_custom_serializer(cls,
|
||||
serializer=None,
|
||||
deserializer=None,
|
||||
local=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
@@ -2078,7 +2075,7 @@ def register_custom_serializer(cls,
|
||||
if and only if use_pickle and use_dict are False.
|
||||
local: True if the serializers should only be registered on the current
|
||||
worker. This should usually be False.
|
||||
driver_id: ID of the driver that we want to register the class for.
|
||||
job_id: ID of the job that we want to register the class for.
|
||||
class_id: ID of the class that we are registering. If this is not
|
||||
specified, we will calculate a new one inside the function.
|
||||
|
||||
@@ -2126,9 +2123,9 @@ def register_custom_serializer(cls,
|
||||
# Make sure class_id is a string.
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = worker.task_driver_id
|
||||
assert isinstance(driver_id, DriverID)
|
||||
if job_id is None:
|
||||
job_id = worker.current_job_id
|
||||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2138,7 +2135,7 @@ def register_custom_serializer(cls,
|
||||
# system.
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(driver_id)
|
||||
"worker"].get_serialization_context(job_id)
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
|
||||
Reference in New Issue
Block a user