mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 14:31:15 +08:00
Refactor code about ray.ObjectID. (#3674)
* Refactor code about ray.ObjectID. * remove from_random and use nil_id instead of constructor * remove id() in hash * Lint and fix * Change driver id to ObjectID * Replace binary_to_hex(ObjectID.id()) to ObjectID.hex()
This commit is contained in:
committed by
Philipp Moritz
parent
c4b058739b
commit
d2cf8561f2
+17
-19
@@ -17,6 +17,7 @@ import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray.utils import _random_string
|
||||
from ray import ObjectID
|
||||
|
||||
DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS = 1
|
||||
|
||||
@@ -41,8 +42,7 @@ def compute_actor_handle_id(actor_handle_id, num_forks):
|
||||
handle_id_hash.update(actor_handle_id.id())
|
||||
handle_id_hash.update(str(num_forks).encode("ascii"))
|
||||
handle_id = handle_id_hash.digest()
|
||||
assert len(handle_id) == ray_constants.ID_SIZE
|
||||
return ray.ObjectID(handle_id)
|
||||
return ObjectID(handle_id)
|
||||
|
||||
|
||||
def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
|
||||
@@ -69,8 +69,7 @@ def compute_actor_handle_id_non_forked(actor_handle_id, current_task_id):
|
||||
handle_id_hash.update(actor_handle_id.id())
|
||||
handle_id_hash.update(current_task_id.id())
|
||||
handle_id = handle_id_hash.digest()
|
||||
assert len(handle_id) == ray_constants.ID_SIZE
|
||||
return ray.ObjectID(handle_id)
|
||||
return ObjectID(handle_id)
|
||||
|
||||
|
||||
def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
@@ -84,7 +83,7 @@ def set_actor_checkpoint(worker, actor_id, checkpoint_index, checkpoint,
|
||||
checkpoint: The state object to save.
|
||||
frontier: The task frontier at the time of the checkpoint.
|
||||
"""
|
||||
actor_key = b"Actor:" + actor_id
|
||||
actor_key = b"Actor:" + actor_id.id()
|
||||
worker.redis_client.hmset(
|
||||
actor_key, {
|
||||
"checkpoint_index": checkpoint_index,
|
||||
@@ -110,7 +109,7 @@ def save_and_log_checkpoint(worker, actor):
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id.id(),
|
||||
driver_id=worker.task_driver_id,
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint__.__name__
|
||||
@@ -134,7 +133,7 @@ def restore_and_log_checkpoint(worker, actor):
|
||||
worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=worker.task_driver_id.id(),
|
||||
driver_id=worker.task_driver_id,
|
||||
data={
|
||||
"actor_class": actor.__class__.__name__,
|
||||
"function_name": actor.__ray_checkpoint_restore__.__name__
|
||||
@@ -156,7 +155,7 @@ def get_actor_checkpoint(worker, actor_id):
|
||||
exists, all objects are set to None. The checkpoint index is the .
|
||||
executed on the actor before the checkpoint was made.
|
||||
"""
|
||||
actor_key = b"Actor:" + actor_id
|
||||
actor_key = b"Actor:" + actor_id.id()
|
||||
checkpoint_index, checkpoint, frontier = worker.redis_client.hmget(
|
||||
actor_key, ["checkpoint_index", "checkpoint", "frontier"])
|
||||
if checkpoint_index is not None:
|
||||
@@ -371,7 +370,7 @@ class ActorClass(object):
|
||||
raise Exception("Actors cannot be created before ray.init() "
|
||||
"has been called.")
|
||||
|
||||
actor_id = ray.ObjectID(_random_string())
|
||||
actor_id = ObjectID(_random_string())
|
||||
# The actor cursor is a dummy object representing the most recent
|
||||
# actor method invocation. For each subsequent method invocation,
|
||||
# the current cursor should be added as a dependency, and then
|
||||
@@ -509,8 +508,7 @@ class ActorHandle(object):
|
||||
# if it was created by the _serialization_helper function.
|
||||
self._ray_original_handle = actor_handle_id is None
|
||||
if self._ray_original_handle:
|
||||
self._ray_actor_handle_id = ray.ObjectID(
|
||||
ray.worker.NIL_ACTOR_HANDLE_ID)
|
||||
self._ray_actor_handle_id = ObjectID.nil_id()
|
||||
else:
|
||||
self._ray_actor_handle_id = actor_handle_id
|
||||
self._ray_actor_cursor = actor_cursor
|
||||
@@ -713,7 +711,7 @@ class ActorHandle(object):
|
||||
# to release, since it could be unpickled and submit another
|
||||
# dependent task at any time. Therefore, we notify the backend of a
|
||||
# random handle ID that will never actually be used.
|
||||
new_actor_handle_id = ray.ObjectID(_random_string())
|
||||
new_actor_handle_id = ObjectID(_random_string())
|
||||
# Notify the backend to expect this new actor handle. The backend will
|
||||
# not release the cursor for any new handles until the first task for
|
||||
# each of the new handles is submitted.
|
||||
@@ -735,7 +733,7 @@ class ActorHandle(object):
|
||||
worker.check_connected()
|
||||
|
||||
if state["ray_forking"]:
|
||||
actor_handle_id = ray.ObjectID(state["actor_handle_id"])
|
||||
actor_handle_id = ObjectID(state["actor_handle_id"])
|
||||
else:
|
||||
# Right now, if the actor handle has been pickled, we create a
|
||||
# temporary actor handle id for invocations.
|
||||
@@ -749,22 +747,22 @@ class ActorHandle(object):
|
||||
# same actor is likely a performance bug. We should consider
|
||||
# logging a warning in these cases.
|
||||
actor_handle_id = compute_actor_handle_id_non_forked(
|
||||
ray.ObjectID(state["actor_handle_id"]), worker.current_task_id)
|
||||
ObjectID(state["actor_handle_id"]), worker.current_task_id)
|
||||
|
||||
# This is the driver ID of the driver that owns the actor, not
|
||||
# necessarily the driver that owns this actor handle.
|
||||
actor_driver_id = ray.ObjectID(state["actor_driver_id"])
|
||||
actor_driver_id = ObjectID(state["actor_driver_id"])
|
||||
|
||||
self.__init__(
|
||||
ray.ObjectID(state["actor_id"]),
|
||||
ObjectID(state["actor_id"]),
|
||||
state["module_name"],
|
||||
state["class_name"],
|
||||
ray.ObjectID(state["actor_cursor"])
|
||||
ObjectID(state["actor_cursor"])
|
||||
if state["actor_cursor"] is not None else None,
|
||||
state["actor_method_names"],
|
||||
state["method_signatures"],
|
||||
state["method_num_return_vals"],
|
||||
ray.ObjectID(state["actor_creation_dummy_object_id"])
|
||||
ObjectID(state["actor_creation_dummy_object_id"])
|
||||
if state["actor_creation_dummy_object_id"] is not None else None,
|
||||
state["actor_method_cpus"],
|
||||
actor_driver_id,
|
||||
@@ -843,7 +841,7 @@ def make_actor(cls, num_cpus, num_gpus, resources, actor_method_cpus,
|
||||
# scheduler has seen. Handle IDs for which no task has yet reached
|
||||
# the local scheduler will not be included, and may not be runnable
|
||||
# on checkpoint resumption.
|
||||
actor_id = ray.ObjectID(worker.actor_id)
|
||||
actor_id = worker.actor_id
|
||||
frontier = worker.raylet_client.get_actor_frontier(actor_id)
|
||||
# Save the checkpoint in Redis. TODO(rkn): Checkpoints
|
||||
# should not be stored in Redis. Fix this.
|
||||
|
||||
@@ -25,7 +25,7 @@ def parse_client_table(redis_client):
|
||||
Returns:
|
||||
A list of information about the nodes in the cluster.
|
||||
"""
|
||||
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
NIL_CLIENT_ID = ray.ObjectID.nil_id().id()
|
||||
message = redis_client.execute_command("RAY.TABLE_LOOKUP",
|
||||
ray.gcs_utils.TablePrefix.CLIENT,
|
||||
"", NIL_CLIENT_ID)
|
||||
@@ -308,20 +308,19 @@ class GlobalState(object):
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
function_descriptor_list)
|
||||
task_spec_info = {
|
||||
"DriverID": binary_to_hex(task_spec.driver_id().id()),
|
||||
"TaskID": binary_to_hex(task_spec.task_id().id()),
|
||||
"ParentTaskID": binary_to_hex(task_spec.parent_task_id().id()),
|
||||
"DriverID": task_spec.driver_id().hex(),
|
||||
"TaskID": task_spec.task_id().hex(),
|
||||
"ParentTaskID": task_spec.parent_task_id().hex(),
|
||||
"ParentCounter": task_spec.parent_counter(),
|
||||
"ActorID": binary_to_hex(task_spec.actor_id().id()),
|
||||
"ActorCreationID": binary_to_hex(
|
||||
task_spec.actor_creation_id().id()),
|
||||
"ActorCreationDummyObjectID": binary_to_hex(
|
||||
task_spec.actor_creation_dummy_object_id().id()),
|
||||
"ActorID": (task_spec.actor_id().hex()),
|
||||
"ActorCreationID": task_spec.actor_creation_id().hex(),
|
||||
"ActorCreationDummyObjectID": (
|
||||
task_spec.actor_creation_dummy_object_id().hex()),
|
||||
"ActorCounter": task_spec.actor_counter(),
|
||||
"Args": task_spec.arguments(),
|
||||
"ReturnObjectIDs": task_spec.returns(),
|
||||
"RequiredResources": task_spec.required_resources(),
|
||||
"FunctionID": binary_to_hex(function_descriptor.function_id.id()),
|
||||
"FunctionID": function_descriptor.function_id.hex(),
|
||||
"FunctionHash": binary_to_hex(function_descriptor.function_hash),
|
||||
"ModuleName": function_descriptor.module_name,
|
||||
"ClassName": function_descriptor.class_name,
|
||||
|
||||
@@ -211,7 +211,7 @@ class FunctionDescriptor(object):
|
||||
Returns:
|
||||
The value of ray.ObjectID that represents the function id.
|
||||
"""
|
||||
return ray.ObjectID(self._function_id)
|
||||
return self._function_id
|
||||
|
||||
def _get_function_id(self):
|
||||
"""Calculate the function id of current function descriptor.
|
||||
@@ -220,10 +220,10 @@ class FunctionDescriptor(object):
|
||||
descriptor.
|
||||
|
||||
Returns:
|
||||
bytes with length of ray_constants.ID_SIZE.
|
||||
ray.ObjectID to represent the function descriptor.
|
||||
"""
|
||||
if self.is_for_driver_task:
|
||||
return ray_constants.NIL_FUNCTION_ID.id()
|
||||
return ray.ObjectID.nil_id()
|
||||
function_id_hash = hashlib.sha1()
|
||||
# Include the function module and name in the hash.
|
||||
function_id_hash.update(self.module_name.encode("ascii"))
|
||||
@@ -232,8 +232,7 @@ class FunctionDescriptor(object):
|
||||
function_id_hash.update(self._function_source_hash)
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
assert len(function_id) == ray_constants.ID_SIZE
|
||||
return function_id
|
||||
return ray.ObjectID(function_id)
|
||||
|
||||
def get_function_descriptor_list(self):
|
||||
"""Return a list of bytes representing the function descriptor.
|
||||
@@ -290,11 +289,11 @@ class FunctionActorManager(object):
|
||||
self.imported_actor_classes = set()
|
||||
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id.id()
|
||||
function_id = function_descriptor.function_id
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id.id()
|
||||
function_id = function_descriptor.function_id
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
@@ -372,13 +371,14 @@ class FunctionActorManager(object):
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
(driver_id, function_id_str, function_name, serialized_function,
|
||||
(driver_id_str, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self._worker.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.ObjectID(function_id_str)
|
||||
driver_id = ray.ObjectID(driver_id_str)
|
||||
function_name = decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = decode(module)
|
||||
@@ -388,10 +388,10 @@ class FunctionActorManager(object):
|
||||
def f():
|
||||
raise Exception("This function was not imported properly.")
|
||||
|
||||
self._function_execution_info[driver_id][function_id.id()] = (
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f, function_name=function_name, max_calls=max_calls))
|
||||
self._num_task_executions[driver_id][function_id.id()] = 0
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
@@ -416,7 +416,7 @@ class FunctionActorManager(object):
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
self._function_execution_info[driver_id][function_id.id()] = (
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
@@ -435,7 +435,7 @@ class FunctionActorManager(object):
|
||||
Returns:
|
||||
A FunctionExecutionInfo object.
|
||||
"""
|
||||
function_id = function_descriptor.function_id.id()
|
||||
function_id = function_descriptor.function_id
|
||||
|
||||
# Wait until the function to be executed has actually been
|
||||
# registered on this worker. We will push warnings to the user if
|
||||
@@ -449,7 +449,7 @@ class FunctionActorManager(object):
|
||||
except KeyError as e:
|
||||
message = ("Error occurs in get_execution_info: "
|
||||
"driver_id: %s, function_descriptor: %s. Message: %s" %
|
||||
(binary_to_hex(driver_id), function_descriptor, e))
|
||||
driver_id, function_descriptor, e)
|
||||
raise KeyError(message)
|
||||
return info
|
||||
|
||||
@@ -474,11 +474,11 @@ class FunctionActorManager(object):
|
||||
warning_sent = False
|
||||
while True:
|
||||
with self._worker.lock:
|
||||
if (self._worker.actor_id == ray.worker.NIL_ACTOR_ID
|
||||
and (function_descriptor.function_id.id() in
|
||||
if (self._worker.actor_id.is_nil()
|
||||
and (function_descriptor.function_id in
|
||||
self._function_execution_info[driver_id])):
|
||||
break
|
||||
elif self._worker.actor_id != ray.worker.NIL_ACTOR_ID and (
|
||||
elif not self._worker.actor_id.is_nil() and (
|
||||
self._worker.actor_id in self._worker.actors):
|
||||
break
|
||||
if time.time() - start_time > timeout:
|
||||
@@ -556,7 +556,7 @@ class FunctionActorManager(object):
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def load_actor(self, driver_id, function_descriptor):
|
||||
key = (b"ActorClass:" + driver_id + b":" +
|
||||
key = (b"ActorClass:" + driver_id.id() + b":" +
|
||||
function_descriptor.function_id.id())
|
||||
# Wait for the actor class key to have been imported by the
|
||||
# import thread. TODO(rkn): It shouldn't be possible to end
|
||||
@@ -578,8 +578,8 @@ class FunctionActorManager(object):
|
||||
actor_class_key: The key in Redis to use to fetch the actor.
|
||||
worker: The worker to use.
|
||||
"""
|
||||
actor_id_str = self._worker.actor_id
|
||||
(driver_id, class_name, module, pickled_class, checkpoint_interval,
|
||||
actor_id = self._worker.actor_id
|
||||
(driver_id_str, class_name, module, pickled_class, checkpoint_interval,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
actor_class_key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
@@ -588,6 +588,7 @@ class FunctionActorManager(object):
|
||||
|
||||
class_name = decode(class_name)
|
||||
module = decode(module)
|
||||
driver_id = ray.ObjectID(driver_id_str)
|
||||
checkpoint_interval = int(checkpoint_interval)
|
||||
actor_method_names = json.loads(decode(actor_method_names))
|
||||
|
||||
@@ -606,7 +607,7 @@ class FunctionActorManager(object):
|
||||
class TemporaryActor(object):
|
||||
pass
|
||||
|
||||
self._worker.actors[actor_id_str] = TemporaryActor()
|
||||
self._worker.actors[actor_id] = TemporaryActor()
|
||||
self._worker.actor_checkpoint_interval = checkpoint_interval
|
||||
|
||||
def temporary_actor_method(*xs):
|
||||
@@ -618,7 +619,7 @@ class FunctionActorManager(object):
|
||||
for actor_method_name in actor_method_names:
|
||||
function_descriptor = FunctionDescriptor(module, actor_method_name,
|
||||
class_name)
|
||||
function_id = function_descriptor.function_id.id()
|
||||
function_id = function_descriptor.function_id
|
||||
temporary_executor = self._make_actor_method_executor(
|
||||
actor_method_name,
|
||||
temporary_actor_method,
|
||||
@@ -644,14 +645,14 @@ class FunctionActorManager(object):
|
||||
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id,
|
||||
data={"actor_id": actor_id_str})
|
||||
data={"actor_id": actor_id.id()})
|
||||
# TODO(rkn): In the future, it might make sense to have the worker
|
||||
# exit here. However, currently that would lead to hanging if
|
||||
# someone calls ray.get on a method invoked on the actor.
|
||||
else:
|
||||
# TODO(pcm): Why is the below line necessary?
|
||||
unpickled_class.__module__ = module
|
||||
self._worker.actors[actor_id_str] = unpickled_class.__new__(
|
||||
self._worker.actors[actor_id] = unpickled_class.__new__(
|
||||
unpickled_class)
|
||||
|
||||
actor_methods = inspect.getmembers(
|
||||
@@ -659,7 +660,7 @@ class FunctionActorManager(object):
|
||||
for actor_method_name, actor_method in actor_methods:
|
||||
function_descriptor = FunctionDescriptor(
|
||||
module, actor_method_name, class_name)
|
||||
function_id = function_descriptor.function_id.id()
|
||||
function_id = function_descriptor.function_id
|
||||
executor = self._make_actor_method_executor(
|
||||
actor_method_name, actor_method, actor_imported=True)
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
|
||||
@@ -58,7 +58,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
The serialized object.
|
||||
"""
|
||||
builder = flatbuffers.Builder(0)
|
||||
driver_offset = builder.CreateString(driver_id)
|
||||
driver_offset = builder.CreateString(driver_id.id())
|
||||
error_type_offset = builder.CreateString(error_type)
|
||||
message_offset = builder.CreateString(message)
|
||||
|
||||
|
||||
@@ -131,5 +131,5 @@ class ImportThread(object):
|
||||
self.worker,
|
||||
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=driver_id,
|
||||
driver_id=ray.ObjectID(driver_id),
|
||||
data={"name": name})
|
||||
|
||||
@@ -5,8 +5,6 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ray.raylet import ObjectID
|
||||
|
||||
|
||||
def env_integer(key, default):
|
||||
if key in os.environ:
|
||||
@@ -15,8 +13,6 @@ def env_integer(key, default):
|
||||
|
||||
|
||||
ID_SIZE = 20
|
||||
NIL_JOB_ID = ObjectID(ID_SIZE * b"\xff")
|
||||
NIL_FUNCTION_ID = NIL_JOB_ID
|
||||
|
||||
# The default maximum number of bytes to allocate to the object store unless
|
||||
# overridden by the user.
|
||||
|
||||
@@ -4,10 +4,10 @@ from __future__ import print_function
|
||||
|
||||
from ray.core.src.ray.raylet.libraylet_library_python import (
|
||||
Task, RayletClient, ObjectID, check_simple_value, compute_task_id,
|
||||
task_from_string, task_to_string, _config, common_error)
|
||||
task_from_string, task_to_string, _config, RayCommonError)
|
||||
|
||||
__all__ = [
|
||||
"Task", "RayletClient", "ObjectID", "check_simple_value",
|
||||
"compute_task_id", "task_from_string", "task_to_string",
|
||||
"start_local_scheduler", "_config", "common_error"
|
||||
"start_local_scheduler", "_config", "RayCommonError"
|
||||
]
|
||||
|
||||
+9
-8
@@ -67,10 +67,10 @@ def push_error_to_driver(worker,
|
||||
will be serialized with json and stored in Redis.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray_constants.NIL_JOB_ID.id()
|
||||
driver_id = ray.ObjectID.nil_id()
|
||||
data = {} if data is None else data
|
||||
worker.raylet_client.push_error(
|
||||
ray.ObjectID(driver_id), error_type, message, time.time())
|
||||
worker.raylet_client.push_error(driver_id, error_type, message,
|
||||
time.time())
|
||||
|
||||
|
||||
def push_error_to_driver_through_redis(redis_client,
|
||||
@@ -96,15 +96,16 @@ def push_error_to_driver_through_redis(redis_client,
|
||||
will be serialized with json and stored in Redis.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray_constants.NIL_JOB_ID.id()
|
||||
driver_id = ray.ObjectID.nil_id()
|
||||
data = {} if data is None else data
|
||||
# Do everything in Python and through the Python Redis client instead
|
||||
# of through the raylet.
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
|
||||
message, time.time())
|
||||
redis_client.execute_command(
|
||||
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.ERROR_INFO,
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO, driver_id, error_data)
|
||||
redis_client.execute_command("RAY.TABLE_APPEND",
|
||||
ray.gcs_utils.TablePrefix.ERROR_INFO,
|
||||
ray.gcs_utils.TablePubsub.ERROR_INFO,
|
||||
driver_id.id(), error_data)
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
@@ -400,7 +401,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
|
||||
worker,
|
||||
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=worker.task_driver_id.id())
|
||||
driver_id=worker.task_driver_id)
|
||||
|
||||
|
||||
class _ThreadSafeProxy(object):
|
||||
|
||||
+64
-71
@@ -36,6 +36,7 @@ import ray.raylet
|
||||
import ray.plasma
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray import import_thread
|
||||
from ray import ObjectID
|
||||
from ray import profiling
|
||||
from ray.function_manager import (FunctionActorManager, FunctionDescriptor)
|
||||
import ray.parameter
|
||||
@@ -53,13 +54,6 @@ PYTHON_MODE = 3
|
||||
|
||||
ERROR_KEY_PREFIX = b"Error:"
|
||||
|
||||
# This must match the definition of NIL_ACTOR_ID in task.h.
|
||||
NIL_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
NIL_LOCAL_SCHEDULER_ID = NIL_ID
|
||||
NIL_ACTOR_ID = NIL_ID
|
||||
NIL_ACTOR_HANDLE_ID = NIL_ID
|
||||
NIL_CLIENT_ID = ray_constants.ID_SIZE * b"\xff"
|
||||
|
||||
# Default resource requirements for actors when no resource requirements are
|
||||
# specified.
|
||||
DEFAULT_ACTOR_METHOD_CPUS_SIMPLE_CASE = 1
|
||||
@@ -168,7 +162,7 @@ class Worker(object):
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.task_driver_id = ObjectID.nil_id()
|
||||
self._task_context = threading.local()
|
||||
|
||||
@property
|
||||
@@ -189,14 +183,13 @@ class Worker(object):
|
||||
# If this is running on the main thread, initialize it to
|
||||
# NIL. The actual value will set when the worker receives
|
||||
# a task from raylet backend.
|
||||
self._task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self._task_context.current_task_id = ObjectID.nil_id()
|
||||
else:
|
||||
# If this is running on a separate thread, then the mapping
|
||||
# to the current task ID may not be correct. Generate a
|
||||
# random task ID so that the backend can differentiate
|
||||
# between different threads.
|
||||
self._task_context.current_task_id = ray.ObjectID(
|
||||
random_string())
|
||||
self._task_context.current_task_id = ObjectID(random_string())
|
||||
if getattr(self, '_multithreading_warned', False) is not True:
|
||||
logger.warning(
|
||||
"Calling ray.get or ray.wait in a separate thread "
|
||||
@@ -353,12 +346,13 @@ class Worker(object):
|
||||
full.
|
||||
"""
|
||||
# Make sure that the value is not an object ID.
|
||||
if isinstance(value, ray.ObjectID):
|
||||
raise Exception("Calling 'put' on an ObjectID is not allowed "
|
||||
"(similarly, returning an ObjectID from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ObjectID in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
if isinstance(value, ObjectID):
|
||||
raise Exception(
|
||||
"Calling 'put' on an ray.ObjectID is not allowed "
|
||||
"(similarly, returning an ray.ObjectID from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ray.ObjectID in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
|
||||
# Serialize and put the object in the object store.
|
||||
try:
|
||||
@@ -433,7 +427,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id.id())
|
||||
driver_id=self.task_driver_id)
|
||||
warning_sent = True
|
||||
|
||||
def get_object(self, object_ids):
|
||||
@@ -449,9 +443,10 @@ class Worker(object):
|
||||
"""
|
||||
# Make sure that the values are object IDs.
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise Exception("Attempting to call `get` on the value {}, "
|
||||
"which is not an ObjectID.".format(object_id))
|
||||
if not isinstance(object_id, ObjectID):
|
||||
raise Exception(
|
||||
"Attempting to call `get` on the value {}, "
|
||||
"which is not an ray.ObjectID.".format(object_id))
|
||||
# Do an initial fetch for remote objects. We divide the fetch into
|
||||
# smaller fetches so as to not block the manager for a prolonged period
|
||||
# of time in a single call.
|
||||
@@ -484,8 +479,7 @@ class Worker(object):
|
||||
for unready_id in unready_ids.keys()
|
||||
]
|
||||
ray_object_ids_to_fetch = [
|
||||
ray.ObjectID(unready_id)
|
||||
for unready_id in unready_ids.keys()
|
||||
ObjectID(unready_id) for unready_id in unready_ids.keys()
|
||||
]
|
||||
fetch_request_size = ray._config.worker_fetch_request_size()
|
||||
for i in range(0, len(object_ids_to_fetch),
|
||||
@@ -574,22 +568,22 @@ class Worker(object):
|
||||
with profiling.profile("submit_task", worker=self):
|
||||
if actor_id is None:
|
||||
assert actor_handle_id is None
|
||||
actor_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
actor_handle_id = ray.ObjectID(NIL_ACTOR_HANDLE_ID)
|
||||
actor_id = ObjectID.nil_id()
|
||||
actor_handle_id = ObjectID.nil_id()
|
||||
else:
|
||||
assert actor_handle_id is not None
|
||||
|
||||
if actor_creation_id is None:
|
||||
actor_creation_id = ray.ObjectID(NIL_ACTOR_ID)
|
||||
actor_creation_id = ObjectID.nil_id()
|
||||
|
||||
if actor_creation_dummy_object_id is None:
|
||||
actor_creation_dummy_object_id = (ray.ObjectID(NIL_ID))
|
||||
actor_creation_dummy_object_id = ObjectID.nil_id()
|
||||
|
||||
# Put large or complex arguments that are passed by value in the
|
||||
# object store first.
|
||||
args_for_local_scheduler = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
if isinstance(arg, ObjectID):
|
||||
args_for_local_scheduler.append(arg)
|
||||
elif ray.raylet.check_simple_value(arg):
|
||||
args_for_local_scheduler.append(arg)
|
||||
@@ -722,7 +716,7 @@ class Worker(object):
|
||||
arguments are being retrieved.
|
||||
serialized_args (List): The arguments to the function. These are
|
||||
either strings representing serialized objects passed by value
|
||||
or they are ObjectIDs.
|
||||
or they are ray.ObjectIDs.
|
||||
|
||||
Returns:
|
||||
The retrieved arguments in addition to the arguments that were
|
||||
@@ -734,7 +728,7 @@ class Worker(object):
|
||||
"""
|
||||
arguments = []
|
||||
for (i, arg) in enumerate(serialized_args):
|
||||
if isinstance(arg, ray.ObjectID):
|
||||
if isinstance(arg, ObjectID):
|
||||
# get the object from the local object store
|
||||
argument = self.get_object([arg])[0]
|
||||
if isinstance(argument, RayTaskError):
|
||||
@@ -838,9 +832,9 @@ class Worker(object):
|
||||
outputs = function_executor(*arguments)
|
||||
else:
|
||||
if not task.actor_id().is_nil():
|
||||
key = task.actor_id().id()
|
||||
key = task.actor_id()
|
||||
else:
|
||||
key = task.actor_creation_id().id()
|
||||
key = task.actor_creation_id()
|
||||
outputs = function_executor(dummy_return_id,
|
||||
self.actors[key], *arguments)
|
||||
except Exception as e:
|
||||
@@ -882,7 +876,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id.id(),
|
||||
driver_id=self.task_driver_id,
|
||||
data={
|
||||
"function_id": function_id.id(),
|
||||
"function_name": function_name,
|
||||
@@ -890,7 +884,7 @@ class Worker(object):
|
||||
"class_name": function_descriptor.class_name
|
||||
})
|
||||
# Mark the actor init as failed
|
||||
if self.actor_id != NIL_ACTOR_ID and function_name == "__init__":
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
|
||||
def _wait_for_and_process_task(self, task):
|
||||
@@ -901,13 +895,13 @@ class Worker(object):
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
driver_id = task.driver_id().id()
|
||||
driver_id = task.driver_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
if not task.actor_creation_id().is_nil():
|
||||
assert self.actor_id == NIL_ACTOR_ID
|
||||
self.actor_id = task.actor_creation_id().id()
|
||||
assert self.actor_id.is_nil()
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.function_actor_manager.load_actor(driver_id,
|
||||
function_descriptor)
|
||||
|
||||
@@ -930,12 +924,12 @@ class Worker(object):
|
||||
title = "ray_worker:{}()".format(function_name)
|
||||
next_title = "ray_worker"
|
||||
else:
|
||||
actor = self.actors[task.actor_creation_id().id()]
|
||||
actor = self.actors[task.actor_creation_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
else:
|
||||
actor = self.actors[task.actor_id().id()]
|
||||
actor = self.actors[task.actor_id()]
|
||||
title = "ray_{}:{}()".format(actor.__class__.__name__,
|
||||
function_name)
|
||||
next_title = "ray_{}".format(actor.__class__.__name__)
|
||||
@@ -943,14 +937,14 @@ class Worker(object):
|
||||
with _changeproctitle(title, next_title):
|
||||
self._process_task(task, execution_info)
|
||||
# Reset the state fields so the next task can run.
|
||||
self.task_context.current_task_id = ray.ObjectID(NIL_ID)
|
||||
self.task_context.current_task_id = ObjectID.nil_id()
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id == NIL_ACTOR_ID:
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = ray.ObjectID(NIL_ID)
|
||||
self.task_driver_id = ObjectID.nil_id()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
@@ -1104,17 +1098,17 @@ def error_applies_to_driver(error_key, worker=global_worker):
|
||||
+ ray_constants.ID_SIZE), error_key
|
||||
# If the driver ID in the error message is a sequence of all zeros, then
|
||||
# the message is intended for all drivers.
|
||||
driver_id = error_key[len(ERROR_KEY_PREFIX):(
|
||||
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)]
|
||||
return (driver_id == worker.task_driver_id.id()
|
||||
or driver_id == ray.ray_constants.NIL_JOB_ID.id())
|
||||
driver_id = ObjectID(error_key[len(ERROR_KEY_PREFIX):(
|
||||
len(ERROR_KEY_PREFIX) + ray_constants.ID_SIZE)])
|
||||
return (driver_id == worker.task_driver_id
|
||||
or driver_id == ObjectID.nil_id())
|
||||
|
||||
|
||||
def error_info(worker=global_worker):
|
||||
"""Return information about failed tasks."""
|
||||
worker.check_connected()
|
||||
return (global_state.error_messages(job_id=worker.task_driver_id) +
|
||||
global_state.error_messages(job_id=ray_constants.NIL_JOB_ID))
|
||||
global_state.error_messages(job_id=ObjectID.nil_id()))
|
||||
|
||||
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
@@ -1134,13 +1128,13 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
return obj.id()
|
||||
|
||||
def object_id_custom_deserializer(serialized_obj):
|
||||
return ray.ObjectID(serialized_obj)
|
||||
return ObjectID(serialized_obj)
|
||||
|
||||
# We register this serializer on each worker instead of calling
|
||||
# register_custom_serializer from the driver so that isinstance still
|
||||
# works.
|
||||
serialization_context.register_type(
|
||||
ray.ObjectID,
|
||||
ObjectID,
|
||||
"ray.ObjectID",
|
||||
pickle=False,
|
||||
custom_serializer=object_id_custom_serializer,
|
||||
@@ -1661,7 +1655,7 @@ def listen_error_messages_raylet(worker, task_error_queue):
|
||||
job_id = error_data.JobId()
|
||||
if job_id not in [
|
||||
worker.task_driver_id.id(),
|
||||
ray_constants.NIL_JOB_ID.id()
|
||||
ObjectID.nil_id().id()
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -1772,11 +1766,10 @@ def connect(info,
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = ray.ObjectID(random_string())
|
||||
driver_id = ObjectID(random_string())
|
||||
|
||||
if not isinstance(driver_id, ray.ObjectID):
|
||||
raise Exception(
|
||||
"The type of given driver id must be ray.ObjectID.")
|
||||
if not isinstance(driver_id, ObjectID):
|
||||
raise Exception("The type of given driver id must be ObjectID.")
|
||||
|
||||
worker.worker_id = driver_id.id()
|
||||
|
||||
@@ -1785,11 +1778,11 @@ def connect(info,
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = ray.ObjectID(worker.worker_id)
|
||||
worker.task_driver_id = ObjectID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
worker.actor_id = NIL_ACTOR_ID
|
||||
worker.actor_id = ObjectID.nil_id()
|
||||
worker.connected = True
|
||||
worker.set_mode(mode)
|
||||
|
||||
@@ -1920,13 +1913,13 @@ def connect(info,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
ray.ObjectID(random_string()), # parent_task_id.
|
||||
ObjectID(random_string()), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_creation_dummy_object_id.
|
||||
ObjectID.nil_id(), # actor_creation_id.
|
||||
ObjectID.nil_id(), # actor_creation_dummy_object_id.
|
||||
0, # max_actor_reconstructions.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_id.
|
||||
ray.ObjectID(NIL_ACTOR_ID), # actor_handle_id.
|
||||
ObjectID.nil_id(), # actor_id.
|
||||
ObjectID.nil_id(), # actor_handle_id.
|
||||
nil_actor_counter, # actor_counter.
|
||||
[], # new_actor_handles.
|
||||
[], # execution_dependencies.
|
||||
@@ -2148,9 +2141,7 @@ def register_custom_serializer(cls,
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if driver_id is None:
|
||||
driver_id_bytes = worker.task_driver_id.id()
|
||||
else:
|
||||
driver_id_bytes = driver_id.id()
|
||||
driver_id = worker.task_driver_id
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2160,7 +2151,7 @@ def register_custom_serializer(cls,
|
||||
# system.
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(ray.ObjectID(driver_id_bytes))
|
||||
"worker"].get_serialization_context(driver_id)
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
@@ -2279,13 +2270,15 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
IDs.
|
||||
"""
|
||||
|
||||
if isinstance(object_ids, ray.ObjectID):
|
||||
if isinstance(object_ids, ObjectID):
|
||||
raise TypeError(
|
||||
"wait() expected a list of ObjectID, got a single ObjectID")
|
||||
"wait() expected a list of ray.ObjectID, got a single ray.ObjectID"
|
||||
)
|
||||
|
||||
if not isinstance(object_ids, list):
|
||||
raise TypeError("wait() expected a list of ObjectID, got {}".format(
|
||||
type(object_ids)))
|
||||
raise TypeError(
|
||||
"wait() expected a list of ray.ObjectID, got {}".format(
|
||||
type(object_ids)))
|
||||
|
||||
if isinstance(timeout, int) and timeout != 0:
|
||||
logger.warning("The 'timeout' argument now requires seconds instead "
|
||||
@@ -2298,8 +2291,8 @@ def wait(object_ids, num_returns=1, timeout=None, worker=global_worker):
|
||||
|
||||
if worker.mode != LOCAL_MODE:
|
||||
for object_id in object_ids:
|
||||
if not isinstance(object_id, ray.ObjectID):
|
||||
raise TypeError("wait() expected a list of ObjectID, "
|
||||
if not isinstance(object_id, ObjectID):
|
||||
raise TypeError("wait() expected a list of ray.ObjectID, "
|
||||
"got list containing {}".format(
|
||||
type(object_id)))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user