Refactor code about ray.ObjectID. (#3674)

* Refactor code about ray.ObjectID.

* remove from_random and use nil_id instead of constructor

* remove id() in hash

* Lint and fix

* Change driver id to ObjectID

* Replace binary_to_hex(ObjectID.id()) to ObjectID.hex()
This commit is contained in:
Yuhong Guo
2019-01-13 17:47:29 +08:00
committed by Philipp Moritz
parent c4b058739b
commit d2cf8561f2
14 changed files with 191 additions and 169 deletions
+64 -71
View File
@@ -36,6 +36,7 @@ import ray.raylet
import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import ObjectID
from ray import profiling
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
import ray.parameter
@@ -53,13 +54,6 @@ PYTHON_MODE = 3
ERROR_KEY_PREFIX = b"Error:"
# This must match the definition of NIL_ACTOR_ID in task.h.
NIL_ID = ray_constants.ID_SIZE * b"\xff"
NIL_LOCAL_SCHEDULER_ID = NIL_ID
NIL_ACTOR_ID = NIL_ID
NIL_ACTOR_HANDLE_ID = NIL_ID
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
# Default resource requirements for actors when no resource requirements are
# specified.
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
@@ -168,7 +162,7 @@ class Worker(object):
self.serialization_context_map = {}
self.function_actor_manager = FunctionActorManager(self)
# Identity of the driver that this worker is processing.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.task_driver_id = ObjectID.nil_id()
self._task_context = threading.local()
@property
@@ -189,14 +183,13 @@ class Worker(object):
# If this is running on the main thread, initialize it to
# NIL. The actual value will set when the worker receives
# a task from raylet backend.
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
self._task_context.current_task_id = ObjectID.nil_id()
else:
# If this is running on a separate thread, then the mapping
# to the current task ID may not be correct. Generate a
# random task ID so that the backend can differentiate
# between different threads.
self._task_context.current_task_id = ray.ObjectID(
random_string())
self._task_context.current_task_id = ObjectID(random_string())
if getattr(self, '_multithreading_warned', False) is not True:
logger.warning(
"Calling ray.get or ray.wait in a separate thread "
@@ -353,12 +346,13 @@ class Worker(object):
full.
"""
# Make sure that the value is not an object ID.
if isinstance(value, ray.ObjectID):
raise Exception("Calling 'put' on an ObjectID is not allowed "
"(similarly, returning an ObjectID from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ObjectID in a list and "
"call 'put' on it (or return it).")
if isinstance(value, ObjectID):
raise Exception(
"Calling 'put' on an ray.ObjectID is not allowed "
"(similarly, returning an ray.ObjectID from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ray.ObjectID in a list and "
"call 'put' on it (or return it).")
# Serialize and put the object in the object store.
try:
@@ -433,7 +427,7 @@ class Worker(object):
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
driver_id=self.task_driver_id)
warning_sent = True
def get_object(self, object_ids):
@@ -449,9 +443,10 @@ class Worker(object):
"""
# Make sure that the values are object IDs.
for object_id in object_ids:
if not isinstance(object_id, ray.ObjectID):
raise Exception("Attempting to call `get` on the value {}, "
"which is not an ObjectID.".format(object_id))
if not isinstance(object_id, ObjectID):
raise Exception(
"Attempting to call `get` on the value {}, "
"which is not an ray.ObjectID.".format(object_id))
# Do an initial fetch for remote objects. We divide the fetch into
# smaller fetches so as to not block the manager for a prolonged period
# of time in a single call.
@@ -484,8 +479,7 @@ class Worker(object):
for unready_id in unready_ids.keys()
]
ray_object_ids_to_fetch = [
ray.ObjectID(unready_id)
for unready_id in unready_ids.keys()
ObjectID(unready_id) for unready_id in unready_ids.keys()
]
fetch_request_size = ray._config.worker_fetch_request_size()
for i in range(0, len(object_ids_to_fetch),
@@ -574,22 +568,22 @@ class Worker(object):
with profiling.profile("submit_task", worker=self):
if actor_id is None:
assert actor_handle_id is None
actor_id = ray.ObjectID(NIL_ACTOR_ID)
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
actor_id = ObjectID.nil_id()
actor_handle_id = ObjectID.nil_id()
else:
assert actor_handle_id is not None
if actor_creation_id is None:
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
actor_creation_id = ObjectID.nil_id()
if actor_creation_dummy_object_id is None:
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
actor_creation_dummy_object_id = ObjectID.nil_id()
# Put large or complex arguments that are passed by value in the
# object store first.
args_for_local_scheduler = []
for arg in args:
if isinstance(arg, ray.ObjectID):
if isinstance(arg, ObjectID):
args_for_local_scheduler.append(arg)
elif ray.raylet.check_simple_value(arg):
args_for_local_scheduler.append(arg)
@@ -722,7 +716,7 @@ class Worker(object):
arguments are being retrieved.
serialized_args (List): The arguments to the function. These are
either strings representing serialized objects passed by value
or they are ObjectIDs.
or they are ray.ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were
@@ -734,7 +728,7 @@ class Worker(object):
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, ray.ObjectID):
if isinstance(arg, ObjectID):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
@@ -838,9 +832,9 @@ class Worker(object):
outputs = function_executor(*arguments)
else:
if not task.actor_id().is_nil():
key = task.actor_id().id()
key = task.actor_id()
else:
key = task.actor_creation_id().id()
key = task.actor_creation_id()
outputs = function_executor(dummy_return_id,
self.actors[key], *arguments)
except Exception as e:
@@ -882,7 +876,7 @@ class Worker(object):
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id.id(),
driver_id=self.task_driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name,
@@ -890,7 +884,7 @@ class Worker(object):
"class_name": function_descriptor.class_name
})
# Mark the actor init as failed
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
if not self.actor_id.is_nil() and function_name == "__init__":
self.mark_actor_init_failed(error)
def _wait_for_and_process_task(self, task):
@@ -901,13 +895,13 @@ class Worker(object):
"""
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
driver_id = task.driver_id().id()
driver_id = task.driver_id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if not task.actor_creation_id().is_nil():
assert self.actor_id == NIL_ACTOR_ID
self.actor_id = task.actor_creation_id().id()
assert self.actor_id.is_nil()
self.actor_id = task.actor_creation_id()
self.function_actor_manager.load_actor(driver_id,
function_descriptor)
@@ -930,12 +924,12 @@ class Worker(object):
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id().id()]
actor = self.actors[task.actor_creation_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
else:
actor = self.actors[task.actor_id().id()]
actor = self.actors[task.actor_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
@@ -943,14 +937,14 @@ class Worker(object):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
self.task_context.current_task_id = ObjectID.nil_id()
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id == NIL_ACTOR_ID:
if self.actor_id.is_nil():
# Don't need to reset task_driver_id if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.task_driver_id = ObjectID.nil_id()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
@@ -1104,17 +1098,17 @@ def error_applies_to_driver(error_key, worker=global_worker):
+ ray_constants.ID_SIZE), error_key
# If the driver ID in the error message is a sequence of all zeros, then
# the message is intended for all drivers.
driver_id = error_key[len(ERROR_KEY_PREFIX):(
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)]
return (driver_id == worker.task_driver_id.id()
or driver_id == ray.ray_constants.NIL_JOB_ID.id())
driver_id = ObjectID(error_key[len(ERROR_KEY_PREFIX):(
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)])
return (driver_id == worker.task_driver_id
or driver_id == ObjectID.nil_id())
def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
global_state.error_messages(job_id=ObjectID.nil_id()))
def _initialize_serialization(driver_id, worker=global_worker):
@@ -1134,13 +1128,13 @@ def _initialize_serialization(driver_id, worker=global_worker):
return obj.id()
def object_id_custom_deserializer(serialized_obj):
return ray.ObjectID(serialized_obj)
return ObjectID(serialized_obj)
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
serialization_context.register_type(
ray.ObjectID,
ObjectID,
"ray.ObjectID",
pickle=False,
custom_serializer=object_id_custom_serializer,
@@ -1661,7 +1655,7 @@ def listen_error_messages_raylet(worker, task_error_queue):
job_id = error_data.JobId()
if job_id not in [
worker.task_driver_id.id(),
ray_constants.NIL_JOB_ID.id()
ObjectID.nil_id().id()
]:
continue
@@ -1772,11 +1766,10 @@ def connect(info,
else:
# This is the code path of driver mode.
if driver_id is None:
driver_id = ray.ObjectID(random_string())
driver_id = ObjectID(random_string())
if not isinstance(driver_id, ray.ObjectID):
raise Exception(
"The type of given driver id must be ray.ObjectID.")
if not isinstance(driver_id, ObjectID):
raise Exception("The type of given driver id must be ObjectID.")
worker.worker_id = driver_id.id()
@@ -1785,11 +1778,11 @@ def connect(info,
# responsible for the task so that error messages will be propagated to
# the correct driver.
if mode != WORKER_MODE:
worker.task_driver_id = ray.ObjectID(worker.worker_id)
worker.task_driver_id = ObjectID(worker.worker_id)
# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
worker.actor_id = NIL_ACTOR_ID
worker.actor_id = ObjectID.nil_id()
worker.connected = True
worker.set_mode(mode)
@@ -1920,13 +1913,13 @@ def connect(info,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
ray.ObjectID(random_string()), # parent_task_id.
ObjectID(random_string()), # parent_task_id.
0, # parent_counter.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
ObjectID.nil_id(), # actor_creation_id.
ObjectID.nil_id(), # actor_creation_dummy_object_id.
0, # max_actor_reconstructions.
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
ObjectID.nil_id(), # actor_id.
ObjectID.nil_id(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
[], # execution_dependencies.
@@ -2148,9 +2141,7 @@ def register_custom_serializer(cls,
class_id = ray.utils.binary_to_hex(class_id)
if driver_id is None:
driver_id_bytes = worker.task_driver_id.id()
else:
driver_id_bytes = driver_id.id()
driver_id = worker.task_driver_id
def register_class_for_serialization(worker_info):
# TODO(rkn): We need to be more thoughtful about what to do if custom
@@ -2160,7 +2151,7 @@ def register_custom_serializer(cls,
# system.
serialization_context = worker_info[
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
"worker"].get_serialization_context(driver_id)
serialization_context.register_type(
cls,
class_id,
@@ -2279,13 +2270,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
IDs.
"""
if isinstance(object_ids, ray.ObjectID):
if isinstance(object_ids, ObjectID):
raise TypeError(
"wait() expected a list of ObjectID, got a single ObjectID")
"wait() expected a list of ray.ObjectID, got a single ray.ObjectID"
)
if not isinstance(object_ids, list):
raise TypeError("wait() expected a list of ObjectID, got {}".format(
type(object_ids)))
raise TypeError(
"wait() expected a list of ray.ObjectID, got {}".format(
type(object_ids)))
if isinstance(timeout, int) and timeout != 0:
logger.warning("The 'timeout' argument now requires seconds instead "
@@ -2298,8 +2291,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if worker.mode != LOCAL_MODE:
for object_id in object_ids:
if not isinstance(object_id, ray.ObjectID):
raise TypeError("wait() expected a list of ObjectID, "
if not isinstance(object_id, ObjectID):
raise TypeError("wait() expected a list of ray.ObjectID, "
"got list containing {}".format(
type(object_id)))