mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 12:28:10 +08:00
[ID Refactor] Rename DriverID to JobID (#5004)
* WIP WIP WIP Rename Driver -> Job Fix complition Fix Rename in Java In py WIP Fix WIP Fix Fix test Fix Fix C++ linting Fix * Update java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Update src/ray/core_worker/core_worker.cc Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Address comments * Fix * Fix CI * Fix cpp linting * Fix py lint * FIx * Address comments and fix * Address comments * Address * Fix import_threading
This commit is contained in:
@@ -56,7 +56,8 @@ from ray._raylet import (
|
||||
ActorID,
|
||||
ClientID,
|
||||
Config as _Config,
|
||||
DriverID,
|
||||
JobID,
|
||||
WorkerID,
|
||||
FunctionID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
@@ -141,7 +142,8 @@ __all__ += [
|
||||
"ActorHandleID",
|
||||
"ActorID",
|
||||
"ClientID",
|
||||
"DriverID",
|
||||
"JobID",
|
||||
"WorkerID",
|
||||
"FunctionID",
|
||||
"ObjectID",
|
||||
"TaskID",
|
||||
|
||||
@@ -221,13 +221,13 @@ cdef class RayletClient:
|
||||
def __cinit__(self, raylet_socket,
|
||||
ClientID client_id,
|
||||
c_bool is_worker,
|
||||
DriverID driver_id):
|
||||
JobID job_id):
|
||||
# We know that we are using Python, so just skip the language
|
||||
# parameter.
|
||||
# TODO(suquark): Should we allow unicode chars in "raylet_socket"?
|
||||
self.client.reset(new CRayletClient(
|
||||
raylet_socket.encode("ascii"), client_id.native(), is_worker,
|
||||
driver_id.native(), LANGUAGE_PYTHON))
|
||||
job_id.native(), LANGUAGE_PYTHON))
|
||||
|
||||
def disconnect(self):
|
||||
check_status(self.client.get().Disconnect())
|
||||
@@ -293,9 +293,9 @@ cdef class RayletClient:
|
||||
postincrement(iterator)
|
||||
return resources_dict
|
||||
|
||||
def push_error(self, DriverID driver_id, error_type, error_message,
|
||||
def push_error(self, JobID job_id, error_type, error_message,
|
||||
double timestamp):
|
||||
check_status(self.client.get().PushError(driver_id.native(),
|
||||
check_status(self.client.get().PushError(job_id.native(),
|
||||
error_type.encode("ascii"),
|
||||
error_message.encode("ascii"),
|
||||
timestamp))
|
||||
@@ -381,8 +381,8 @@ cdef class RayletClient:
|
||||
return ClientID(self.client.get().GetClientID().Binary())
|
||||
|
||||
@property
|
||||
def driver_id(self):
|
||||
return DriverID(self.client.get().GetDriverID().Binary())
|
||||
def job_id(self):
|
||||
return JobID(self.client.get().GetJobID().Binary())
|
||||
|
||||
@property
|
||||
def is_worker(self):
|
||||
|
||||
+19
-21
@@ -17,8 +17,7 @@ from ray.function_manager import FunctionDescriptor
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID,
|
||||
DriverID)
|
||||
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -186,7 +185,7 @@ class ActorClass(object):
|
||||
task.
|
||||
_resources: The default resources required by the actor creation task.
|
||||
_actor_method_cpus: The number of CPUs required by actor method tasks.
|
||||
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
|
||||
_last_job_id_exported_for: The ID of the job of the last Ray
|
||||
session during which this actor class definition was exported. This
|
||||
is an imperfect mechanism used to determine if we need to export
|
||||
the remote function again. It is imperfect in the sense that the
|
||||
@@ -212,7 +211,7 @@ class ActorClass(object):
|
||||
self._num_cpus = num_cpus
|
||||
self._num_gpus = num_gpus
|
||||
self._resources = resources
|
||||
self._last_driver_id_exported_for = None
|
||||
self._last_job_id_exported_for = None
|
||||
|
||||
self._actor_methods = inspect.getmembers(
|
||||
self._modified_class, ray.utils.is_function_or_method)
|
||||
@@ -345,13 +344,12 @@ class ActorClass(object):
|
||||
*copy.deepcopy(args), **copy.deepcopy(kwargs))
|
||||
else:
|
||||
# Export the actor.
|
||||
if (self._last_driver_id_exported_for is None
|
||||
or self._last_driver_id_exported_for !=
|
||||
worker.task_driver_id):
|
||||
if (self._last_job_id_exported_for is None or
|
||||
self._last_job_id_exported_for != worker.current_job_id):
|
||||
# If this actor class was exported in a previous session, we
|
||||
# need to export this function again, because current GCS
|
||||
# doesn't have it.
|
||||
self._last_driver_id_exported_for = worker.task_driver_id
|
||||
self._last_job_id_exported_for = worker.current_job_id
|
||||
worker.function_actor_manager.export_actor_class(
|
||||
self._modified_class, self._actor_method_names)
|
||||
|
||||
@@ -389,7 +387,7 @@ class ActorClass(object):
|
||||
actor_id, self._modified_class.__module__, self._class_name,
|
||||
actor_cursor, self._actor_method_names, self._method_decorators,
|
||||
self._method_signatures, self._actor_method_num_return_vals,
|
||||
actor_cursor, actor_method_cpu, worker.task_driver_id)
|
||||
actor_cursor, actor_method_cpu, worker.current_job_id)
|
||||
# We increment the actor counter by 1 to account for the actor creation
|
||||
# task.
|
||||
actor_handle._ray_actor_counter += 1
|
||||
@@ -446,9 +444,9 @@ class ActorHandle(object):
|
||||
_ray_original_handle: True if this is the original actor handle for a
|
||||
given actor. If this is true, then the actor will be destroyed when
|
||||
this handle goes out of scope.
|
||||
_ray_actor_driver_id: The driver ID of the job that created the actor
|
||||
(it is possible that this ActorHandle exists on a driver with a
|
||||
different driver ID).
|
||||
_ray_actor_job_id: The ID of the job that created the actor
|
||||
(it is possible that this ActorHandle exists on a job with a
|
||||
different job ID).
|
||||
_ray_new_actor_handles: The new actor handles that were created from
|
||||
this handle since the last task on this handle was submitted. This
|
||||
is used to garbage-collect dummy objects that are no longer
|
||||
@@ -466,10 +464,10 @@ class ActorHandle(object):
|
||||
method_num_return_vals,
|
||||
actor_creation_dummy_object_id,
|
||||
actor_method_cpus,
|
||||
actor_driver_id,
|
||||
actor_job_id,
|
||||
actor_handle_id=None):
|
||||
assert isinstance(actor_id, ActorID)
|
||||
assert isinstance(actor_driver_id, DriverID)
|
||||
assert isinstance(actor_job_id, ray.JobID)
|
||||
self._ray_actor_id = actor_id
|
||||
self._ray_module_name = module_name
|
||||
# False if this actor handle was created by forking or pickling. True
|
||||
@@ -491,7 +489,7 @@ class ActorHandle(object):
|
||||
self._ray_actor_creation_dummy_object_id = (
|
||||
actor_creation_dummy_object_id)
|
||||
self._ray_actor_method_cpus = actor_method_cpus
|
||||
self._ray_actor_driver_id = actor_driver_id
|
||||
self._ray_actor_job_id = actor_job_id
|
||||
self._ray_new_actor_handles = []
|
||||
self._ray_actor_lock = threading.Lock()
|
||||
|
||||
@@ -551,7 +549,7 @@ class ActorHandle(object):
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": self._ray_actor_method_cpus},
|
||||
placement_resources={},
|
||||
driver_id=self._ray_actor_driver_id,
|
||||
job_id=self._ray_actor_job_id,
|
||||
)
|
||||
# Update the actor counter and cursor to reflect the most recent
|
||||
# invocation.
|
||||
@@ -612,7 +610,7 @@ class ActorHandle(object):
|
||||
# not just the first one.
|
||||
worker = ray.worker.get_global_worker()
|
||||
if (worker.mode == ray.worker.SCRIPT_MODE
|
||||
and self._ray_actor_driver_id.binary() != worker.worker_id):
|
||||
and self._ray_actor_job_id.binary() != worker.worker_id):
|
||||
# If the worker is a driver and driver id has changed because
|
||||
# Ray was shut down re-initialized, the actor is already cleaned up
|
||||
# and we don't need to send `__ray_terminate__` again.
|
||||
@@ -666,7 +664,7 @@ class ActorHandle(object):
|
||||
"actor_creation_dummy_object_id": self.
|
||||
_ray_actor_creation_dummy_object_id,
|
||||
"actor_method_cpus": self._ray_actor_method_cpus,
|
||||
"actor_driver_id": self._ray_actor_driver_id,
|
||||
"actor_job_id": self._ray_actor_job_id,
|
||||
"ray_forking": ray_forking
|
||||
}
|
||||
|
||||
@@ -727,9 +725,9 @@ class ActorHandle(object):
|
||||
state["method_num_return_vals"],
|
||||
state["actor_creation_dummy_object_id"],
|
||||
state["actor_method_cpus"],
|
||||
# This is the driver ID of the driver that owns the actor, not
|
||||
# necessarily the driver that owns this actor handle.
|
||||
state["actor_driver_id"],
|
||||
# This is the ID of the job that owns the actor, not
|
||||
# necessarily the job that owns this actor handle.
|
||||
state["actor_job_id"],
|
||||
actor_handle_id=actor_handle_id)
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
@@ -277,9 +277,9 @@ class FunctionActorManager(object):
|
||||
the worker gets connected.
|
||||
_actors_to_export: The actors to export when the worker gets
|
||||
connected.
|
||||
_function_execution_info: The map from driver_id to finction_id
|
||||
_function_execution_info: The map from job_id to function_id
|
||||
and execution_info.
|
||||
_num_task_executions: The map from driver_id to function
|
||||
_num_task_executions: The map from job_id to function
|
||||
execution times.
|
||||
imported_actor_classes: The set of actor classes keys (format:
|
||||
ActorClass:function_id) that are already in GCS.
|
||||
@@ -303,17 +303,17 @@ class FunctionActorManager(object):
|
||||
self._loaded_actor_classes = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
def increase_task_counter(self, job_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
job_id = ray.JobID.nil()
|
||||
self._num_task_executions[job_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_descriptor):
|
||||
def get_task_counter(self, job_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
job_id = ray.JobID.nil()
|
||||
return self._num_task_executions[job_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
"""Export cached remote functions
|
||||
@@ -376,11 +376,11 @@ class FunctionActorManager(object):
|
||||
check_oversized_pickle(pickled_function,
|
||||
remote_function._function_name,
|
||||
"remote function", self._worker)
|
||||
key = (b"RemoteFunction:" + self._worker.task_driver_id.binary() + b":"
|
||||
key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
|
||||
+ remote_function._function_descriptor.function_id.binary())
|
||||
self._worker.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self._worker.task_driver_id.binary(),
|
||||
"job_id": self._worker.current_job_id.binary(),
|
||||
"function_id": remote_function._function_descriptor.
|
||||
function_id.binary(),
|
||||
"name": remote_function._function_name,
|
||||
@@ -392,14 +392,14 @@ class FunctionActorManager(object):
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
(driver_id_str, function_id_str, function_name, serialized_function,
|
||||
(job_id_str, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self._worker.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"job_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.FunctionID(function_id_str)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
job_id = ray.JobID(job_id_str)
|
||||
function_name = decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = decode(module)
|
||||
@@ -413,12 +413,12 @@ class FunctionActorManager(object):
|
||||
# atomic. Otherwise, there is race condition. Another thread may use
|
||||
# the temporary function above before the real function is ready.
|
||||
with self.lock:
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
@@ -434,7 +434,7 @@ class FunctionActorManager(object):
|
||||
"Failed to unpickle the remote function '{}' with "
|
||||
"function ID {}. Traceback:\n{}".format(
|
||||
function_name, function_id.hex(), traceback_str),
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
else:
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python
|
||||
@@ -442,7 +442,7 @@ class FunctionActorManager(object):
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
@@ -452,11 +452,11 @@ class FunctionActorManager(object):
|
||||
b"FunctionTable:" + function_id.binary(),
|
||||
self._worker.worker_id)
|
||||
|
||||
def get_execution_info(self, driver_id, function_descriptor):
|
||||
def get_execution_info(self, job_id, function_descriptor):
|
||||
"""Get the FunctionExecutionInfo of a remote function.
|
||||
|
||||
Args:
|
||||
driver_id: ID of the driver that the function belongs to.
|
||||
job_id: ID of the job that the function belongs to.
|
||||
function_descriptor: The FunctionDescriptor of the function to get.
|
||||
|
||||
Returns:
|
||||
@@ -464,11 +464,11 @@ class FunctionActorManager(object):
|
||||
"""
|
||||
if self._worker.load_code_from_local:
|
||||
# Load function from local code.
|
||||
# Currently, we don't support isolating code by drivers,
|
||||
# thus always set driver ID to NIL here.
|
||||
driver_id = ray.DriverID.nil()
|
||||
# Currently, we don't support isolating code by jobs,
|
||||
# thus always set job ID to NIL here.
|
||||
job_id = ray.JobID.nil()
|
||||
if not function_descriptor.is_actor_method():
|
||||
self._load_function_from_local(driver_id, function_descriptor)
|
||||
self._load_function_from_local(job_id, function_descriptor)
|
||||
else:
|
||||
# Load function from GCS.
|
||||
# Wait until the function to be executed has actually been
|
||||
@@ -477,21 +477,21 @@ class FunctionActorManager(object):
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function"):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
self._wait_for_function(function_descriptor, job_id)
|
||||
try:
|
||||
function_id = function_descriptor.function_id
|
||||
info = self._function_execution_info[driver_id][function_id]
|
||||
info = self._function_execution_info[job_id][function_id]
|
||||
except KeyError as e:
|
||||
message = ("Error occurs in get_execution_info: "
|
||||
"driver_id: %s, function_descriptor: %s. Message: %s" %
|
||||
(driver_id, function_descriptor, e))
|
||||
"job_id: %s, function_descriptor: %s. Message: %s" %
|
||||
(job_id, function_descriptor, e))
|
||||
raise KeyError(message)
|
||||
return info
|
||||
|
||||
def _load_function_from_local(self, driver_id, function_descriptor):
|
||||
def _load_function_from_local(self, job_id, function_descriptor):
|
||||
assert not function_descriptor.is_actor_method()
|
||||
function_id = function_descriptor.function_id
|
||||
if (driver_id in self._function_execution_info
|
||||
if (job_id in self._function_execution_info
|
||||
and function_id in self._function_execution_info[function_id]):
|
||||
return
|
||||
module_name, function_name = (
|
||||
@@ -501,13 +501,13 @@ class FunctionActorManager(object):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
function = getattr(module, function_name)._function
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load function %s.".format(function_name))
|
||||
@@ -515,7 +515,7 @@ class FunctionActorManager(object):
|
||||
"Function {} failed to be loaded from local code.".format(
|
||||
function_descriptor))
|
||||
|
||||
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
|
||||
def _wait_for_function(self, function_descriptor, job_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
@@ -528,7 +528,7 @@ class FunctionActorManager(object):
|
||||
Args:
|
||||
function_descriptor : The FunctionDescriptor of the function that
|
||||
we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
job_id (str): The ID of the job to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
@@ -538,7 +538,7 @@ class FunctionActorManager(object):
|
||||
with self.lock:
|
||||
if (self._worker.actor_id.is_nil()
|
||||
and (function_descriptor.function_id in
|
||||
self._function_execution_info[driver_id])):
|
||||
self._function_execution_info[job_id])):
|
||||
break
|
||||
elif not self._worker.actor_id.is_nil() and (
|
||||
self._worker.actor_id in self._worker.actors):
|
||||
@@ -553,7 +553,7 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
@@ -577,22 +577,22 @@ class FunctionActorManager(object):
|
||||
if self._worker.load_code_from_local:
|
||||
return
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
# `task_driver_id` shouldn't be NIL, unless:
|
||||
# `current_job_id` shouldn't be NIL, unless:
|
||||
# 1) This worker isn't an actor;
|
||||
# 2) And a previous task started a background thread, which didn't
|
||||
# finish before the task finished, and still uses Ray API
|
||||
# after that.
|
||||
assert not self._worker.task_driver_id.is_nil(), (
|
||||
assert not self._worker.current_job_id.is_nil(), (
|
||||
"You might have started a background thread in a non-actor task, "
|
||||
"please make sure the thread finishes before the task finishes.")
|
||||
driver_id = self._worker.task_driver_id
|
||||
key = (b"ActorClass:" + driver_id.binary() + b":" +
|
||||
job_id = self._worker.current_job_id
|
||||
key = (b"ActorClass:" + job_id.binary() + b":" +
|
||||
function_descriptor.function_id.binary())
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"driver_id": driver_id.binary(),
|
||||
"job_id": job_id.binary(),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
|
||||
@@ -616,11 +616,11 @@ class FunctionActorManager(object):
|
||||
# within tasks. I tried to disable this, but it may be necessary
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def load_actor_class(self, driver_id, function_descriptor):
|
||||
def load_actor_class(self, job_id, function_descriptor):
|
||||
"""Load the actor class.
|
||||
|
||||
Args:
|
||||
driver_id: Driver ID of the actor.
|
||||
job_id: job ID of the actor.
|
||||
function_descriptor: Function descriptor of the actor constructor.
|
||||
|
||||
Returns:
|
||||
@@ -632,14 +632,14 @@ class FunctionActorManager(object):
|
||||
if actor_class is None:
|
||||
# Load actor class.
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
job_id = ray.JobID.nil()
|
||||
# Load actor class from local code.
|
||||
actor_class = self._load_actor_from_local(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
else:
|
||||
# Load actor class from GCS.
|
||||
actor_class = self._load_actor_class_from_gcs(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
# Save the loaded actor class in cache.
|
||||
self._loaded_actor_classes[function_id] = actor_class
|
||||
|
||||
@@ -657,18 +657,19 @@ class FunctionActorManager(object):
|
||||
actor_method,
|
||||
actor_imported=True,
|
||||
)
|
||||
self._function_execution_info[driver_id][method_id] = (
|
||||
self._function_execution_info[job_id][method_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][method_id] = 0
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][method_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
return actor_class
|
||||
|
||||
def _load_actor_from_local(self, driver_id, function_descriptor):
|
||||
def _load_actor_from_local(self, job_id, function_descriptor):
|
||||
"""Load actor class from local code."""
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
module_name, class_name = (function_descriptor.module_name,
|
||||
function_descriptor.class_name)
|
||||
try:
|
||||
@@ -699,9 +700,9 @@ class FunctionActorManager(object):
|
||||
|
||||
return TemporaryActor
|
||||
|
||||
def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
|
||||
def _load_actor_class_from_gcs(self, job_id, function_descriptor):
|
||||
"""Load actor class from GCS."""
|
||||
key = (b"ActorClass:" + driver_id.binary() + b":" +
|
||||
key = (b"ActorClass:" + job_id.binary() + b":" +
|
||||
function_descriptor.function_id.binary())
|
||||
# Wait for the actor class key to have been imported by the
|
||||
# import thread. TODO(rkn): It shouldn't be possible to end
|
||||
@@ -711,16 +712,14 @@ class FunctionActorManager(object):
|
||||
time.sleep(0.001)
|
||||
|
||||
# Fetch raw data from GCS.
|
||||
(driver_id_str, class_name, module, pickled_class,
|
||||
(job_id_str, class_name, module, pickled_class,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"actor_method_names"
|
||||
])
|
||||
key,
|
||||
["job_id", "class_name", "module", "class", "actor_method_names"])
|
||||
|
||||
class_name = ensure_str(class_name)
|
||||
module_name = ensure_str(module)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
job_id = ray.JobID(job_id_str)
|
||||
actor_method_names = json.loads(ensure_str(actor_method_names))
|
||||
|
||||
actor_class = None
|
||||
@@ -741,11 +740,12 @@ class FunctionActorManager(object):
|
||||
traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
self._worker,
|
||||
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
"Failed to unpickle actor class '{}' for actor ID {}. "
|
||||
"Traceback:\n{}".format(class_name,
|
||||
self._worker.actor_id.hex(),
|
||||
traceback_str), driver_id)
|
||||
"Traceback:\n{}".format(
|
||||
class_name, self._worker.actor_id.hex(), traceback_str),
|
||||
job_id=job_id)
|
||||
# TODO(rkn): In the future, it might make sense to have the worker
|
||||
# exit here. However, currently that would lead to hanging if
|
||||
# someone calls ray.get on a method invoked on the actor.
|
||||
@@ -859,7 +859,7 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
job_id=self._worker.current_job_id)
|
||||
|
||||
def _restore_and_log_checkpoint(self, actor):
|
||||
"""Restore an actor from a checkpoint if available and log any errors.
|
||||
@@ -898,4 +898,4 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
job_id=self._worker.current_job_id)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.core.generated.ray.protocol.Task import Task
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
ActorCheckpointIdData,
|
||||
ClientTableData,
|
||||
DriverTableData,
|
||||
JobTableData,
|
||||
ErrorTableData,
|
||||
ErrorType,
|
||||
GcsEntry,
|
||||
@@ -23,7 +23,7 @@ from ray.core.generated.gcs_pb2 import (
|
||||
__all__ = [
|
||||
"ActorCheckpointIdData",
|
||||
"ClientTableData",
|
||||
"DriverTableData",
|
||||
"JobTableData",
|
||||
"ErrorTableData",
|
||||
"ErrorType",
|
||||
"GcsEntry",
|
||||
@@ -48,8 +48,8 @@ XRAY_HEARTBEAT_CHANNEL = str(
|
||||
XRAY_HEARTBEAT_BATCH_CHANNEL = str(
|
||||
TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
|
||||
|
||||
# xray driver updates
|
||||
XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
|
||||
# xray job updates
|
||||
XRAY_JOB_CHANNEL = str(TablePubsub.Value("JOB_PUBSUB")).encode("ascii")
|
||||
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
||||
# gcs.proto.
|
||||
@@ -61,11 +61,11 @@ TablePrefix_ERROR_INFO_string = "ERROR_INFO"
|
||||
TablePrefix_PROFILE_string = "PROFILE"
|
||||
|
||||
|
||||
def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
def construct_error_message(job_id, error_type, message, timestamp):
|
||||
"""Construct a serialized ErrorTableData object.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that the error should go to. If this is
|
||||
job_id: The ID of the job that the error should go to. If this is
|
||||
nil, then the error will go to all drivers.
|
||||
error_type: The type of the error.
|
||||
message: The error message.
|
||||
@@ -75,7 +75,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
The serialized object.
|
||||
"""
|
||||
data = ErrorTableData()
|
||||
data.driver_id = driver_id.binary()
|
||||
data.job_id = job_id.binary()
|
||||
data.type = error_type
|
||||
data.error_message = message
|
||||
data.timestamp = timestamp
|
||||
|
||||
@@ -114,13 +114,13 @@ class ImportThread(object):
|
||||
|
||||
def fetch_and_execute_function_to_run(self, key):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
(driver_id, serialized_function,
|
||||
(job_id, serialized_function,
|
||||
run_on_other_drivers) = self.redis_client.hmget(
|
||||
key, ["driver_id", "function", "run_on_other_drivers"])
|
||||
key, ["job_id", "function", "run_on_other_drivers"])
|
||||
|
||||
if (utils.decode(run_on_other_drivers) == "False"
|
||||
and self.worker.mode == ray.SCRIPT_MODE
|
||||
and driver_id != self.worker.task_driver_id.binary()):
|
||||
and job_id != self.worker.current_job_id.binary()):
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -140,4 +140,4 @@ class ImportThread(object):
|
||||
self.worker,
|
||||
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=ray.DriverID(driver_id))
|
||||
job_id=ray.JobID(job_id))
|
||||
|
||||
@@ -6,7 +6,8 @@ from libcpp.unordered_map cimport unordered_map
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CWorkerID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -81,7 +82,7 @@ cdef extern from "ray/common/status.h" namespace "ray::StatusCode" nogil:
|
||||
|
||||
|
||||
cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
const CTaskID GenerateTaskId(const CDriverID &driver_id,
|
||||
const CTaskID GenerateTaskId(const CJobID &job_id,
|
||||
const CTaskID &parent_task_id,
|
||||
int parent_task_counter)
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ from ray.includes.unique_ids cimport (
|
||||
CActorCheckpointID,
|
||||
CActorID,
|
||||
CClientID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CWorkerID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -46,7 +47,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
cdef cppclass CRayletClient "RayletClient":
|
||||
CRayletClient(const c_string &raylet_socket,
|
||||
const CClientID &client_id,
|
||||
c_bool is_worker, const CDriverID &driver_id,
|
||||
c_bool is_worker, const CJobID &job_id,
|
||||
const CLanguage &language)
|
||||
CRayStatus Disconnect()
|
||||
CRayStatus SubmitTask(
|
||||
@@ -62,7 +63,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
int num_returns, int64_t timeout_milliseconds,
|
||||
c_bool wait_local, const CTaskID ¤t_task_id,
|
||||
WaitResultPair *result)
|
||||
CRayStatus PushError(const CDriverID &driver_id, const c_string &type,
|
||||
CRayStatus PushError(const CJobID &job_id, const c_string &type,
|
||||
const c_string &error_message, double timestamp)
|
||||
CRayStatus PushProfileEvents(
|
||||
const GCSProfileTableDataT &profile_events)
|
||||
@@ -75,6 +76,6 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id)
|
||||
CLanguage GetLanguage() const
|
||||
CClientID GetClientID() const
|
||||
CDriverID GetDriverID() const
|
||||
CJobID GetJobID() const
|
||||
c_bool IsWorker() const
|
||||
const ResourceMappingType &GetResourceIDs() const
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.includes.common cimport (
|
||||
from ray.includes.unique_ids cimport (
|
||||
CActorHandleID,
|
||||
CActorID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -46,7 +46,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
|
||||
cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification":
|
||||
CTaskSpecification(
|
||||
const CDriverID &driver_id, const CTaskID &parent_task_id,
|
||||
const CJobID &job_id, const CTaskID &parent_task_id,
|
||||
int64_t parent_counter,
|
||||
const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
|
||||
int64_t num_returns,
|
||||
@@ -54,7 +54,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
const CLanguage &language,
|
||||
const c_vector[c_string] &function_descriptor)
|
||||
CTaskSpecification(
|
||||
const CDriverID &driver_id, const CTaskID &parent_task_id,
|
||||
const CJobID &job_id, const CTaskID &parent_task_id,
|
||||
int64_t parent_counter, const CActorID &actor_creation_id,
|
||||
const CObjectID &actor_creation_dummy_object_id,
|
||||
int64_t max_actor_reconstructions, const CActorID &actor_id,
|
||||
@@ -70,7 +70,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
c_string SerializeAsString() const
|
||||
|
||||
CTaskID TaskId() const
|
||||
CDriverID DriverId() const
|
||||
CJobID JobId() const
|
||||
CTaskID ParentTaskId() const
|
||||
int64_t ParentCounter() const
|
||||
c_vector[c_string] FunctionDescriptor() const
|
||||
|
||||
@@ -18,7 +18,7 @@ cdef class Task:
|
||||
unique_ptr[CTaskSpecification] task_spec
|
||||
unique_ptr[c_vector[CObjectID]] execution_dependencies
|
||||
|
||||
def __init__(self, DriverID driver_id, function_descriptor, arguments,
|
||||
def __init__(self, JobID job_id, function_descriptor, arguments,
|
||||
int num_returns, TaskID parent_task_id, int parent_counter,
|
||||
ActorID actor_creation_id,
|
||||
ObjectID actor_creation_dummy_object_id,
|
||||
@@ -72,7 +72,7 @@ cdef class Task:
|
||||
(<ActorHandleID?>new_actor_handle).native())
|
||||
|
||||
self.task_spec.reset(new CTaskSpecification(
|
||||
driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
|
||||
job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
|
||||
actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
|
||||
actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
|
||||
required_resources, required_placement_resources, LANGUAGE_PYTHON,
|
||||
@@ -122,9 +122,9 @@ cdef class Task:
|
||||
return SerializeTaskAsString(
|
||||
self.execution_dependencies.get(), self.task_spec.get())
|
||||
|
||||
def driver_id(self):
|
||||
"""Return the driver ID for this task."""
|
||||
return DriverID(self.task_spec.get().DriverId().Binary())
|
||||
def job_id(self):
|
||||
"""Return the job ID for this task."""
|
||||
return JobID(self.task_spec.get().JobId().Binary())
|
||||
|
||||
def task_id(self):
|
||||
"""Return the task ID for this task."""
|
||||
|
||||
@@ -78,10 +78,10 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
@staticmethod
|
||||
CFunctionID FromBinary(const c_string &binary)
|
||||
|
||||
cdef cppclass CDriverID "ray::DriverID"(CUniqueID):
|
||||
cdef cppclass CJobID "ray::JobID"(CUniqueID):
|
||||
|
||||
@staticmethod
|
||||
CDriverID FromBinary(const c_string &binary)
|
||||
CJobID FromBinary(const c_string &binary)
|
||||
|
||||
cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]):
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.includes.unique_ids cimport (
|
||||
CActorID,
|
||||
CClientID,
|
||||
CConfigID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CFunctionID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
@@ -212,15 +212,23 @@ cdef class ClientID(UniqueID):
|
||||
return <CClientID>self.data
|
||||
|
||||
|
||||
cdef class DriverID(UniqueID):
|
||||
cdef class JobID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id)
|
||||
self.data = CDriverID.FromBinary(<c_string>id)
|
||||
self.data = CJobID.FromBinary(<c_string>id)
|
||||
|
||||
cdef CDriverID native(self):
|
||||
return <CDriverID>self.data
|
||||
cdef CJobID native(self):
|
||||
return <CJobID>self.data
|
||||
|
||||
cdef class WorkerID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id)
|
||||
self.data = CWorkerID.FromBinary(<c_string>id)
|
||||
|
||||
cdef CWorkerID native(self):
|
||||
return <CWorkerID>self.data
|
||||
|
||||
cdef class ActorID(UniqueID):
|
||||
|
||||
@@ -277,7 +285,8 @@ _ID_TYPES = [
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
DriverID,
|
||||
JobID,
|
||||
WorkerID,
|
||||
FunctionID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
|
||||
+23
-23
@@ -130,14 +130,14 @@ class Monitor(object):
|
||||
"Monitor: "
|
||||
"could not find ip for client {}".format(client_id))
|
||||
|
||||
def _xray_clean_up_entries_for_driver(self, driver_id):
|
||||
"""Remove this driver's object/task entries from redis.
|
||||
def _xray_clean_up_entries_for_job(self, job_id):
|
||||
"""Remove this job's object/task entries from redis.
|
||||
|
||||
Removes control-state entries of all tasks and task return
|
||||
objects belonging to the driver.
|
||||
|
||||
Args:
|
||||
driver_id: The driver id.
|
||||
job_id: The job id.
|
||||
"""
|
||||
|
||||
xray_task_table_prefix = (
|
||||
@@ -146,23 +146,23 @@ class Monitor(object):
|
||||
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
|
||||
|
||||
task_table_objects = ray.tasks()
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
driver_task_id_bins = set()
|
||||
job_id_hex = binary_to_hex(job_id)
|
||||
job_task_id_bins = set()
|
||||
for task_id_hex, task_info in task_table_objects.items():
|
||||
task_table_object = task_info["TaskSpec"]
|
||||
task_driver_id_hex = task_table_object["DriverID"]
|
||||
if driver_id_hex != task_driver_id_hex:
|
||||
task_job_id_hex = task_table_object["JobID"]
|
||||
if job_id_hex != task_job_id_hex:
|
||||
# Ignore tasks that aren't from this driver.
|
||||
continue
|
||||
driver_task_id_bins.add(hex_to_binary(task_id_hex))
|
||||
job_task_id_bins.add(hex_to_binary(task_id_hex))
|
||||
|
||||
# Get objects associated with the driver.
|
||||
object_table_objects = ray.objects()
|
||||
driver_object_id_bins = set()
|
||||
job_object_id_bins = set()
|
||||
for object_id, _ in object_table_objects.items():
|
||||
task_id_bin = ray._raylet.compute_task_id(object_id).binary()
|
||||
if task_id_bin in driver_task_id_bins:
|
||||
driver_object_id_bins.add(object_id.binary())
|
||||
if task_id_bin in job_task_id_bins:
|
||||
job_object_id_bins.add(object_id.binary())
|
||||
|
||||
def to_shard_index(id_bin):
|
||||
if len(id_bin) == ray.TaskID.size():
|
||||
@@ -174,10 +174,10 @@ class Monitor(object):
|
||||
|
||||
# Form the redis keys to delete.
|
||||
sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
|
||||
for task_id_bin in driver_task_id_bins:
|
||||
for task_id_bin in job_task_id_bins:
|
||||
sharded_keys[to_shard_index(task_id_bin)].append(
|
||||
xray_task_table_prefix + task_id_bin)
|
||||
for object_id_bin in driver_object_id_bins:
|
||||
for object_id_bin in job_object_id_bins:
|
||||
sharded_keys[to_shard_index(object_id_bin)].append(
|
||||
xray_object_table_prefix + object_id_bin)
|
||||
|
||||
@@ -198,21 +198,21 @@ class Monitor(object):
|
||||
"entries from redis shard {}.".format(
|
||||
len(keys) - num_deleted, shard_index))
|
||||
|
||||
def xray_driver_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
def xray_job_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a job has been removed.
|
||||
|
||||
Args:
|
||||
unused_channel: The message channel.
|
||||
data: The message data.
|
||||
"""
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
|
||||
driver_data = gcs_entries.entries[0]
|
||||
message = ray.gcs_utils.DriverTableData.FromString(driver_data)
|
||||
driver_id = message.driver_id
|
||||
job_data = gcs_entries.entries[0]
|
||||
message = ray.gcs_utils.JobTableData.FromString(job_data)
|
||||
job_id = message.job_id
|
||||
logger.info("Monitor: "
|
||||
"XRay Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
self._xray_clean_up_entries_for_driver(driver_id)
|
||||
binary_to_hex(job_id)))
|
||||
self._xray_clean_up_entries_for_job(job_id)
|
||||
|
||||
def process_messages(self, max_messages=10000):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
@@ -240,9 +240,9 @@ class Monitor(object):
|
||||
if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL:
|
||||
# Similar functionality as raylet info channel
|
||||
message_handler = self.xray_heartbeat_batch_handler
|
||||
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
|
||||
elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL:
|
||||
# Handles driver death.
|
||||
message_handler = self.xray_driver_removed_handler
|
||||
message_handler = self.xray_job_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
@@ -298,7 +298,7 @@ class Monitor(object):
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL)
|
||||
|
||||
# TODO(rkn): If there were any dead clients at startup, we should clean
|
||||
# up the associated state in the state tables.
|
||||
|
||||
@@ -44,7 +44,7 @@ class RemoteFunction(object):
|
||||
return the resulting ObjectIDs. For an example, see
|
||||
"test_decorated_function" in "python/ray/tests/test_basic.py".
|
||||
_function_signature: The function signature.
|
||||
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
|
||||
_last_job_id_exported_for: The ID of the job ID of the last Ray
|
||||
session during which this remote function definition was exported.
|
||||
This is an imperfect mechanism used to determine if we need to
|
||||
export the remote function again. It is imperfect in the sense that
|
||||
@@ -73,7 +73,7 @@ class RemoteFunction(object):
|
||||
self._function_signature = ray.signature.extract_signature(
|
||||
self._function)
|
||||
|
||||
self._last_driver_id_exported_for = None
|
||||
self._last_job_id_exported_for = None
|
||||
|
||||
# Override task.remote's signature and docstring
|
||||
@wraps(function)
|
||||
@@ -115,11 +115,11 @@ class RemoteFunction(object):
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
|
||||
if (self._last_driver_id_exported_for is None
|
||||
or self._last_driver_id_exported_for != worker.task_driver_id):
|
||||
if (self._last_job_id_exported_for is None
|
||||
or self._last_job_id_exported_for != worker.current_job_id):
|
||||
# If this function was exported in a previous session, we need to
|
||||
# export this function again, because current GCS doesn't have it.
|
||||
self._last_driver_id_exported_for = worker.task_driver_id
|
||||
self._last_job_id_exported_for = worker.current_job_id
|
||||
worker.function_actor_manager.export(self)
|
||||
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
@@ -20,7 +20,7 @@ class RuntimeContext(object):
|
||||
a task, return the driver ID of the associated driver.
|
||||
"""
|
||||
assert self.worker is not None
|
||||
return self.worker.task_driver_id
|
||||
return self.worker.current_job_id
|
||||
|
||||
|
||||
_runtime_context = None
|
||||
|
||||
+19
-20
@@ -316,7 +316,7 @@ class GlobalState(object):
|
||||
function_descriptor_list)
|
||||
|
||||
task_spec_info = {
|
||||
"DriverID": task.driver_id().hex(),
|
||||
"JobID": task.job_id().hex(),
|
||||
"TaskID": task.task_id().hex(),
|
||||
"ParentTaskID": task.parent_task_id().hex(),
|
||||
"ParentCounter": task.parent_counter(),
|
||||
@@ -817,19 +817,19 @@ class GlobalState(object):
|
||||
|
||||
return dict(total_available_resources)
|
||||
|
||||
def _error_messages(self, driver_id):
|
||||
def _error_messages(self, job_id):
|
||||
"""Get the error messages for a specific driver.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver to get the errors for.
|
||||
job_id: The ID of the job to get the errors for.
|
||||
|
||||
Returns:
|
||||
A list of the error messages for this driver.
|
||||
"""
|
||||
assert isinstance(driver_id, ray.DriverID)
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
message = self.redis_client.execute_command(
|
||||
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
|
||||
driver_id.binary())
|
||||
job_id.binary())
|
||||
|
||||
# If there are no errors, return early.
|
||||
if message is None:
|
||||
@@ -839,7 +839,7 @@ class GlobalState(object):
|
||||
error_messages = []
|
||||
for entry in gcs_entries.entries:
|
||||
error_data = gcs_utils.ErrorTableData.FromString(entry)
|
||||
assert driver_id.binary() == error_data.driver_id
|
||||
assert job_id.binary() == error_data.job_id
|
||||
error_message = {
|
||||
"type": error_data.type,
|
||||
"message": error_data.error_message,
|
||||
@@ -848,12 +848,12 @@ class GlobalState(object):
|
||||
error_messages.append(error_message)
|
||||
return error_messages
|
||||
|
||||
def error_messages(self, driver_id=None):
|
||||
def error_messages(self, job_id=None):
|
||||
"""Get the error messages for all drivers or a specific driver.
|
||||
|
||||
Args:
|
||||
driver_id: The specific driver to get the errors for. If this is
|
||||
None, then this method retrieves the errors for all drivers.
|
||||
job_id: The specific job to get the errors for. If this is
|
||||
None, then this method retrieves the errors for all jobs.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping driver ID to a list of the error messages for
|
||||
@@ -861,21 +861,20 @@ class GlobalState(object):
|
||||
"""
|
||||
self._check_connected()
|
||||
|
||||
if driver_id is not None:
|
||||
assert isinstance(driver_id, ray.DriverID)
|
||||
return self._error_messages(driver_id)
|
||||
if job_id is not None:
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
return self._error_messages(job_id)
|
||||
|
||||
error_table_keys = self.redis_client.keys(
|
||||
gcs_utils.TablePrefix_ERROR_INFO_string + "*")
|
||||
driver_ids = [
|
||||
job_ids = [
|
||||
key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
|
||||
for key in error_table_keys
|
||||
]
|
||||
|
||||
return {
|
||||
binary_to_hex(driver_id): self._error_messages(
|
||||
ray.DriverID(driver_id))
|
||||
for driver_id in driver_ids
|
||||
binary_to_hex(job_id): self._error_messages(ray.JobID(job_id))
|
||||
for job_id in job_ids
|
||||
}
|
||||
|
||||
def actor_checkpoint_info(self, actor_id):
|
||||
@@ -969,12 +968,12 @@ class DeprecatedGlobalState(object):
|
||||
"instead.")
|
||||
return ray.available_resources()
|
||||
|
||||
def error_messages(self, driver_id=None):
|
||||
def error_messages(self, job_id=None):
|
||||
logger.warning(
|
||||
"ray.global_state.error_messages() is deprecated and will be "
|
||||
"removed in a subsequent release. Use ray.errors() "
|
||||
"instead.")
|
||||
return ray.errors(driver_id=driver_id)
|
||||
return ray.errors(job_id=job_id)
|
||||
|
||||
|
||||
state = GlobalState()
|
||||
@@ -1095,7 +1094,7 @@ def errors(include_cluster_errors=True):
|
||||
Error messages pushed from the cluster.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
error_messages = state.error_messages(driver_id=worker.task_driver_id)
|
||||
error_messages = state.error_messages(job_id=worker.current_job_id)
|
||||
if include_cluster_errors:
|
||||
error_messages += state.error_messages(driver_id=ray.DriverID.nil())
|
||||
error_messages += state.error_messages(job_id=ray.JobID.nil())
|
||||
return error_messages
|
||||
|
||||
@@ -2439,7 +2439,7 @@ def test_global_state_api(shutdown_only):
|
||||
|
||||
assert ray.objects() == {}
|
||||
|
||||
driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
|
||||
job_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
|
||||
driver_task_id = ray.worker.global_worker.current_task_id.hex()
|
||||
|
||||
# One task is put in the task table which corresponds to this driver.
|
||||
@@ -2453,7 +2453,7 @@ def test_global_state_api(shutdown_only):
|
||||
assert task_spec["TaskID"] == driver_task_id
|
||||
assert task_spec["ActorID"] == nil_id_hex
|
||||
assert task_spec["Args"] == []
|
||||
assert task_spec["DriverID"] == driver_id
|
||||
assert task_spec["JobID"] == job_id
|
||||
assert task_spec["FunctionID"] == nil_id_hex
|
||||
assert task_spec["ReturnObjectIDs"] == []
|
||||
|
||||
@@ -2481,7 +2481,7 @@ def test_global_state_api(shutdown_only):
|
||||
task_spec = task_table[task_id]["TaskSpec"]
|
||||
assert task_spec["ActorID"] == nil_id_hex
|
||||
assert task_spec["Args"] == [1, "hi", x_id]
|
||||
assert task_spec["DriverID"] == driver_id
|
||||
assert task_spec["JobID"] == job_id
|
||||
assert task_spec["ReturnObjectIDs"] == [result_id]
|
||||
|
||||
assert task_table[task_id] == ray.tasks(task_id)
|
||||
@@ -2613,9 +2613,9 @@ def test_workers(shutdown_only):
|
||||
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
|
||||
|
||||
|
||||
def test_specific_driver_id():
|
||||
dummy_driver_id = ray.DriverID(b"00112233445566778899")
|
||||
ray.init(num_cpus=1, driver_id=dummy_driver_id)
|
||||
def test_specific_job_id():
|
||||
dummy_driver_id = ray.JobID(b"00112233445566778899")
|
||||
ray.init(num_cpus=1, job_id=dummy_driver_id)
|
||||
|
||||
# in driver
|
||||
assert dummy_driver_id == ray._get_runtime_context().current_driver_id
|
||||
@@ -2727,7 +2727,7 @@ def test_ray_setproctitle(ray_start_2_cpus):
|
||||
def test_duplicate_error_messages(shutdown_only):
|
||||
ray.init(num_cpus=0)
|
||||
|
||||
driver_id = ray.DriverID.nil()
|
||||
driver_id = ray.WorkerID.nil()
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, "test",
|
||||
"message", 0)
|
||||
|
||||
|
||||
+15
-14
@@ -51,7 +51,7 @@ def format_error_message(exception_message, task_exception=False):
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def push_error_to_driver(worker, error_type, message, driver_id=None):
|
||||
def push_error_to_driver(worker, error_type, message, job_id=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Args:
|
||||
@@ -59,19 +59,19 @@ def push_error_to_driver(worker, error_type, message, driver_id=None):
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
driver_id: The ID of the driver to push the error message to. If this
|
||||
job_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray.DriverID.nil()
|
||||
worker.raylet_client.push_error(driver_id, error_type, message,
|
||||
time.time())
|
||||
if job_id is None:
|
||||
job_id = ray.JobID.nil()
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
worker.raylet_client.push_error(job_id, error_type, message, time.time())
|
||||
|
||||
|
||||
def push_error_to_driver_through_redis(redis_client,
|
||||
error_type,
|
||||
message,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Normally the push_error_to_driver function should be used. However, in some
|
||||
@@ -84,19 +84,20 @@ def push_error_to_driver_through_redis(redis_client,
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
driver_id: The ID of the driver to push the error message to. If this
|
||||
job_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray.DriverID.nil()
|
||||
if job_id is None:
|
||||
job_id = ray.JobID.nil()
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
# Do everything in Python and through the Python Redis client instead
|
||||
# of through the raylet.
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
|
||||
error_data = ray.gcs_utils.construct_error_message(job_id, error_type,
|
||||
message, time.time())
|
||||
redis_client.execute_command(
|
||||
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
|
||||
driver_id.binary(), error_data)
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), job_id.binary(),
|
||||
error_data)
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
@@ -443,7 +444,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
|
||||
worker,
|
||||
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=worker.task_driver_id)
|
||||
job_id=worker.current_job_id)
|
||||
|
||||
|
||||
class _ThreadSafeProxy(object):
|
||||
|
||||
+77
-80
@@ -40,7 +40,8 @@ from ray import (
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
DriverID,
|
||||
WorkerID,
|
||||
JobID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
)
|
||||
@@ -145,9 +146,9 @@ class Worker(object):
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
# It is a DriverID.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Identity of the job that this worker is processing.
|
||||
# It is a JobID.
|
||||
self.current_job_id = JobID.nil()
|
||||
self._task_context = threading.local()
|
||||
# This event is checked regularly by all of the threads so that they
|
||||
# know when to exit.
|
||||
@@ -227,24 +228,24 @@ class Worker(object):
|
||||
if self.actor_init_error is not None:
|
||||
raise self.actor_init_error
|
||||
|
||||
def get_serialization_context(self, driver_id):
|
||||
"""Get the SerializationContext of the driver that this worker is processing.
|
||||
def get_serialization_context(self, job_id):
|
||||
"""Get the SerializationContext of the job that this worker is processing.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that indicates which driver to get
|
||||
job_id: The ID of the job that indicates which job to get
|
||||
the serialization context for.
|
||||
|
||||
Returns:
|
||||
The serialization context of the given driver.
|
||||
The serialization context of the given job.
|
||||
"""
|
||||
# This function needs to be proctected by a lock, because it will be
|
||||
# called by`register_class_for_serialization`, as well as the import
|
||||
# thread, from different threads. Also, this function will recursively
|
||||
# call itself, so we use RLock here.
|
||||
with self.lock:
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
if job_id not in self.serialization_context_map:
|
||||
_initialize_serialization(job_id)
|
||||
return self.serialization_context_map[job_id]
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
@@ -314,7 +315,7 @@ class Worker(object):
|
||||
object_id=pyarrow.plasma.ObjectID(object_id.binary()),
|
||||
memcopy_threads=self.memcopy_threads,
|
||||
serialization_context=self.get_serialization_context(
|
||||
self.task_driver_id))
|
||||
self.current_job_id))
|
||||
break
|
||||
except pyarrow.SerializationCallbackError as e:
|
||||
try:
|
||||
@@ -388,17 +389,17 @@ class Worker(object):
|
||||
# should return an error code to the caller instead of printing a
|
||||
# message.
|
||||
logger.info(
|
||||
"The object with ID {} already exists in the object store."
|
||||
.format(object_id))
|
||||
"The object with ID {} already exists in the object store.".
|
||||
format(object_id))
|
||||
except TypeError:
|
||||
# This error can happen because one of the members of the object
|
||||
# may not be serializable for cloudpickle. So we need these extra
|
||||
# fallbacks here to start from the beginning. Hopefully the object
|
||||
# could have a `__reduce__` method.
|
||||
register_custom_serializer(type(value), use_pickle=True)
|
||||
warning_message = ("WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle."
|
||||
.format(type(value)))
|
||||
warning_message = (
|
||||
"WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle.".format(type(value)))
|
||||
logger.warning(warning_message)
|
||||
self.store_and_register(object_id, value)
|
||||
|
||||
@@ -407,7 +408,7 @@ class Worker(object):
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
serialization_context = self.get_serialization_context(
|
||||
self.task_driver_id)
|
||||
self.current_job_id)
|
||||
while True:
|
||||
try:
|
||||
# We divide very large get requests into smaller get requests
|
||||
@@ -449,7 +450,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
warning_sent = True
|
||||
|
||||
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
||||
@@ -575,7 +576,7 @@ class Worker(object):
|
||||
num_return_vals=None,
|
||||
resources=None,
|
||||
placement_resources=None,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with
|
||||
@@ -601,11 +602,11 @@ class Worker(object):
|
||||
placement_resources: The resources required for placing the task.
|
||||
If this is not provided or if it is an empty dictionary, then
|
||||
the placement resources will be equal to resources.
|
||||
driver_id: The ID of the relevant driver. This is almost always the
|
||||
driver ID of the driver that is currently running. However, in
|
||||
job_id: The ID of the relevant job. This is almost always the
|
||||
job ID of the job that is currently running. However, in
|
||||
the exceptional case that an actor task is being dispatched to
|
||||
an actor created by a different driver, this should be the
|
||||
driver ID of the driver that created the actor.
|
||||
an actor created by a different job, this should be the
|
||||
job ID of the job that created the actor.
|
||||
|
||||
Returns:
|
||||
The return object IDs for this task.
|
||||
@@ -642,8 +643,8 @@ class Worker(object):
|
||||
if new_actor_handles is None:
|
||||
new_actor_handles = []
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = self.task_driver_id
|
||||
if job_id is None:
|
||||
job_id = self.current_job_id
|
||||
|
||||
if resources is None:
|
||||
raise ValueError("The resources dictionary is required.")
|
||||
@@ -674,13 +675,13 @@ class Worker(object):
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Current driver id must not be nil when submitting a task.
|
||||
# Because every task must belong to a driver.
|
||||
assert not self.task_driver_id.is_nil()
|
||||
assert not self.current_job_id.is_nil()
|
||||
# Submit the task to raylet.
|
||||
function_descriptor_list = (
|
||||
function_descriptor.get_function_descriptor_list())
|
||||
assert isinstance(driver_id, DriverID)
|
||||
assert isinstance(job_id, JobID)
|
||||
task = ray._raylet.Task(
|
||||
driver_id,
|
||||
job_id,
|
||||
function_descriptor_list,
|
||||
args_for_raylet,
|
||||
num_return_vals,
|
||||
@@ -747,7 +748,7 @@ class Worker(object):
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.binary(),
|
||||
"job_id": self.current_job_id.binary(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickled_function,
|
||||
"run_on_other_drivers": str(run_on_other_drivers)
|
||||
@@ -853,17 +854,17 @@ class Worker(object):
|
||||
assert self.task_context.task_index == 0
|
||||
assert self.task_context.put_index == 1
|
||||
if task.actor_id().is_nil():
|
||||
# If this worker is not an actor, check that `task_driver_id`
|
||||
# If this worker is not an actor, check that `current_job_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.task_driver_id.is_nil()
|
||||
assert self.current_job_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_job_id = task.job_id()
|
||||
else:
|
||||
# If this worker is an actor, task_driver_id wasn't reset.
|
||||
# If this worker is an actor, current_job_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.task_driver_id == task.driver_id()
|
||||
assert self.current_job_id == task.job_id()
|
||||
|
||||
self.task_context.current_task_id = task.task_id()
|
||||
|
||||
@@ -945,7 +946,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
# Mark the actor init as failed
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
@@ -960,7 +961,7 @@ class Worker(object):
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
driver_id = task.driver_id()
|
||||
job_id = task.job_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
@@ -969,7 +970,7 @@ class Worker(object):
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
actor_class = self.function_actor_manager.load_actor_class(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
self.actors[self.actor_id] = actor_class.__new__(actor_class)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
@@ -978,7 +979,7 @@ class Worker(object):
|
||||
)
|
||||
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
# Execute the task.
|
||||
function_name = execution_info.function_name
|
||||
@@ -1005,20 +1006,20 @@ class Worker(object):
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# Don't need to reset `current_job_id` if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
self.current_job_id = WorkerID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
reached_max_executions = (self.function_actor_manager.get_task_counter(
|
||||
driver_id, function_descriptor) == execution_info.max_calls)
|
||||
job_id, function_descriptor) == execution_info.max_calls)
|
||||
if reached_max_executions:
|
||||
self.raylet_client.disconnect()
|
||||
sys.exit(0)
|
||||
@@ -1141,7 +1142,7 @@ def print_failed_task(task_status):
|
||||
task_status["error_message"]))
|
||||
|
||||
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
def _initialize_serialization(job_id, worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
@@ -1177,7 +1178,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
custom_serializer=actor_handle_serializer,
|
||||
custom_deserializer=actor_handle_deserializer)
|
||||
|
||||
worker.serialization_context_map[driver_id] = serialization_context
|
||||
worker.serialization_context_map[job_id] = serialization_context
|
||||
|
||||
# Register exception types.
|
||||
for error_cls in RAY_EXCEPTION_TYPES:
|
||||
@@ -1185,7 +1186,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
error_cls,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id=error_cls.__module__ + ". " + error_cls.__name__,
|
||||
)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
@@ -1193,22 +1194,18 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
type(lambda: 0),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="lambda")
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_custom_serializer(
|
||||
type(int),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="type")
|
||||
type(int), use_pickle=True, local=True, job_id=job_id, class_id="type")
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
register_custom_serializer(
|
||||
ray.signature.FunctionSignature,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="ray.signature.FunctionSignature")
|
||||
|
||||
|
||||
@@ -1231,7 +1228,7 @@ def init(redis_address=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
@@ -1302,7 +1299,7 @@ def init(redis_address=None,
|
||||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which displays the status of the Ray cluster.
|
||||
driver_id: The ID of driver.
|
||||
job_id: The ID of this job.
|
||||
configure_logging: True if allow the logging cofiguration here.
|
||||
Otherwise, the users may want to configure it by their own.
|
||||
logging_level: Logging level, default will be logging.INFO.
|
||||
@@ -1449,7 +1446,7 @@ def init(redis_address=None,
|
||||
mode=driver_mode,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
@@ -1660,10 +1657,10 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
assert len(gcs_entry.entries) == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.FromString(
|
||||
gcs_entry.entries[0])
|
||||
driver_id = error_data.driver_id
|
||||
if driver_id not in [
|
||||
worker.task_driver_id.binary(),
|
||||
DriverID.nil().binary()
|
||||
job_id = error_data.job_id
|
||||
if job_id not in [
|
||||
worker.current_job_id.binary(),
|
||||
JobID.nil().binary()
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -1691,7 +1688,7 @@ def connect(node,
|
||||
mode=WORKER_MODE,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Connect this worker to the raylet, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
@@ -1701,7 +1698,7 @@ def connect(node,
|
||||
log_to_driver (bool): If true, then output from all of the worker
|
||||
processes on all nodes will be directed to the driver.
|
||||
worker: The ray.Worker instance.
|
||||
driver_id: The ID of driver. If it's None, then we will generate one.
|
||||
job_id: The ID of job. If it's None, then we will generate one.
|
||||
"""
|
||||
# Do some basic checking to make sure we didn't call ray.init twice.
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
@@ -1721,20 +1718,20 @@ def connect(node,
|
||||
setproctitle.setproctitle("ray_worker")
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = DriverID.from_random()
|
||||
if job_id is None:
|
||||
job_id = JobID.from_random()
|
||||
|
||||
if not isinstance(driver_id, DriverID):
|
||||
raise TypeError("The type of given driver id must be DriverID.")
|
||||
if not isinstance(job_id, JobID):
|
||||
raise TypeError("The type of given job id must be JobID.")
|
||||
|
||||
worker.worker_id = driver_id.binary()
|
||||
worker.worker_id = job_id.binary()
|
||||
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
# drivers, the current job ID is used to keep track of which driver is
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = DriverID(worker.worker_id)
|
||||
worker.current_job_id = JobID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
@@ -1766,7 +1763,7 @@ def connect(node,
|
||||
worker.redis_client,
|
||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
job_id=None)
|
||||
|
||||
worker.lock = threading.RLock()
|
||||
|
||||
@@ -1831,7 +1828,7 @@ def connect(node,
|
||||
# Create an object store client.
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
|
||||
driver_id_str = _random_string()
|
||||
job_id_str = _random_string()
|
||||
|
||||
# If this is a driver, set the current task ID, the task driver ID, and set
|
||||
# the task index to 0.
|
||||
@@ -1859,11 +1856,11 @@ def connect(node,
|
||||
|
||||
function_descriptor = FunctionDescriptor.for_driver_task()
|
||||
driver_task = ray._raylet.Task(
|
||||
worker.task_driver_id,
|
||||
worker.current_job_id,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
TaskID(driver_id_str[:TaskID.size()]), # parent_task_id.
|
||||
TaskID(job_id_str[:TaskID.size()]), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ActorID.nil(), # actor_creation_id.
|
||||
ObjectID.nil(), # actor_creation_dummy_object_id.
|
||||
@@ -1895,7 +1892,7 @@ def connect(node,
|
||||
node.raylet_socket_name,
|
||||
ClientID(worker.worker_id),
|
||||
(mode == WORKER_MODE),
|
||||
DriverID(driver_id_str),
|
||||
JobID(job_id_str),
|
||||
)
|
||||
|
||||
# Start the import thread
|
||||
@@ -2057,7 +2054,7 @@ def register_custom_serializer(cls,
|
||||
serializer=None,
|
||||
deserializer=None,
|
||||
local=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
@@ -2078,7 +2075,7 @@ def register_custom_serializer(cls,
|
||||
if and only if use_pickle and use_dict are False.
|
||||
local: True if the serializers should only be registered on the current
|
||||
worker. This should usually be False.
|
||||
driver_id: ID of the driver that we want to register the class for.
|
||||
job_id: ID of the job that we want to register the class for.
|
||||
class_id: ID of the class that we are registering. If this is not
|
||||
specified, we will calculate a new one inside the function.
|
||||
|
||||
@@ -2126,9 +2123,9 @@ def register_custom_serializer(cls,
|
||||
# Make sure class_id is a string.
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = worker.task_driver_id
|
||||
assert isinstance(driver_id, DriverID)
|
||||
if job_id is None:
|
||||
job_id = worker.current_job_id
|
||||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2138,7 +2135,7 @@ def register_custom_serializer(cls,
|
||||
# system.
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(driver_id)
|
||||
"worker"].get_serialization_context(job_id)
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
|
||||
@@ -102,7 +102,7 @@ if __name__ == "__main__":
|
||||
ray.worker.global_worker,
|
||||
"worker_crash",
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
job_id=None)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
||||
Reference in New Issue
Block a user