mirror of
https://github.com/wassname/ray.git
synced 2026-07-03 01:09:13 +08:00
Refactor code about ray.ObjectID. (#3674)
* Refactor code about ray.ObjectID. * remove from_random and use nil_id instead of constructor * remove id() in hash * Lint and fix * Change driver id to ObjectID * Replace binary_to_hex(ObjectID.id()) to ObjectID.hex()
This commit is contained in:
committed by
Philipp Moritz
parent
c4b058739b
commit
d2cf8561f2
+64
-71
@@ -36,6 +36,7 @@ import ray.raylet
|
||||
import ray.plasma
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray import import_thread
|
||||
from ray import ObjectID
|
||||
from ray import profiling
|
||||
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
|
||||
import ray.parameter
|
||||
@@ -53,13 +54,6 @@ PYTHON_MODE = 3
|
||||
|
||||
ERROR_KEY_PREFIX = b"Error:"
|
||||
|
||||
# This must match the definition of NIL_ACTOR_ID in task.h.
|
||||
NIL_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
NIL_LOCAL_SCHEDULER_ID = NIL_ID
|
||||
NIL_ACTOR_ID = NIL_ID
|
||||
NIL_ACTOR_HANDLE_ID = NIL_ID
|
||||
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
|
||||
# Default resource requirements for actors when no resource requirements are
|
||||
# specified.
|
||||
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
|
||||
@@ -168,7 +162,7 @@ class Worker(object):
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.task_driver_id = ObjectID.nil_id()
|
||||
self._task_context = threading.local()
|
||||
|
||||
@property
|
||||
@@ -189,14 +183,13 @@ class Worker(object):
|
||||
# If this is running on the main thread, initialize it to
|
||||
# NIL. The actual value will set when the worker receives
|
||||
# a task from raylet backend.
|
||||
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self._task_context.current_task_id = ObjectID.nil_id()
|
||||
else:
|
||||
# If this is running on a separate thread, then the mapping
|
||||
# to the current task ID may not be correct. Generate a
|
||||
# random task ID so that the backend can differentiate
|
||||
# between different threads.
|
||||
self._task_context.current_task_id = ray.ObjectID(
|
||||
random_string())
|
||||
self._task_context.current_task_id = ObjectID(random_string())
|
||||
if getattr(self, '_multithreading_warned', False) is not True:
|
||||
logger.warning(
|
||||
"Calling ray.get or ray.wait in a separate thread "
|
||||
@@ -353,12 +346,13 @@ class Worker(object):
|
||||
full.
|
||||
"""
|
||||
# Make sure that the value is not an object ID.
|
||||
if isinstance(value, ray.ObjectID):
|
||||
raise Exception("Calling 'put' on an ObjectID is not allowed "
|
||||
"(similarly, returning an ObjectID from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ObjectID in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
if isinstance(value, ObjectID):
|
||||
raise Exception(
|
||||
"Calling 'put' on an ray.ObjectID is not allowed "
|
||||
"(similarly, returning an ray.ObjectID from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ray.ObjectID in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
|
||||
# Serialize and put the object in the object store.
|
||||
try:
|
||||
@@ -433,7 +427,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id.id())
|
||||
driver_id=self.task_driver_id)
|
||||
warning_sent = True
|
||||
|
||||
def get_object(self, object_ids):
|
||||
@@ -449,9 +443,10 @@ class Worker(object):
|
||||
"""
|
||||
# Make sure that the values are object IDs.
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise Exception("Attempting to call `get` on the value {}, "
|
||||
"which is not an ObjectID.".format(object_id))
|
||||
if not isinstance(object_id, ObjectID):
|
||||
raise Exception(
|
||||
"Attempting to call `get` on the value {}, "
|
||||
"which is not an ray.ObjectID.".format(object_id))
|
||||
# Do an initial fetch for remote objects. We divide the fetch into
|
||||
# smaller fetches so as to not block the manager for a prolonged period
|
||||
# of time in a single call.
|
||||
@@ -484,8 +479,7 @@ class Worker(object):
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
ray_object_ids_to_fetch = [
|
||||
ray.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
ObjectID(unready_id) for unready_id in unready_ids.keys()
|
||||
]
|
||||
fetch_request_size = ray._config.worker_fetch_request_size()
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
@@ -574,22 +568,22 @@ class Worker(object):
|
||||
with profiling.profile("submit_task", worker=self):
|
||||
if actor_id is None:
|
||||
assert actor_handle_id is None
|
||||
actor_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
|
||||
actor_id = ObjectID.nil_id()
|
||||
actor_handle_id = ObjectID.nil_id()
|
||||
else:
|
||||
assert actor_handle_id is not None
|
||||
|
||||
if actor_creation_id is None:
|
||||
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
actor_creation_id = ObjectID.nil_id()
|
||||
|
||||
if actor_creation_dummy_object_id is None:
|
||||
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
|
||||
actor_creation_dummy_object_id = ObjectID.nil_id()
|
||||
|
||||
# Put large or complex arguments that are passed by value in the
|
||||
# object store first.
|
||||
args_for_local_scheduler = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
if isinstance(arg, ObjectID):
|
||||
args_for_local_scheduler.append(arg)
|
||||
elif ray.raylet.check_simple_value(arg):
|
||||
args_for_local_scheduler.append(arg)
|
||||
@@ -722,7 +716,7 @@ class Worker(object):
|
||||
arguments are being retrieved.
|
||||
serialized_args (List): The arguments to the function. These are
|
||||
either strings representing serialized objects passed by value
|
||||
or they are ObjectIDs.
|
||||
or they are ray.ObjectIDs.
|
||||
|
||||
Returns:
|
||||
The retrieved arguments in addition to the arguments that were
|
||||
@@ -734,7 +728,7 @@ class Worker(object):
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
if isinstance(arg, ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = self.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
@@ -838,9 +832,9 @@ class Worker(object):
|
||||
outputs = function_executor(*arguments)
|
||||
else:
|
||||
if not task.actor_id().is_nil():
|
||||
key = task.actor_id().id()
|
||||
key = task.actor_id()
|
||||
else:
|
||||
key = task.actor_creation_id().id()
|
||||
key = task.actor_creation_id()
|
||||
outputs = function_executor(dummy_return_id,
|
||||
self.actors[key], *arguments)
|
||||
except Exception as e:
|
||||
@@ -882,7 +876,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id.id(),
|
||||
driver_id=self.task_driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name,
|
||||
@@ -890,7 +884,7 @@ class Worker(object):
|
||||
"class_name": function_descriptor.class_name
|
||||
})
|
||||
# Mark the actor init as failed
|
||||
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
@@ -901,13 +895,13 @@ class Worker(object):
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
driver_id = task.driver_id().id()
|
||||
driver_id = task.driver_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
if not task.actor_creation_id().is_nil():
|
||||
assert self.actor_id == NIL_ACTOR_ID
|
||||
self.actor_id = task.actor_creation_id().id()
|
||||
assert self.actor_id.is_nil()
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.function_actor_manager.load_actor(driver_id,
|
||||
function_descriptor)
|
||||
|
||||
@@ -930,12 +924,12 @@ class Worker(object):
|
||||
title = "ray_worker:{}()".format(function_name)
|
||||
next_title = "ray_worker"
|
||||
else:
|
||||
actor = self.actors[task.actor_creation_id().id()]
|
||||
actor = self.actors[task.actor_creation_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
else:
|
||||
actor = self.actors[task.actor_id().id()]
|
||||
actor = self.actors[task.actor_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
@@ -943,14 +937,14 @@ class Worker(object):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_context.current_task_id = ObjectID.nil_id()
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id == NIL_ACTOR_ID:
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.task_driver_id = ObjectID.nil_id()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
@@ -1104,17 +1098,17 @@ def error_applies_to_driver(error_key, worker=global_worker):
|
||||
+ ray_constants.ID_SIZE), error_key
|
||||
# If the driver ID in the error message is a sequence of all zeros, then
|
||||
# the message is intended for all drivers.
|
||||
driver_id = error_key[len(ERROR_KEY_PREFIX):(
|
||||
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)]
|
||||
return (driver_id == worker.task_driver_id.id()
|
||||
or driver_id == ray.ray_constants.NIL_JOB_ID.id())
|
||||
driver_id = ObjectID(error_key[len(ERROR_KEY_PREFIX):(
|
||||
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)])
|
||||
return (driver_id == worker.task_driver_id
|
||||
or driver_id == ObjectID.nil_id())
|
||||
|
||||
|
||||
def error_info(worker=global_worker):
|
||||
"""Return information about failed tasks."""
|
||||
worker.check_connected()
|
||||
return (global_state.error_messages(job_id=worker.task_driver_id) +
|
||||
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
|
||||
global_state.error_messages(job_id=ObjectID.nil_id()))
|
||||
|
||||
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
@@ -1134,13 +1128,13 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
return obj.id()
|
||||
|
||||
def object_id_custom_deserializer(serialized_obj):
|
||||
return ray.ObjectID(serialized_obj)
|
||||
return ObjectID(serialized_obj)
|
||||
|
||||
# We register this serializer on each worker instead of calling
|
||||
# register_custom_serializer from the driver so that isinstance still
|
||||
# works.
|
||||
serialization_context.register_type(
|
||||
ray.ObjectID,
|
||||
ObjectID,
|
||||
"ray.ObjectID",
|
||||
pickle=False,
|
||||
custom_serializer=object_id_custom_serializer,
|
||||
@@ -1661,7 +1655,7 @@ def listen_error_messages_raylet(worker, task_error_queue):
|
||||
job_id = error_data.JobId()
|
||||
if job_id not in [
|
||||
worker.task_driver_id.id(),
|
||||
ray_constants.NIL_JOB_ID.id()
|
||||
ObjectID.nil_id().id()
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -1772,11 +1766,10 @@ def connect(info,
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = ray.ObjectID(random_string())
|
||||
driver_id = ObjectID(random_string())
|
||||
|
||||
if not isinstance(driver_id, ray.ObjectID):
|
||||
raise Exception(
|
||||
"The type of given driver id must be ray.ObjectID.")
|
||||
if not isinstance(driver_id, ObjectID):
|
||||
raise Exception("The type of given driver id must be ObjectID.")
|
||||
|
||||
worker.worker_id = driver_id.id()
|
||||
|
||||
@@ -1785,11 +1778,11 @@ def connect(info,
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = ray.ObjectID(worker.worker_id)
|
||||
worker.task_driver_id = ObjectID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
worker.actor_id = NIL_ACTOR_ID
|
||||
worker.actor_id = ObjectID.nil_id()
|
||||
worker.connected = True
|
||||
worker.set_mode(mode)
|
||||
|
||||
@@ -1920,13 +1913,13 @@ def connect(info,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
ray.ObjectID(random_string()), # parent_task_id.
|
||||
ObjectID(random_string()), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
|
||||
ObjectID.nil_id(), # actor_creation_id.
|
||||
ObjectID.nil_id(), # actor_creation_dummy_object_id.
|
||||
0, # max_actor_reconstructions.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
|
||||
ObjectID.nil_id(), # actor_id.
|
||||
ObjectID.nil_id(), # actor_handle_id.
|
||||
nil_actor_counter, # actor_counter.
|
||||
[], # new_actor_handles.
|
||||
[], # execution_dependencies.
|
||||
@@ -2148,9 +2141,7 @@ def register_custom_serializer(cls,
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if driver_id is None:
|
||||
driver_id_bytes = worker.task_driver_id.id()
|
||||
else:
|
||||
driver_id_bytes = driver_id.id()
|
||||
driver_id = worker.task_driver_id
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2160,7 +2151,7 @@ def register_custom_serializer(cls,
|
||||
# system.
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
|
||||
"worker"].get_serialization_context(driver_id)
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
@@ -2279,13 +2270,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
IDs.
|
||||
"""
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
if isinstance(object_ids, ObjectID):
|
||||
raise TypeError(
|
||||
"wait() expected a list of ObjectID, got a single ObjectID")
|
||||
"wait() expected a list of ray.ObjectID, got a single ray.ObjectID"
|
||||
)
|
||||
|
||||
if not isinstance(object_ids, list):
|
||||
raise TypeError("wait() expected a list of ObjectID, got {}".format(
|
||||
type(object_ids)))
|
||||
raise TypeError(
|
||||
"wait() expected a list of ray.ObjectID, got {}".format(
|
||||
type(object_ids)))
|
||||
|
||||
if isinstance(timeout, int) and timeout != 0:
|
||||
logger.warning("The 'timeout' argument now requires seconds instead "
|
||||
@@ -2298,8 +2291,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
|
||||
if worker.mode != LOCAL_MODE:
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError("wait() expected a list of ObjectID, "
|
||||
if not isinstance(object_id, ObjectID):
|
||||
raise TypeError("wait() expected a list of ray.ObjectID, "
|
||||
"got list containing {}".format(
|
||||
type(object_id)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user