Refactor code about ray.ObjectID. (#3674)

* Refactor code about ray.ObjectID.

* remove from_random and use nil_id instead of constructor

* remove id() in hash

* Lint and fix

* Change driver id to ObjectID

* Replace binary_to_hex(ObjectID.id()) to ObjectID.hex()
This commit is contained in:
Yuhong Guo
2019-01-13 17:47:29 +08:00
committed by Philipp Moritz
parent c4b058739b
commit d2cf8561f2
14 changed files with 191 additions and 169 deletions
+17 -19
View File
@@ -17,6 +17,7 @@ import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray.utils import _random_string
from ray import ObjectID
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
@@ -41,8 +42,7 @@ def compute_actor_handle_id(actor_handle_id, num_forks):
handle_id_hash.update(actor_handle_id.id())
handle_id_hash.update(str(num_forks).encode("ascii"))
handle_id = handle_id_hash.digest()
assert len(handle_id) == ray_constants.ID_SIZE
return ray.ObjectID(handle_id)
return ObjectID(handle_id)
def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
@@ -69,8 +69,7 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
handle_id_hash.update(actor_handle_id.id())
handle_id_hash.update(current_task_id.id())
handle_id = handle_id_hash.digest()
assert len(handle_id) == ray_constants.ID_SIZE
return ray.ObjectID(handle_id)
return ObjectID(handle_id)
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
@@ -84,7 +83,7 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
checkpoint: The state object to save.
frontier: The task frontier at the time of the checkpoint.
"""
actor_key = b"Actor:" + actor_id
actor_key = b"Actor:" + actor_id.id()
worker.redis_client.hmset(
actor_key, {
"checkpoint_index": checkpoint_index,
@@ -110,7 +109,7 @@ def save_and_log_checkpoint(worker, actor):
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
driver_id=worker.task_driver_id,
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint__.__name__
@@ -134,7 +133,7 @@ def restore_and_log_checkpoint(worker, actor):
worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=worker.task_driver_id.id(),
driver_id=worker.task_driver_id,
data={
"actor_class": actor.__class__.__name__,
"function_name": actor.__ray_checkpoint_restore__.__name__
@@ -156,7 +155,7 @@ def get_actor_checkpoint(worker, actor_id):
exists, all objects are set to None. The checkpoint index is the .
executed on the actor before the checkpoint was made.
"""
actor_key = b"Actor:" + actor_id
actor_key = b"Actor:" + actor_id.id()
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
if checkpoint_index is not None:
@@ -371,7 +370,7 @@ class ActorClass(object):
raise Exception("Actors cannot be created before ray.init() "
"has been called.")
actor_id = ray.ObjectID(_random_string())
actor_id = ObjectID(_random_string())
# The actor cursor is a dummy object representing the most recent
# actor method invocation. For each subsequent method invocation,
# the current cursor should be added as a dependency, and then
@@ -509,8 +508,7 @@ class ActorHandle(object):
# if it was created by the _serialization_helper function.
self._ray_original_handle = actor_handle_id is None
if self._ray_original_handle:
self._ray_actor_handle_id = ray.ObjectID(
ray.worker.NIL_ACTOR_HANDLE_ID)
self._ray_actor_handle_id = ObjectID.nil_id()
else:
self._ray_actor_handle_id = actor_handle_id
self._ray_actor_cursor = actor_cursor
@@ -713,7 +711,7 @@ class ActorHandle(object):
# to release, since it could be unpickled and submit another
# dependent task at any time. Therefore, we notify the backend of a
# random handle ID that will never actually be used.
new_actor_handle_id = ray.ObjectID(_random_string())
new_actor_handle_id = ObjectID(_random_string())
# Notify the backend to expect this new actor handle. The backend will
# not release the cursor for any new handles until the first task for
# each of the new handles is submitted.
@@ -735,7 +733,7 @@ class ActorHandle(object):
worker.check_connected()
if state["ray_forking"]:
actor_handle_id = ray.ObjectID(state["actor_handle_id"])
actor_handle_id = ObjectID(state["actor_handle_id"])
else:
# Right now, if the actor handle has been pickled, we create a
# temporary actor handle id for invocations.
@@ -749,22 +747,22 @@ class ActorHandle(object):
# same actor is likely a performance bug. We should consider
# logging a warning in these cases.
actor_handle_id = compute_actor_handle_id_non_forked(
ray.ObjectID(state["actor_handle_id"]), worker.current_task_id)
ObjectID(state["actor_handle_id"]), worker.current_task_id)
# This is the driver ID of the driver that owns the actor, not
# necessarily the driver that owns this actor handle.
actor_driver_id = ray.ObjectID(state["actor_driver_id"])
actor_driver_id = ObjectID(state["actor_driver_id"])
self.__init__(
ray.ObjectID(state["actor_id"]),
ObjectID(state["actor_id"]),
state["module_name"],
state["class_name"],
ray.ObjectID(state["actor_cursor"])
ObjectID(state["actor_cursor"])
if state["actor_cursor"] is not None else None,
state["actor_method_names"],
state["method_signatures"],
state["method_num_return_vals"],
ray.ObjectID(state["actor_creation_dummy_object_id"])
ObjectID(state["actor_creation_dummy_object_id"])
if state["actor_creation_dummy_object_id"] is not None else None,
state["actor_method_cpus"],
actor_driver_id,
@@ -843,7 +841,7 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
# scheduler has seen. Handle IDs for which no task has yet reached
# the local scheduler will not be included, and may not be runnable
# on checkpoint resumption.
actor_id = ray.ObjectID(worker.actor_id)
actor_id = worker.actor_id
frontier = worker.raylet_client.get_actor_frontier(actor_id)
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
# should not be stored in Redis. Fix this.
+9 -10
View File
@@ -25,7 +25,7 @@ def parse_client_table(redis_client):
Returns:
A list of information about the nodes in the cluster.
"""
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
NIL_CLIENT_ID = ray.ObjectID.nil_id().id()
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
ray.gcs_utils.TablePrefix.CLIENT,
"", NIL_CLIENT_ID)
@@ -308,20 +308,19 @@ class GlobalState(object):
function_descriptor = FunctionDescriptor.from_bytes_list(
function_descriptor_list)
task_spec_info = {
"DriverID": binary_to_hex(task_spec.driver_id().id()),
"TaskID": binary_to_hex(task_spec.task_id().id()),
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
"DriverID": task_spec.driver_id().hex(),
"TaskID": task_spec.task_id().hex(),
"ParentTaskID": task_spec.parent_task_id().hex(),
"ParentCounter": task_spec.parent_counter(),
"ActorID": binary_to_hex(task_spec.actor_id().id()),
"ActorCreationID": binary_to_hex(
task_spec.actor_creation_id().id()),
"ActorCreationDummyObjectID": binary_to_hex(
task_spec.actor_creation_dummy_object_id().id()),
"ActorID": (task_spec.actor_id().hex()),
"ActorCreationID": task_spec.actor_creation_id().hex(),
"ActorCreationDummyObjectID": (
task_spec.actor_creation_dummy_object_id().hex()),
"ActorCounter": task_spec.actor_counter(),
"Args": task_spec.arguments(),
"ReturnObjectIDs": task_spec.returns(),
"RequiredResources": task_spec.required_resources(),
"FunctionID": binary_to_hex(function_descriptor.function_id.id()),
"FunctionID": function_descriptor.function_id.hex(),
"FunctionHash": binary_to_hex(function_descriptor.function_hash),
"ModuleName": function_descriptor.module_name,
"ClassName": function_descriptor.class_name,
+25 -24
View File
@@ -211,7 +211,7 @@ class FunctionDescriptor(object):
Returns:
The value of ray.ObjectID that represents the function id.
"""
return ray.ObjectID(self._function_id)
return self._function_id
def _get_function_id(self):
"""Calculate the function id of current function descriptor.
@@ -220,10 +220,10 @@ class FunctionDescriptor(object):
descriptor.
Returns:
bytes with length of ray_constants.ID_SIZE.
ray.ObjectID to represent the function descriptor.
"""
if self.is_for_driver_task:
return ray_constants.NIL_FUNCTION_ID.id()
return ray.ObjectID.nil_id()
function_id_hash = hashlib.sha1()
# Include the function module and name in the hash.
function_id_hash.update(self.module_name.encode("ascii"))
@@ -232,8 +232,7 @@ class FunctionDescriptor(object):
function_id_hash.update(self._function_source_hash)
# Compute the function ID.
function_id = function_id_hash.digest()
assert len(function_id) == ray_constants.ID_SIZE
return function_id
return ray.ObjectID(function_id)
def get_function_descriptor_list(self):
"""Return a list of bytes representing the function descriptor.
@@ -290,11 +289,11 @@ class FunctionActorManager(object):
self.imported_actor_classes = set()
def increase_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
self._num_task_executions[driver_id][function_id] += 1
def get_task_counter(self, driver_id, function_descriptor):
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
return self._num_task_executions[driver_id][function_id]
def export_cached(self):
@@ -372,13 +371,14 @@ class FunctionActorManager(object):
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
(driver_id, function_id_str, function_name, serialized_function,
(driver_id_str, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.ObjectID(function_id_str)
driver_id = ray.ObjectID(driver_id_str)
function_name = decode(function_name)
max_calls = int(max_calls)
module = decode(module)
@@ -388,10 +388,10 @@ class FunctionActorManager(object):
def f():
raise Exception("This function was not imported properly.")
self._function_execution_info[driver_id][function_id.id()] = (
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=f, function_name=function_name, max_calls=max_calls))
self._num_task_executions[driver_id][function_id.id()] = 0
self._num_task_executions[driver_id][function_id] = 0
try:
function = pickle.loads(serialized_function)
@@ -416,7 +416,7 @@ class FunctionActorManager(object):
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id.id()] = (
self._function_execution_info[driver_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
@@ -435,7 +435,7 @@ class FunctionActorManager(object):
Returns:
A FunctionExecutionInfo object.
"""
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
# Wait until the function to be executed has actually been
# registered on this worker. We will push warnings to the user if
@@ -449,7 +449,7 @@ class FunctionActorManager(object):
except KeyError as e:
message = ("Error occurs in get_execution_info: "
"driver_id: %s, function_descriptor: %s. Message: %s" %
(binary_to_hex(driver_id), function_descriptor, e))
driver_id, function_descriptor, e)
raise KeyError(message)
return info
@@ -474,11 +474,11 @@ class FunctionActorManager(object):
warning_sent = False
while True:
with self._worker.lock:
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
and (function_descriptor.function_id.id() in
if (self._worker.actor_id.is_nil()
and (function_descriptor.function_id in
self._function_execution_info[driver_id])):
break
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
elif not self._worker.actor_id.is_nil() and (
self._worker.actor_id in self._worker.actors):
break
if time.time() - start_time > timeout:
@@ -556,7 +556,7 @@ class FunctionActorManager(object):
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor(self, driver_id, function_descriptor):
key = (b"ActorClass:" + driver_id + b":" +
key = (b"ActorClass:" + driver_id.id() + b":" +
function_descriptor.function_id.id())
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
@@ -578,8 +578,8 @@ class FunctionActorManager(object):
actor_class_key: The key in Redis to use to fetch the actor.
worker: The worker to use.
"""
actor_id_str = self._worker.actor_id
(driver_id, class_name, module, pickled_class, checkpoint_interval,
actor_id = self._worker.actor_id
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
actor_method_names) = self._worker.redis_client.hmget(
actor_class_key, [
"driver_id", "class_name", "module", "class",
@@ -588,6 +588,7 @@ class FunctionActorManager(object):
class_name = decode(class_name)
module = decode(module)
driver_id = ray.ObjectID(driver_id_str)
checkpoint_interval = int(checkpoint_interval)
actor_method_names = json.loads(decode(actor_method_names))
@@ -606,7 +607,7 @@ class FunctionActorManager(object):
class TemporaryActor(object):
pass
self._worker.actors[actor_id_str] = TemporaryActor()
self._worker.actors[actor_id] = TemporaryActor()
self._worker.actor_checkpoint_interval = checkpoint_interval
def temporary_actor_method(*xs):
@@ -618,7 +619,7 @@ class FunctionActorManager(object):
for actor_method_name in actor_method_names:
function_descriptor = FunctionDescriptor(module, actor_method_name,
class_name)
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
temporary_executor = self._make_actor_method_executor(
actor_method_name,
temporary_actor_method,
@@ -644,14 +645,14 @@ class FunctionActorManager(object):
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
traceback_str,
driver_id,
data={"actor_id": actor_id_str})
data={"actor_id": actor_id.id()})
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
else:
# TODO(pcm): Why is the below line necessary?
unpickled_class.__module__ = module
self._worker.actors[actor_id_str] = unpickled_class.__new__(
self._worker.actors[actor_id] = unpickled_class.__new__(
unpickled_class)
actor_methods = inspect.getmembers(
@@ -659,7 +660,7 @@ class FunctionActorManager(object):
for actor_method_name, actor_method in actor_methods:
function_descriptor = FunctionDescriptor(
module, actor_method_name, class_name)
function_id = function_descriptor.function_id.id()
function_id = function_descriptor.function_id
executor = self._make_actor_method_executor(
actor_method_name, actor_method, actor_imported=True)
self._function_execution_info[driver_id][function_id] = (
+1 -1
View File
@@ -58,7 +58,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
The serialized object.
"""
builder = flatbuffers.Builder(0)
driver_offset = builder.CreateString(driver_id)
driver_offset = builder.CreateString(driver_id.id())
error_type_offset = builder.CreateString(error_type)
message_offset = builder.CreateString(message)
+1 -1
View File
@@ -131,5 +131,5 @@ class ImportThread(object):
self.worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=driver_id,
driver_id=ray.ObjectID(driver_id),
data={"name": name})
-4
View File
@@ -5,8 +5,6 @@ from __future__ import print_function
import os
from ray.raylet import ObjectID
def env_integer(key, default):
if key in os.environ:
@@ -15,8 +13,6 @@ def env_integer(key, default):
ID_SIZE = 20
NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff")
NIL_FUNCTION_ID = NIL_JOB_ID
# The default maximum number of bytes to allocate to the object store unless
# overridden by the user.
+2 -2
View File
@@ -4,10 +4,10 @@ from __future__ import print_function
from ray.core.src.ray.raylet.libraylet_library_python import (
Task, RayletClient, ObjectID, check_simple_value, compute_task_id,
task_from_string, task_to_string, _config, common_error)
task_from_string, task_to_string, _config, RayCommonError)
__all__ = [
"Task", "RayletClient", "ObjectID", "check_simple_value",
"compute_task_id", "task_from_string", "task_to_string",
"start_local_scheduler", "_config", "common_error"
"start_local_scheduler", "_config", "RayCommonError"
]
+9 -8
View File
@@ -67,10 +67,10 @@ def push_error_to_driver(worker,
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
driver_id = ray.ObjectID.nil_id()
data = {} if data is None else data
worker.raylet_client.push_error(
ray.ObjectID(driver_id), error_type, message, time.time())
worker.raylet_client.push_error(driver_id, error_type, message,
time.time())
def push_error_to_driver_through_redis(redis_client,
@@ -96,15 +96,16 @@ def push_error_to_driver_through_redis(redis_client,
will be serialized with json and stored in Redis.
"""
if driver_id is None:
driver_id = ray_constants.NIL_JOB_ID.id()
driver_id = ray.ObjectID.nil_id()
data = {} if data is None else data
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
redis_client.execute_command("RAY.TABLE_APPEND",
ray.gcs_utils.TablePrefix.ERROR_INFO,
ray.gcs_utils.TablePubsub.ERROR_INFO,
driver_id.id(), error_data)
def is_cython(obj):
@@ -400,7 +401,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
worker,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id.id())
driver_id=worker.task_driver_id)
class _ThreadSafeProxy(object):
+64 -71
View File
@@ -36,6 +36,7 @@ import ray.raylet
import ray.plasma
import ray.ray_constants as ray_constants
from ray import import_thread
from ray import ObjectID
from ray import profiling
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
import ray.parameter
@@ -53,13 +54,6 @@ PYTHON_MODE = 3
ERROR_KEY_PREFIX = b"Error:"
# This must match the definition of NIL_ACTOR_ID in task.h.
NIL_ID = ray_constants.ID_SIZE * b"\xff"
NIL_LOCAL_SCHEDULER_ID = NIL_ID
NIL_ACTOR_ID = NIL_ID
NIL_ACTOR_HANDLE_ID = NIL_ID
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
# Default resource requirements for actors when no resource requirements are
# specified.
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
@@ -168,7 +162,7 @@ class Worker(object):
self.serialization_context_map = {}
self.function_actor_manager = FunctionActorManager(self)
# Identity of the driver that this worker is processing.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.task_driver_id = ObjectID.nil_id()
self._task_context = threading.local()
@property
@@ -189,14 +183,13 @@ class Worker(object):
# If this is running on the main thread, initialize it to
# NIL. The actual value will set when the worker receives
# a task from raylet backend.
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
self._task_context.current_task_id = ObjectID.nil_id()
else:
# If this is running on a separate thread, then the mapping
# to the current task ID may not be correct. Generate a
# random task ID so that the backend can differentiate
# between different threads.
self._task_context.current_task_id = ray.ObjectID(
random_string())
self._task_context.current_task_id = ObjectID(random_string())
if getattr(self, '_multithreading_warned', False) is not True:
logger.warning(
"Calling ray.get or ray.wait in a separate thread "
@@ -353,12 +346,13 @@ class Worker(object):
full.
"""
# Make sure that the value is not an object ID.
if isinstance(value, ray.ObjectID):
raise Exception("Calling 'put' on an ObjectID is not allowed "
"(similarly, returning an ObjectID from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ObjectID in a list and "
"call 'put' on it (or return it).")
if isinstance(value, ObjectID):
raise Exception(
"Calling 'put' on an ray.ObjectID is not allowed "
"(similarly, returning an ray.ObjectID from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ray.ObjectID in a list and "
"call 'put' on it (or return it).")
# Serialize and put the object in the object store.
try:
@@ -433,7 +427,7 @@ class Worker(object):
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id.id())
driver_id=self.task_driver_id)
warning_sent = True
def get_object(self, object_ids):
@@ -449,9 +443,10 @@ class Worker(object):
"""
# Make sure that the values are object IDs.
for object_id in object_ids:
if not isinstance(object_id, ray.ObjectID):
raise Exception("Attempting to call `get` on the value {}, "
"which is not an ObjectID.".format(object_id))
if not isinstance(object_id, ObjectID):
raise Exception(
"Attempting to call `get` on the value {}, "
"which is not an ray.ObjectID.".format(object_id))
# Do an initial fetch for remote objects. We divide the fetch into
# smaller fetches so as to not block the manager for a prolonged period
# of time in a single call.
@@ -484,8 +479,7 @@ class Worker(object):
for unready_id in unready_ids.keys()
]
ray_object_ids_to_fetch = [
ray.ObjectID(unready_id)
for unready_id in unready_ids.keys()
ObjectID(unready_id) for unready_id in unready_ids.keys()
]
fetch_request_size = ray._config.worker_fetch_request_size()
for i in range(0, len(object_ids_to_fetch),
@@ -574,22 +568,22 @@ class Worker(object):
with profiling.profile("submit_task", worker=self):
if actor_id is None:
assert actor_handle_id is None
actor_id = ray.ObjectID(NIL_ACTOR_ID)
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
actor_id = ObjectID.nil_id()
actor_handle_id = ObjectID.nil_id()
else:
assert actor_handle_id is not None
if actor_creation_id is None:
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
actor_creation_id = ObjectID.nil_id()
if actor_creation_dummy_object_id is None:
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
actor_creation_dummy_object_id = ObjectID.nil_id()
# Put large or complex arguments that are passed by value in the
# object store first.
args_for_local_scheduler = []
for arg in args:
if isinstance(arg, ray.ObjectID):
if isinstance(arg, ObjectID):
args_for_local_scheduler.append(arg)
elif ray.raylet.check_simple_value(arg):
args_for_local_scheduler.append(arg)
@@ -722,7 +716,7 @@ class Worker(object):
arguments are being retrieved.
serialized_args (List): The arguments to the function. These are
either strings representing serialized objects passed by value
or they are ObjectIDs.
or they are ray.ObjectIDs.
Returns:
The retrieved arguments in addition to the arguments that were
@@ -734,7 +728,7 @@ class Worker(object):
"""
arguments = []
for (i, arg) in enumerate(serialized_args):
if isinstance(arg, ray.ObjectID):
if isinstance(arg, ObjectID):
# get the object from the local object store
argument = self.get_object([arg])[0]
if isinstance(argument, RayTaskError):
@@ -838,9 +832,9 @@ class Worker(object):
outputs = function_executor(*arguments)
else:
if not task.actor_id().is_nil():
key = task.actor_id().id()
key = task.actor_id()
else:
key = task.actor_creation_id().id()
key = task.actor_creation_id()
outputs = function_executor(dummy_return_id,
self.actors[key], *arguments)
except Exception as e:
@@ -882,7 +876,7 @@ class Worker(object):
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id.id(),
driver_id=self.task_driver_id,
data={
"function_id": function_id.id(),
"function_name": function_name,
@@ -890,7 +884,7 @@ class Worker(object):
"class_name": function_descriptor.class_name
})
# Mark the actor init as failed
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
if not self.actor_id.is_nil() and function_name == "__init__":
self.mark_actor_init_failed(error)
def _wait_for_and_process_task(self, task):
@@ -901,13 +895,13 @@ class Worker(object):
"""
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
driver_id = task.driver_id().id()
driver_id = task.driver_id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
if not task.actor_creation_id().is_nil():
assert self.actor_id == NIL_ACTOR_ID
self.actor_id = task.actor_creation_id().id()
assert self.actor_id.is_nil()
self.actor_id = task.actor_creation_id()
self.function_actor_manager.load_actor(driver_id,
function_descriptor)
@@ -930,12 +924,12 @@ class Worker(object):
title = "ray_worker:{}()".format(function_name)
next_title = "ray_worker"
else:
actor = self.actors[task.actor_creation_id().id()]
actor = self.actors[task.actor_creation_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
else:
actor = self.actors[task.actor_id().id()]
actor = self.actors[task.actor_id()]
title = "ray_{}:{}()".format(actor.__class__.__name__,
function_name)
next_title = "ray_{}".format(actor.__class__.__name__)
@@ -943,14 +937,14 @@ class Worker(object):
with _changeproctitle(title, next_title):
self._process_task(task, execution_info)
# Reset the state fields so the next task can run.
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
self.task_context.current_task_id = ObjectID.nil_id()
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id == NIL_ACTOR_ID:
if self.actor_id.is_nil():
# Don't need to reset task_driver_id if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = ray.ObjectID(NIL_ID)
self.task_driver_id = ObjectID.nil_id()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
@@ -1104,17 +1098,17 @@ def error_applies_to_driver(error_key, worker=global_worker):
+ ray_constants.ID_SIZE), error_key
# If the driver ID in the error message is a sequence of all zeros, then
# the message is intended for all drivers.
driver_id = error_key[len(ERROR_KEY_PREFIX):(
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)]
return (driver_id == worker.task_driver_id.id()
or driver_id == ray.ray_constants.NIL_JOB_ID.id())
driver_id = ObjectID(error_key[len(ERROR_KEY_PREFIX):(
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)])
return (driver_id == worker.task_driver_id
or driver_id == ObjectID.nil_id())
def error_info(worker=global_worker):
"""Return information about failed tasks."""
worker.check_connected()
return (global_state.error_messages(job_id=worker.task_driver_id) +
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
global_state.error_messages(job_id=ObjectID.nil_id()))
def _initialize_serialization(driver_id, worker=global_worker):
@@ -1134,13 +1128,13 @@ def _initialize_serialization(driver_id, worker=global_worker):
return obj.id()
def object_id_custom_deserializer(serialized_obj):
return ray.ObjectID(serialized_obj)
return ObjectID(serialized_obj)
# We register this serializer on each worker instead of calling
# register_custom_serializer from the driver so that isinstance still
# works.
serialization_context.register_type(
ray.ObjectID,
ObjectID,
"ray.ObjectID",
pickle=False,
custom_serializer=object_id_custom_serializer,
@@ -1661,7 +1655,7 @@ def listen_error_messages_raylet(worker, task_error_queue):
job_id = error_data.JobId()
if job_id not in [
worker.task_driver_id.id(),
ray_constants.NIL_JOB_ID.id()
ObjectID.nil_id().id()
]:
continue
@@ -1772,11 +1766,10 @@ def connect(info,
else:
# This is the code path of driver mode.
if driver_id is None:
driver_id = ray.ObjectID(random_string())
driver_id = ObjectID(random_string())
if not isinstance(driver_id, ray.ObjectID):
raise Exception(
"The type of given driver id must be ray.ObjectID.")
if not isinstance(driver_id, ObjectID):
raise Exception("The type of given driver id must be ObjectID.")
worker.worker_id = driver_id.id()
@@ -1785,11 +1778,11 @@ def connect(info,
# responsible for the task so that error messages will be propagated to
# the correct driver.
if mode != WORKER_MODE:
worker.task_driver_id = ray.ObjectID(worker.worker_id)
worker.task_driver_id = ObjectID(worker.worker_id)
# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
worker.actor_id = NIL_ACTOR_ID
worker.actor_id = ObjectID.nil_id()
worker.connected = True
worker.set_mode(mode)
@@ -1920,13 +1913,13 @@ def connect(info,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
ray.ObjectID(random_string()), # parent_task_id.
ObjectID(random_string()), # parent_task_id.
0, # parent_counter.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
ObjectID.nil_id(), # actor_creation_id.
ObjectID.nil_id(), # actor_creation_dummy_object_id.
0, # max_actor_reconstructions.
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
ObjectID.nil_id(), # actor_id.
ObjectID.nil_id(), # actor_handle_id.
nil_actor_counter, # actor_counter.
[], # new_actor_handles.
[], # execution_dependencies.
@@ -2148,9 +2141,7 @@ def register_custom_serializer(cls,
class_id = ray.utils.binary_to_hex(class_id)
if driver_id is None:
driver_id_bytes = worker.task_driver_id.id()
else:
driver_id_bytes = driver_id.id()
driver_id = worker.task_driver_id
def register_class_for_serialization(worker_info):
# TODO(rkn): We need to be more thoughtful about what to do if custom
@@ -2160,7 +2151,7 @@ def register_custom_serializer(cls,
# system.
serialization_context = worker_info[
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
"worker"].get_serialization_context(driver_id)
serialization_context.register_type(
cls,
class_id,
@@ -2279,13 +2270,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
IDs.
"""
if isinstance(object_ids, ray.ObjectID):
if isinstance(object_ids, ObjectID):
raise TypeError(
"wait() expected a list of ObjectID, got a single ObjectID")
"wait() expected a list of ray.ObjectID, got a single ray.ObjectID"
)
if not isinstance(object_ids, list):
raise TypeError("wait() expected a list of ObjectID, got {}".format(
type(object_ids)))
raise TypeError(
"wait() expected a list of ray.ObjectID, got {}".format(
type(object_ids)))
if isinstance(timeout, int) and timeout != 0:
logger.warning("The 'timeout' argument now requires seconds instead "
@@ -2298,8 +2291,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
if worker.mode != LOCAL_MODE:
for object_id in object_ids:
if not isinstance(object_id, ray.ObjectID):
raise TypeError("wait() expected a list of ObjectID, "
if not isinstance(object_id, ObjectID):
raise TypeError("wait() expected a list of ray.ObjectID, "
"got list containing {}".format(
type(object_id)))