[ID Refactor] Rename DriverID to JobID (#5004)

* WIP

WIP

WIP

Rename Driver -> Job

Fix complition

Fix

Rename in Java

In py

WIP

Fix

WIP

Fix

Fix test

Fix

Fix C++ linting

Fix

* Update java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java

Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu>

* Update src/ray/core_worker/core_worker.cc

Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu>

* Address comments

* Fix

* Fix CI

* Fix cpp linting

* Fix py lint

* FIx

* Address comments and fix

* Address comments

* Address

* Fix import_threading
This commit is contained in:
Qing Wang
2019-06-28 00:44:51 +08:00
committed by GitHub
parent d9768c1cd2
commit 62e4b591e3
79 changed files with 961 additions and 974 deletions
+4 -2
View File
@@ -56,7 +56,8 @@ from ray._raylet import (
ActorID,
ClientID,
Config as _Config,
DriverID,
JobID,
WorkerID,
FunctionID,
ObjectID,
TaskID,
@@ -141,7 +142,8 @@ __all__ += [
"ActorHandleID",
"ActorID",
"ClientID",
"DriverID",
"JobID",
"WorkerID",
"FunctionID",
"ObjectID",
"TaskID",
+6 -6
View File
@@ -221,13 +221,13 @@ cdef class RayletClient:
def __cinit__(self, raylet_socket,
ClientID client_id,
c_bool is_worker,
DriverID driver_id):
JobID job_id):
# We know that we are using Python, so just skip the language
# parameter.
# TODO(suquark): Should we allow unicode chars in "raylet_socket"?
self.client.reset(new CRayletClient(
raylet_socket.encode("ascii"), client_id.native(), is_worker,
driver_id.native(), LANGUAGE_PYTHON))
job_id.native(), LANGUAGE_PYTHON))
def disconnect(self):
check_status(self.client.get().Disconnect())
@@ -293,9 +293,9 @@ cdef class RayletClient:
postincrement(iterator)
return resources_dict
def push_error(self, DriverID driver_id, error_type, error_message,
def push_error(self, JobID job_id, error_type, error_message,
double timestamp):
check_status(self.client.get().PushError(driver_id.native(),
check_status(self.client.get().PushError(job_id.native(),
error_type.encode("ascii"),
error_message.encode("ascii"),
timestamp))
@@ -381,8 +381,8 @@ cdef class RayletClient:
return ClientID(self.client.get().GetClientID().Binary())
@property
def driver_id(self):
return DriverID(self.client.get().GetDriverID().Binary())
def job_id(self):
return JobID(self.client.get().GetJobID().Binary())
@property
def is_worker(self):
+19 -21
View File
@@ -17,8 +17,7 @@ from ray.function_manager import FunctionDescriptor
import ray.ray_constants as ray_constants
import ray.signature as signature
import ray.worker
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID,
DriverID)
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID)
logger = logging.getLogger(__name__)
@@ -186,7 +185,7 @@ class ActorClass(object):
task.
_resources: The default resources required by the actor creation task.
_actor_method_cpus: The number of CPUs required by actor method tasks.
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
_last_job_id_exported_for: The ID of the job of the last Ray
session during which this actor class definition was exported. This
is an imperfect mechanism used to determine if we need to export
the remote function again. It is imperfect in the sense that the
@@ -212,7 +211,7 @@ class ActorClass(object):
self._num_cpus = num_cpus
self._num_gpus = num_gpus
self._resources = resources
self._last_driver_id_exported_for = None
self._last_job_id_exported_for = None
self._actor_methods = inspect.getmembers(
self._modified_class, ray.utils.is_function_or_method)
@@ -345,13 +344,12 @@ class ActorClass(object):
*copy.deepcopy(args), **copy.deepcopy(kwargs))
else:
# Export the actor.
if (self._last_driver_id_exported_for is None
or self._last_driver_id_exported_for !=
worker.task_driver_id):
if (self._last_job_id_exported_for is None or
self._last_job_id_exported_for != worker.current_job_id):
# If this actor class was exported in a previous session, we
# need to export this function again, because current GCS
# doesn't have it.
self._last_driver_id_exported_for = worker.task_driver_id
self._last_job_id_exported_for = worker.current_job_id
worker.function_actor_manager.export_actor_class(
self._modified_class, self._actor_method_names)
@@ -389,7 +387,7 @@ class ActorClass(object):
actor_id, self._modified_class.__module__, self._class_name,
actor_cursor, self._actor_method_names, self._method_decorators,
self._method_signatures, self._actor_method_num_return_vals,
actor_cursor, actor_method_cpu, worker.task_driver_id)
actor_cursor, actor_method_cpu, worker.current_job_id)
# We increment the actor counter by 1 to account for the actor creation
# task.
actor_handle._ray_actor_counter += 1
@@ -446,9 +444,9 @@ class ActorHandle(object):
_ray_original_handle: True if this is the original actor handle for a
given actor. If this is true, then the actor will be destroyed when
this handle goes out of scope.
_ray_actor_driver_id: The driver ID of the job that created the actor
(it is possible that this ActorHandle exists on a driver with a
different driver ID).
_ray_actor_job_id: The ID of the job that created the actor
(it is possible that this ActorHandle exists on a job with a
different job ID).
_ray_new_actor_handles: The new actor handles that were created from
this handle since the last task on this handle was submitted. This
is used to garbage-collect dummy objects that are no longer
@@ -466,10 +464,10 @@ class ActorHandle(object):
method_num_return_vals,
actor_creation_dummy_object_id,
actor_method_cpus,
actor_driver_id,
actor_job_id,
actor_handle_id=None):
assert isinstance(actor_id, ActorID)
assert isinstance(actor_driver_id, DriverID)
assert isinstance(actor_job_id, ray.JobID)
self._ray_actor_id = actor_id
self._ray_module_name = module_name
# False if this actor handle was created by forking or pickling. True
@@ -491,7 +489,7 @@ class ActorHandle(object):
self._ray_actor_creation_dummy_object_id = (
actor_creation_dummy_object_id)
self._ray_actor_method_cpus = actor_method_cpus
self._ray_actor_driver_id = actor_driver_id
self._ray_actor_job_id = actor_job_id
self._ray_new_actor_handles = []
self._ray_actor_lock = threading.Lock()
@@ -551,7 +549,7 @@ class ActorHandle(object):
num_return_vals=num_return_vals + 1,
resources={"CPU": self._ray_actor_method_cpus},
placement_resources={},
driver_id=self._ray_actor_driver_id,
job_id=self._ray_actor_job_id,
)
# Update the actor counter and cursor to reflect the most recent
# invocation.
@@ -612,7 +610,7 @@ class ActorHandle(object):
# not just the first one.
worker = ray.worker.get_global_worker()
if (worker.mode == ray.worker.SCRIPT_MODE
and self._ray_actor_driver_id.binary() != worker.worker_id):
and self._ray_actor_job_id.binary() != worker.worker_id):
# If the worker is a driver and driver id has changed because
# Ray was shut down re-initialized, the actor is already cleaned up
# and we don't need to send `__ray_terminate__` again.
@@ -666,7 +664,7 @@ class ActorHandle(object):
"actor_creation_dummy_object_id": self.
_ray_actor_creation_dummy_object_id,
"actor_method_cpus": self._ray_actor_method_cpus,
"actor_driver_id": self._ray_actor_driver_id,
"actor_job_id": self._ray_actor_job_id,
"ray_forking": ray_forking
}
@@ -727,9 +725,9 @@ class ActorHandle(object):
state["method_num_return_vals"],
state["actor_creation_dummy_object_id"],
state["actor_method_cpus"],
# This is the driver ID of the driver that owns the actor, not
# necessarily the driver that owns this actor handle.
state["actor_driver_id"],
# This is the ID of the job that owns the actor, not
# necessarily the job that owns this actor handle.
state["actor_job_id"],
actor_handle_id=actor_handle_id)
def __getstate__(self):
+63 -63
View File
@@ -277,9 +277,9 @@ class FunctionActorManager(object):
the worker gets connected.
_actors_to_export: The actors to export when the worker gets
connected.
_function_execution_info: The map from driver_id to finction_id
_function_execution_info: The map from job_id to function_id
and execution_info.
_num_task_executions: The map from driver_id to function
_num_task_executions: The map from job_id to function
execution times.
imported_actor_classes: The set of actor classes keys (format:
ActorClass:function_id) that are already in GCS.
@@ -303,17 +303,17 @@ class FunctionActorManager(object):
self._loaded_actor_classes = {}
self.lock = threading.Lock()
def increase_task_counter(self, driver_id, function_descriptor):
def increase_task_counter(self, job_id, function_descriptor):
function_id = function_descriptor.function_id
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
self._num_task_executions[driver_id][function_id] += 1
job_id = ray.JobID.nil()
self._num_task_executions[job_id][function_id] += 1
def get_task_counter(self, driver_id, function_descriptor):
def get_task_counter(self, job_id, function_descriptor):
function_id = function_descriptor.function_id
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
return self._num_task_executions[driver_id][function_id]
job_id = ray.JobID.nil()
return self._num_task_executions[job_id][function_id]
def export_cached(self):
"""Export cached remote functions
@@ -376,11 +376,11 @@ class FunctionActorManager(object):
check_oversized_pickle(pickled_function,
remote_function._function_name,
"remote function", self._worker)
key = (b"RemoteFunction:" + self._worker.task_driver_id.binary() + b":"
key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
+ remote_function._function_descriptor.function_id.binary())
self._worker.redis_client.hmset(
key, {
"driver_id": self._worker.task_driver_id.binary(),
"job_id": self._worker.current_job_id.binary(),
"function_id": remote_function._function_descriptor.
function_id.binary(),
"name": remote_function._function_name,
@@ -392,14 +392,14 @@ class FunctionActorManager(object):
def fetch_and_register_remote_function(self, key):
"""Import a remote function."""
(driver_id_str, function_id_str, function_name, serialized_function,
(job_id_str, function_id_str, function_name, serialized_function,
num_return_vals, module, resources,
max_calls) = self._worker.redis_client.hmget(key, [
"driver_id", "function_id", "name", "function", "num_return_vals",
"job_id", "function_id", "name", "function", "num_return_vals",
"module", "resources", "max_calls"
])
function_id = ray.FunctionID(function_id_str)
driver_id = ray.DriverID(driver_id_str)
job_id = ray.JobID(job_id_str)
function_name = decode(function_name)
max_calls = int(max_calls)
module = decode(module)
@@ -413,12 +413,12 @@ class FunctionActorManager(object):
# atomic. Otherwise, there is race condition. Another thread may use
# the temporary function above before the real function is ready.
with self.lock:
self._function_execution_info[driver_id][function_id] = (
self._function_execution_info[job_id][function_id] = (
FunctionExecutionInfo(
function=f,
function_name=function_name,
max_calls=max_calls))
self._num_task_executions[driver_id][function_id] = 0
self._num_task_executions[job_id][function_id] = 0
try:
function = pickle.loads(serialized_function)
@@ -434,7 +434,7 @@ class FunctionActorManager(object):
"Failed to unpickle the remote function '{}' with "
"function ID {}. Traceback:\n{}".format(
function_name, function_id.hex(), traceback_str),
driver_id=driver_id)
job_id=job_id)
else:
# The below line is necessary. Because in the driver process,
# if the function is defined in the file where the python
@@ -442,7 +442,7 @@ class FunctionActorManager(object):
# However in the worker process, the `__main__` module is a
# different module, which is `default_worker.py`
function.__module__ = module
self._function_execution_info[driver_id][function_id] = (
self._function_execution_info[job_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
@@ -452,11 +452,11 @@ class FunctionActorManager(object):
b"FunctionTable:" + function_id.binary(),
self._worker.worker_id)
def get_execution_info(self, driver_id, function_descriptor):
def get_execution_info(self, job_id, function_descriptor):
"""Get the FunctionExecutionInfo of a remote function.
Args:
driver_id: ID of the driver that the function belongs to.
job_id: ID of the job that the function belongs to.
function_descriptor: The FunctionDescriptor of the function to get.
Returns:
@@ -464,11 +464,11 @@ class FunctionActorManager(object):
"""
if self._worker.load_code_from_local:
# Load function from local code.
# Currently, we don't support isolating code by drivers,
# thus always set driver ID to NIL here.
driver_id = ray.DriverID.nil()
# Currently, we don't support isolating code by jobs,
# thus always set job ID to NIL here.
job_id = ray.JobID.nil()
if not function_descriptor.is_actor_method():
self._load_function_from_local(driver_id, function_descriptor)
self._load_function_from_local(job_id, function_descriptor)
else:
# Load function from GCS.
# Wait until the function to be executed has actually been
@@ -477,21 +477,21 @@ class FunctionActorManager(object):
# The driver function may not be found in sys.path. Try to load
# the function from GCS.
with profiling.profile("wait_for_function"):
self._wait_for_function(function_descriptor, driver_id)
self._wait_for_function(function_descriptor, job_id)
try:
function_id = function_descriptor.function_id
info = self._function_execution_info[driver_id][function_id]
info = self._function_execution_info[job_id][function_id]
except KeyError as e:
message = ("Error occurs in get_execution_info: "
"driver_id: %s, function_descriptor: %s. Message: %s" %
(driver_id, function_descriptor, e))
"job_id: %s, function_descriptor: %s. Message: %s" %
(job_id, function_descriptor, e))
raise KeyError(message)
return info
def _load_function_from_local(self, driver_id, function_descriptor):
def _load_function_from_local(self, job_id, function_descriptor):
assert not function_descriptor.is_actor_method()
function_id = function_descriptor.function_id
if (driver_id in self._function_execution_info
if (job_id in self._function_execution_info
and function_id in self._function_execution_info[function_id]):
return
module_name, function_name = (
@@ -501,13 +501,13 @@ class FunctionActorManager(object):
try:
module = importlib.import_module(module_name)
function = getattr(module, function_name)._function
self._function_execution_info[driver_id][function_id] = (
self._function_execution_info[job_id][function_id] = (
FunctionExecutionInfo(
function=function,
function_name=function_name,
max_calls=0,
))
self._num_task_executions[driver_id][function_id] = 0
self._num_task_executions[job_id][function_id] = 0
except Exception:
logger.exception(
"Failed to load function %s.".format(function_name))
@@ -515,7 +515,7 @@ class FunctionActorManager(object):
"Function {} failed to be loaded from local code.".format(
function_descriptor))
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
def _wait_for_function(self, function_descriptor, job_id, timeout=10):
"""Wait until the function to be executed is present on this worker.
This method will simply loop until the import thread has imported the
@@ -528,7 +528,7 @@ class FunctionActorManager(object):
Args:
function_descriptor : The FunctionDescriptor of the function that
we want to execute.
driver_id (str): The ID of the driver to push the error message to
job_id (str): The ID of the job to push the error message to
if this times out.
"""
start_time = time.time()
@@ -538,7 +538,7 @@ class FunctionActorManager(object):
with self.lock:
if (self._worker.actor_id.is_nil()
and (function_descriptor.function_id in
self._function_execution_info[driver_id])):
self._function_execution_info[job_id])):
break
elif not self._worker.actor_id.is_nil() and (
self._worker.actor_id in self._worker.actors):
@@ -553,7 +553,7 @@ class FunctionActorManager(object):
self._worker,
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
warning_message,
driver_id=driver_id)
job_id=job_id)
warning_sent = True
time.sleep(0.001)
@@ -577,22 +577,22 @@ class FunctionActorManager(object):
if self._worker.load_code_from_local:
return
function_descriptor = FunctionDescriptor.from_class(Class)
# `task_driver_id` shouldn't be NIL, unless:
# `current_job_id` shouldn't be NIL, unless:
# 1) This worker isn't an actor;
# 2) And a previous task started a background thread, which didn't
# finish before the task finished, and still uses Ray API
# after that.
assert not self._worker.task_driver_id.is_nil(), (
assert not self._worker.current_job_id.is_nil(), (
"You might have started a background thread in a non-actor task, "
"please make sure the thread finishes before the task finishes.")
driver_id = self._worker.task_driver_id
key = (b"ActorClass:" + driver_id.binary() + b":" +
job_id = self._worker.current_job_id
key = (b"ActorClass:" + job_id.binary() + b":" +
function_descriptor.function_id.binary())
actor_class_info = {
"class_name": Class.__name__,
"module": Class.__module__,
"class": pickle.dumps(Class),
"driver_id": driver_id.binary(),
"job_id": job_id.binary(),
"actor_method_names": json.dumps(list(actor_method_names))
}
@@ -616,11 +616,11 @@ class FunctionActorManager(object):
# within tasks. I tried to disable this, but it may be necessary
# because of https://github.com/ray-project/ray/issues/1146.
def load_actor_class(self, driver_id, function_descriptor):
def load_actor_class(self, job_id, function_descriptor):
"""Load the actor class.
Args:
driver_id: Driver ID of the actor.
job_id: job ID of the actor.
function_descriptor: Function descriptor of the actor constructor.
Returns:
@@ -632,14 +632,14 @@ class FunctionActorManager(object):
if actor_class is None:
# Load actor class.
if self._worker.load_code_from_local:
driver_id = ray.DriverID.nil()
job_id = ray.JobID.nil()
# Load actor class from local code.
actor_class = self._load_actor_from_local(
driver_id, function_descriptor)
job_id, function_descriptor)
else:
# Load actor class from GCS.
actor_class = self._load_actor_class_from_gcs(
driver_id, function_descriptor)
job_id, function_descriptor)
# Save the loaded actor class in cache.
self._loaded_actor_classes[function_id] = actor_class
@@ -657,18 +657,19 @@ class FunctionActorManager(object):
actor_method,
actor_imported=True,
)
self._function_execution_info[driver_id][method_id] = (
self._function_execution_info[job_id][method_id] = (
FunctionExecutionInfo(
function=executor,
function_name=actor_method_name,
max_calls=0,
))
self._num_task_executions[driver_id][method_id] = 0
self._num_task_executions[driver_id][function_id] = 0
self._num_task_executions[job_id][method_id] = 0
self._num_task_executions[job_id][function_id] = 0
return actor_class
def _load_actor_from_local(self, driver_id, function_descriptor):
def _load_actor_from_local(self, job_id, function_descriptor):
"""Load actor class from local code."""
assert isinstance(job_id, ray.JobID)
module_name, class_name = (function_descriptor.module_name,
function_descriptor.class_name)
try:
@@ -699,9 +700,9 @@ class FunctionActorManager(object):
return TemporaryActor
def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
def _load_actor_class_from_gcs(self, job_id, function_descriptor):
"""Load actor class from GCS."""
key = (b"ActorClass:" + driver_id.binary() + b":" +
key = (b"ActorClass:" + job_id.binary() + b":" +
function_descriptor.function_id.binary())
# Wait for the actor class key to have been imported by the
# import thread. TODO(rkn): It shouldn't be possible to end
@@ -711,16 +712,14 @@ class FunctionActorManager(object):
time.sleep(0.001)
# Fetch raw data from GCS.
(driver_id_str, class_name, module, pickled_class,
(job_id_str, class_name, module, pickled_class,
actor_method_names) = self._worker.redis_client.hmget(
key, [
"driver_id", "class_name", "module", "class",
"actor_method_names"
])
key,
["job_id", "class_name", "module", "class", "actor_method_names"])
class_name = ensure_str(class_name)
module_name = ensure_str(module)
driver_id = ray.DriverID(driver_id_str)
job_id = ray.JobID(job_id_str)
actor_method_names = json.loads(ensure_str(actor_method_names))
actor_class = None
@@ -741,11 +740,12 @@ class FunctionActorManager(object):
traceback.format_exc())
# Log the error message.
push_error_to_driver(
self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
self._worker,
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
"Failed to unpickle actor class '{}' for actor ID {}. "
"Traceback:\n{}".format(class_name,
self._worker.actor_id.hex(),
traceback_str), driver_id)
"Traceback:\n{}".format(
class_name, self._worker.actor_id.hex(), traceback_str),
job_id=job_id)
# TODO(rkn): In the future, it might make sense to have the worker
# exit here. However, currently that would lead to hanging if
# someone calls ray.get on a method invoked on the actor.
@@ -859,7 +859,7 @@ class FunctionActorManager(object):
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)
job_id=self._worker.current_job_id)
def _restore_and_log_checkpoint(self, actor):
"""Restore an actor from a checkpoint if available and log any errors.
@@ -898,4 +898,4 @@ class FunctionActorManager(object):
self._worker,
ray_constants.CHECKPOINT_PUSH_ERROR,
traceback_str,
driver_id=self._worker.task_driver_id)
job_id=self._worker.current_job_id)
+7 -7
View File
@@ -7,7 +7,7 @@ from ray.core.generated.ray.protocol.Task import Task
from ray.core.generated.gcs_pb2 import (
ActorCheckpointIdData,
ClientTableData,
DriverTableData,
JobTableData,
ErrorTableData,
ErrorType,
GcsEntry,
@@ -23,7 +23,7 @@ from ray.core.generated.gcs_pb2 import (
__all__ = [
"ActorCheckpointIdData",
"ClientTableData",
"DriverTableData",
"JobTableData",
"ErrorTableData",
"ErrorType",
"GcsEntry",
@@ -48,8 +48,8 @@ XRAY_HEARTBEAT_CHANNEL = str(
XRAY_HEARTBEAT_BATCH_CHANNEL = str(
TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
# xray driver updates
XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
# xray job updates
XRAY_JOB_CHANNEL = str(TablePubsub.Value("JOB_PUBSUB")).encode("ascii")
# These prefixes must be kept up-to-date with the TablePrefix enum in
# gcs.proto.
@@ -61,11 +61,11 @@ TablePrefix_ERROR_INFO_string = "ERROR_INFO"
TablePrefix_PROFILE_string = "PROFILE"
def construct_error_message(driver_id, error_type, message, timestamp):
def construct_error_message(job_id, error_type, message, timestamp):
"""Construct a serialized ErrorTableData object.
Args:
driver_id: The ID of the driver that the error should go to. If this is
job_id: The ID of the job that the error should go to. If this is
nil, then the error will go to all drivers.
error_type: The type of the error.
message: The error message.
@@ -75,7 +75,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
The serialized object.
"""
data = ErrorTableData()
data.driver_id = driver_id.binary()
data.job_id = job_id.binary()
data.type = error_type
data.error_message = message
data.timestamp = timestamp
+4 -4
View File
@@ -114,13 +114,13 @@ class ImportThread(object):
def fetch_and_execute_function_to_run(self, key):
"""Run on arbitrary function on the worker."""
(driver_id, serialized_function,
(job_id, serialized_function,
run_on_other_drivers) = self.redis_client.hmget(
key, ["driver_id", "function", "run_on_other_drivers"])
key, ["job_id", "function", "run_on_other_drivers"])
if (utils.decode(run_on_other_drivers) == "False"
and self.worker.mode == ray.SCRIPT_MODE
and driver_id != self.worker.task_driver_id.binary()):
and job_id != self.worker.current_job_id.binary()):
return
try:
@@ -140,4 +140,4 @@ class ImportThread(object):
self.worker,
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
traceback_str,
driver_id=ray.DriverID(driver_id))
job_id=ray.JobID(job_id))
+3 -2
View File
@@ -6,7 +6,8 @@ from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector as c_vector
from ray.includes.unique_ids cimport (
CDriverID,
CJobID,
CWorkerID,
CObjectID,
CTaskID,
)
@@ -81,7 +82,7 @@ cdef extern from "ray/common/status.h" namespace "ray::StatusCode" nogil:
cdef extern from "ray/common/id.h" namespace "ray" nogil:
const CTaskID GenerateTaskId(const CDriverID &driver_id,
const CTaskID GenerateTaskId(const CJobID &job_id,
const CTaskID &parent_task_id,
int parent_task_counter)
+5 -4
View File
@@ -14,7 +14,8 @@ from ray.includes.unique_ids cimport (
CActorCheckpointID,
CActorID,
CClientID,
CDriverID,
CJobID,
CWorkerID,
CObjectID,
CTaskID,
)
@@ -46,7 +47,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
cdef cppclass CRayletClient "RayletClient":
CRayletClient(const c_string &raylet_socket,
const CClientID &client_id,
c_bool is_worker, const CDriverID &driver_id,
c_bool is_worker, const CJobID &job_id,
const CLanguage &language)
CRayStatus Disconnect()
CRayStatus SubmitTask(
@@ -62,7 +63,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
int num_returns, int64_t timeout_milliseconds,
c_bool wait_local, const CTaskID &current_task_id,
WaitResultPair *result)
CRayStatus PushError(const CDriverID &driver_id, const c_string &type,
CRayStatus PushError(const CJobID &job_id, const c_string &type,
const c_string &error_message, double timestamp)
CRayStatus PushProfileEvents(
const GCSProfileTableDataT &profile_events)
@@ -75,6 +76,6 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id)
CLanguage GetLanguage() const
CClientID GetClientID() const
CDriverID GetDriverID() const
CJobID GetJobID() const
c_bool IsWorker() const
const ResourceMappingType &GetResourceIDs() const
+4 -4
View File
@@ -12,7 +12,7 @@ from ray.includes.common cimport (
from ray.includes.unique_ids cimport (
CActorHandleID,
CActorID,
CDriverID,
CJobID,
CObjectID,
CTaskID,
)
@@ -46,7 +46,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification":
CTaskSpecification(
const CDriverID &driver_id, const CTaskID &parent_task_id,
const CJobID &job_id, const CTaskID &parent_task_id,
int64_t parent_counter,
const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
int64_t num_returns,
@@ -54,7 +54,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
const CLanguage &language,
const c_vector[c_string] &function_descriptor)
CTaskSpecification(
const CDriverID &driver_id, const CTaskID &parent_task_id,
const CJobID &job_id, const CTaskID &parent_task_id,
int64_t parent_counter, const CActorID &actor_creation_id,
const CObjectID &actor_creation_dummy_object_id,
int64_t max_actor_reconstructions, const CActorID &actor_id,
@@ -70,7 +70,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
c_string SerializeAsString() const
CTaskID TaskId() const
CDriverID DriverId() const
CJobID JobId() const
CTaskID ParentTaskId() const
int64_t ParentCounter() const
c_vector[c_string] FunctionDescriptor() const
+5 -5
View File
@@ -18,7 +18,7 @@ cdef class Task:
unique_ptr[CTaskSpecification] task_spec
unique_ptr[c_vector[CObjectID]] execution_dependencies
def __init__(self, DriverID driver_id, function_descriptor, arguments,
def __init__(self, JobID job_id, function_descriptor, arguments,
int num_returns, TaskID parent_task_id, int parent_counter,
ActorID actor_creation_id,
ObjectID actor_creation_dummy_object_id,
@@ -72,7 +72,7 @@ cdef class Task:
(<ActorHandleID?>new_actor_handle).native())
self.task_spec.reset(new CTaskSpecification(
driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
required_resources, required_placement_resources, LANGUAGE_PYTHON,
@@ -122,9 +122,9 @@ cdef class Task:
return SerializeTaskAsString(
self.execution_dependencies.get(), self.task_spec.get())
def driver_id(self):
"""Return the driver ID for this task."""
return DriverID(self.task_spec.get().DriverId().Binary())
def job_id(self):
"""Return the job ID for this task."""
return JobID(self.task_spec.get().JobId().Binary())
def task_id(self):
"""Return the task ID for this task."""
+2 -2
View File
@@ -78,10 +78,10 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
@staticmethod
CFunctionID FromBinary(const c_string &binary)
cdef cppclass CDriverID "ray::DriverID"(CUniqueID):
cdef cppclass CJobID "ray::JobID"(CUniqueID):
@staticmethod
CDriverID FromBinary(const c_string &binary)
CJobID FromBinary(const c_string &binary)
cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]):
+15 -6
View File
@@ -15,7 +15,7 @@ from ray.includes.unique_ids cimport (
CActorID,
CClientID,
CConfigID,
CDriverID,
CJobID,
CFunctionID,
CObjectID,
CTaskID,
@@ -212,15 +212,23 @@ cdef class ClientID(UniqueID):
return <CClientID>self.data
cdef class DriverID(UniqueID):
cdef class JobID(UniqueID):
def __init__(self, id):
check_id(id)
self.data = CDriverID.FromBinary(<c_string>id)
self.data = CJobID.FromBinary(<c_string>id)
cdef CDriverID native(self):
return <CDriverID>self.data
cdef CJobID native(self):
return <CJobID>self.data
cdef class WorkerID(UniqueID):
def __init__(self, id):
check_id(id)
self.data = CWorkerID.FromBinary(<c_string>id)
cdef CWorkerID native(self):
return <CWorkerID>self.data
cdef class ActorID(UniqueID):
@@ -277,7 +285,8 @@ _ID_TYPES = [
ActorHandleID,
ActorID,
ClientID,
DriverID,
JobID,
WorkerID,
FunctionID,
ObjectID,
TaskID,
+23 -23
View File
@@ -130,14 +130,14 @@ class Monitor(object):
"Monitor: "
"could not find ip for client {}".format(client_id))
def _xray_clean_up_entries_for_driver(self, driver_id):
"""Remove this driver's object/task entries from redis.
def _xray_clean_up_entries_for_job(self, job_id):
"""Remove this job's object/task entries from redis.
Removes control-state entries of all tasks and task return
objects belonging to the driver.
Args:
driver_id: The driver id.
job_id: The job id.
"""
xray_task_table_prefix = (
@@ -146,23 +146,23 @@ class Monitor(object):
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
task_table_objects = ray.tasks()
driver_id_hex = binary_to_hex(driver_id)
driver_task_id_bins = set()
job_id_hex = binary_to_hex(job_id)
job_task_id_bins = set()
for task_id_hex, task_info in task_table_objects.items():
task_table_object = task_info["TaskSpec"]
task_driver_id_hex = task_table_object["DriverID"]
if driver_id_hex != task_driver_id_hex:
task_job_id_hex = task_table_object["JobID"]
if job_id_hex != task_job_id_hex:
# Ignore tasks that aren't from this driver.
continue
driver_task_id_bins.add(hex_to_binary(task_id_hex))
job_task_id_bins.add(hex_to_binary(task_id_hex))
# Get objects associated with the driver.
object_table_objects = ray.objects()
driver_object_id_bins = set()
job_object_id_bins = set()
for object_id, _ in object_table_objects.items():
task_id_bin = ray._raylet.compute_task_id(object_id).binary()
if task_id_bin in driver_task_id_bins:
driver_object_id_bins.add(object_id.binary())
if task_id_bin in job_task_id_bins:
job_object_id_bins.add(object_id.binary())
def to_shard_index(id_bin):
if len(id_bin) == ray.TaskID.size():
@@ -174,10 +174,10 @@ class Monitor(object):
# Form the redis keys to delete.
sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
for task_id_bin in driver_task_id_bins:
for task_id_bin in job_task_id_bins:
sharded_keys[to_shard_index(task_id_bin)].append(
xray_task_table_prefix + task_id_bin)
for object_id_bin in driver_object_id_bins:
for object_id_bin in job_object_id_bins:
sharded_keys[to_shard_index(object_id_bin)].append(
xray_object_table_prefix + object_id_bin)
@@ -198,21 +198,21 @@ class Monitor(object):
"entries from redis shard {}.".format(
len(keys) - num_deleted, shard_index))
def xray_driver_removed_handler(self, unused_channel, data):
"""Handle a notification that a driver has been removed.
def xray_job_removed_handler(self, unused_channel, data):
"""Handle a notification that a job has been removed.
Args:
unused_channel: The message channel.
data: The message data.
"""
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
driver_data = gcs_entries.entries[0]
message = ray.gcs_utils.DriverTableData.FromString(driver_data)
driver_id = message.driver_id
job_data = gcs_entries.entries[0]
message = ray.gcs_utils.JobTableData.FromString(job_data)
job_id = message.job_id
logger.info("Monitor: "
"XRay Driver {} has been removed.".format(
binary_to_hex(driver_id)))
self._xray_clean_up_entries_for_driver(driver_id)
binary_to_hex(job_id)))
self._xray_clean_up_entries_for_job(job_id)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@@ -240,9 +240,9 @@ class Monitor(object):
if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL:
# Similar functionality as raylet info channel
message_handler = self.xray_heartbeat_batch_handler
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL:
# Handles driver death.
message_handler = self.xray_driver_removed_handler
message_handler = self.xray_job_removed_handler
else:
raise Exception("This code should be unreachable.")
@@ -298,7 +298,7 @@ class Monitor(object):
"""
# Initialize the subscription channel.
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL)
# TODO(rkn): If there were any dead clients at startup, we should clean
# up the associated state in the state tables.
+5 -5
View File
@@ -44,7 +44,7 @@ class RemoteFunction(object):
return the resulting ObjectIDs. For an example, see
"test_decorated_function" in "python/ray/tests/test_basic.py".
_function_signature: The function signature.
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
_last_job_id_exported_for: The ID of the job ID of the last Ray
session during which this remote function definition was exported.
This is an imperfect mechanism used to determine if we need to
export the remote function again. It is imperfect in the sense that
@@ -73,7 +73,7 @@ class RemoteFunction(object):
self._function_signature = ray.signature.extract_signature(
self._function)
self._last_driver_id_exported_for = None
self._last_job_id_exported_for = None
# Override task.remote's signature and docstring
@wraps(function)
@@ -115,11 +115,11 @@ class RemoteFunction(object):
worker = ray.worker.get_global_worker()
worker.check_connected()
if (self._last_driver_id_exported_for is None
or self._last_driver_id_exported_for != worker.task_driver_id):
if (self._last_job_id_exported_for is None
or self._last_job_id_exported_for != worker.current_job_id):
# If this function was exported in a previous session, we need to
# export this function again, because current GCS doesn't have it.
self._last_driver_id_exported_for = worker.task_driver_id
self._last_job_id_exported_for = worker.current_job_id
worker.function_actor_manager.export(self)
kwargs = {} if kwargs is None else kwargs
+1 -1
View File
@@ -20,7 +20,7 @@ class RuntimeContext(object):
a task, return the driver ID of the associated driver.
"""
assert self.worker is not None
return self.worker.task_driver_id
return self.worker.current_job_id
_runtime_context = None
+19 -20
View File
@@ -316,7 +316,7 @@ class GlobalState(object):
function_descriptor_list)
task_spec_info = {
"DriverID": task.driver_id().hex(),
"JobID": task.job_id().hex(),
"TaskID": task.task_id().hex(),
"ParentTaskID": task.parent_task_id().hex(),
"ParentCounter": task.parent_counter(),
@@ -817,19 +817,19 @@ class GlobalState(object):
return dict(total_available_resources)
def _error_messages(self, driver_id):
def _error_messages(self, job_id):
"""Get the error messages for a specific driver.
Args:
driver_id: The ID of the driver to get the errors for.
job_id: The ID of the job to get the errors for.
Returns:
A list of the error messages for this driver.
"""
assert isinstance(driver_id, ray.DriverID)
assert isinstance(job_id, ray.JobID)
message = self.redis_client.execute_command(
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
driver_id.binary())
job_id.binary())
# If there are no errors, return early.
if message is None:
@@ -839,7 +839,7 @@ class GlobalState(object):
error_messages = []
for entry in gcs_entries.entries:
error_data = gcs_utils.ErrorTableData.FromString(entry)
assert driver_id.binary() == error_data.driver_id
assert job_id.binary() == error_data.job_id
error_message = {
"type": error_data.type,
"message": error_data.error_message,
@@ -848,12 +848,12 @@ class GlobalState(object):
error_messages.append(error_message)
return error_messages
def error_messages(self, driver_id=None):
def error_messages(self, job_id=None):
"""Get the error messages for all drivers or a specific driver.
Args:
driver_id: The specific driver to get the errors for. If this is
None, then this method retrieves the errors for all drivers.
job_id: The specific job to get the errors for. If this is
None, then this method retrieves the errors for all jobs.
Returns:
A dictionary mapping driver ID to a list of the error messages for
@@ -861,21 +861,20 @@ class GlobalState(object):
"""
self._check_connected()
if driver_id is not None:
assert isinstance(driver_id, ray.DriverID)
return self._error_messages(driver_id)
if job_id is not None:
assert isinstance(job_id, ray.JobID)
return self._error_messages(job_id)
error_table_keys = self.redis_client.keys(
gcs_utils.TablePrefix_ERROR_INFO_string + "*")
driver_ids = [
job_ids = [
key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
for key in error_table_keys
]
return {
binary_to_hex(driver_id): self._error_messages(
ray.DriverID(driver_id))
for driver_id in driver_ids
binary_to_hex(job_id): self._error_messages(ray.JobID(job_id))
for job_id in job_ids
}
def actor_checkpoint_info(self, actor_id):
@@ -969,12 +968,12 @@ class DeprecatedGlobalState(object):
"instead.")
return ray.available_resources()
def error_messages(self, driver_id=None):
def error_messages(self, job_id=None):
logger.warning(
"ray.global_state.error_messages() is deprecated and will be "
"removed in a subsequent release. Use ray.errors() "
"instead.")
return ray.errors(driver_id=driver_id)
return ray.errors(job_id=job_id)
state = GlobalState()
@@ -1095,7 +1094,7 @@ def errors(include_cluster_errors=True):
Error messages pushed from the cluster.
"""
worker = ray.worker.global_worker
error_messages = state.error_messages(driver_id=worker.task_driver_id)
error_messages = state.error_messages(job_id=worker.current_job_id)
if include_cluster_errors:
error_messages += state.error_messages(driver_id=ray.DriverID.nil())
error_messages += state.error_messages(job_id=ray.JobID.nil())
return error_messages
+7 -7
View File
@@ -2439,7 +2439,7 @@ def test_global_state_api(shutdown_only):
assert ray.objects() == {}
driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
job_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
driver_task_id = ray.worker.global_worker.current_task_id.hex()
# One task is put in the task table which corresponds to this driver.
@@ -2453,7 +2453,7 @@ def test_global_state_api(shutdown_only):
assert task_spec["TaskID"] == driver_task_id
assert task_spec["ActorID"] == nil_id_hex
assert task_spec["Args"] == []
assert task_spec["DriverID"] == driver_id
assert task_spec["JobID"] == job_id
assert task_spec["FunctionID"] == nil_id_hex
assert task_spec["ReturnObjectIDs"] == []
@@ -2481,7 +2481,7 @@ def test_global_state_api(shutdown_only):
task_spec = task_table[task_id]["TaskSpec"]
assert task_spec["ActorID"] == nil_id_hex
assert task_spec["Args"] == [1, "hi", x_id]
assert task_spec["DriverID"] == driver_id
assert task_spec["JobID"] == job_id
assert task_spec["ReturnObjectIDs"] == [result_id]
assert task_table[task_id] == ray.tasks(task_id)
@@ -2613,9 +2613,9 @@ def test_workers(shutdown_only):
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
def test_specific_driver_id():
dummy_driver_id = ray.DriverID(b"00112233445566778899")
ray.init(num_cpus=1, driver_id=dummy_driver_id)
def test_specific_job_id():
dummy_driver_id = ray.JobID(b"00112233445566778899")
ray.init(num_cpus=1, job_id=dummy_driver_id)
# in driver
assert dummy_driver_id == ray._get_runtime_context().current_driver_id
@@ -2727,7 +2727,7 @@ def test_ray_setproctitle(ray_start_2_cpus):
def test_duplicate_error_messages(shutdown_only):
ray.init(num_cpus=0)
driver_id = ray.DriverID.nil()
driver_id = ray.WorkerID.nil()
error_data = ray.gcs_utils.construct_error_message(driver_id, "test",
"message", 0)
+15 -14
View File
@@ -51,7 +51,7 @@ def format_error_message(exception_message, task_exception=False):
return "\n".join(lines)
def push_error_to_driver(worker, error_type, message, driver_id=None):
def push_error_to_driver(worker, error_type, message, job_id=None):
"""Push an error message to the driver to be printed in the background.
Args:
@@ -59,19 +59,19 @@ def push_error_to_driver(worker, error_type, message, driver_id=None):
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
driver_id: The ID of the driver to push the error message to. If this
job_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
"""
if driver_id is None:
driver_id = ray.DriverID.nil()
worker.raylet_client.push_error(driver_id, error_type, message,
time.time())
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
worker.raylet_client.push_error(job_id, error_type, message, time.time())
def push_error_to_driver_through_redis(redis_client,
error_type,
message,
driver_id=None):
job_id=None):
"""Push an error message to the driver to be printed in the background.
Normally the push_error_to_driver function should be used. However, in some
@@ -84,19 +84,20 @@ def push_error_to_driver_through_redis(redis_client,
error_type (str): The type of the error.
message (str): The message that will be printed in the background
on the driver.
driver_id: The ID of the driver to push the error message to. If this
job_id: The ID of the driver to push the error message to. If this
is None, then the message will be pushed to all drivers.
"""
if driver_id is None:
driver_id = ray.DriverID.nil()
if job_id is None:
job_id = ray.JobID.nil()
assert isinstance(job_id, ray.JobID)
# Do everything in Python and through the Python Redis client instead
# of through the raylet.
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
error_data = ray.gcs_utils.construct_error_message(job_id, error_type,
message, time.time())
redis_client.execute_command(
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
driver_id.binary(), error_data)
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), job_id.binary(),
error_data)
def is_cython(obj):
@@ -443,7 +444,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
worker,
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
warning_message,
driver_id=worker.task_driver_id)
job_id=worker.current_job_id)
class _ThreadSafeProxy(object):
+77 -80
View File
@@ -40,7 +40,8 @@ from ray import (
ActorHandleID,
ActorID,
ClientID,
DriverID,
WorkerID,
JobID,
ObjectID,
TaskID,
)
@@ -145,9 +146,9 @@ class Worker(object):
# TODO: clean up the SerializationContext once the job finished.
self.serialization_context_map = {}
self.function_actor_manager = FunctionActorManager(self)
# Identity of the driver that this worker is processing.
# It is a DriverID.
self.task_driver_id = DriverID.nil()
# Identity of the job that this worker is processing.
# It is a JobID.
self.current_job_id = JobID.nil()
self._task_context = threading.local()
# This event is checked regularly by all of the threads so that they
# know when to exit.
@@ -227,24 +228,24 @@ class Worker(object):
if self.actor_init_error is not None:
raise self.actor_init_error
def get_serialization_context(self, driver_id):
"""Get the SerializationContext of the driver that this worker is processing.
def get_serialization_context(self, job_id):
"""Get the SerializationContext of the job that this worker is processing.
Args:
driver_id: The ID of the driver that indicates which driver to get
job_id: The ID of the job that indicates which job to get
the serialization context for.
Returns:
The serialization context of the given driver.
The serialization context of the given job.
"""
# This function needs to be proctected by a lock, because it will be
# called by`register_class_for_serialization`, as well as the import
# thread, from different threads. Also, this function will recursively
# call itself, so we use RLock here.
with self.lock:
if driver_id not in self.serialization_context_map:
_initialize_serialization(driver_id)
return self.serialization_context_map[driver_id]
if job_id not in self.serialization_context_map:
_initialize_serialization(job_id)
return self.serialization_context_map[job_id]
def check_connected(self):
"""Check if the worker is connected.
@@ -314,7 +315,7 @@ class Worker(object):
object_id=pyarrow.plasma.ObjectID(object_id.binary()),
memcopy_threads=self.memcopy_threads,
serialization_context=self.get_serialization_context(
self.task_driver_id))
self.current_job_id))
break
except pyarrow.SerializationCallbackError as e:
try:
@@ -388,17 +389,17 @@ class Worker(object):
# should return an error code to the caller instead of printing a
# message.
logger.info(
"The object with ID {} already exists in the object store."
.format(object_id))
"The object with ID {} already exists in the object store.".
format(object_id))
except TypeError:
# This error can happen because one of the members of the object
# may not be serializable for cloudpickle. So we need these extra
# fallbacks here to start from the beginning. Hopefully the object
# could have a `__reduce__` method.
register_custom_serializer(type(value), use_pickle=True)
warning_message = ("WARNING: Serializing the class {} failed, "
"so are are falling back to cloudpickle."
.format(type(value)))
warning_message = (
"WARNING: Serializing the class {} failed, "
"so are are falling back to cloudpickle.".format(type(value)))
logger.warning(warning_message)
self.store_and_register(object_id, value)
@@ -407,7 +408,7 @@ class Worker(object):
# Only send the warning once.
warning_sent = False
serialization_context = self.get_serialization_context(
self.task_driver_id)
self.current_job_id)
while True:
try:
# We divide very large get requests into smaller get requests
@@ -449,7 +450,7 @@ class Worker(object):
self,
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
warning_message,
driver_id=self.task_driver_id)
job_id=self.current_job_id)
warning_sent = True
def _deserialize_object_from_arrow(self, data, metadata, object_id,
@@ -575,7 +576,7 @@ class Worker(object):
num_return_vals=None,
resources=None,
placement_resources=None,
driver_id=None):
job_id=None):
"""Submit a remote task to the scheduler.
Tell the scheduler to schedule the execution of the function with
@@ -601,11 +602,11 @@ class Worker(object):
placement_resources: The resources required for placing the task.
If this is not provided or if it is an empty dictionary, then
the placement resources will be equal to resources.
driver_id: The ID of the relevant driver. This is almost always the
driver ID of the driver that is currently running. However, in
job_id: The ID of the relevant job. This is almost always the
job ID of the job that is currently running. However, in
the exceptional case that an actor task is being dispatched to
an actor created by a different driver, this should be the
driver ID of the driver that created the actor.
an actor created by a different job, this should be the
job ID of the job that created the actor.
Returns:
The return object IDs for this task.
@@ -642,8 +643,8 @@ class Worker(object):
if new_actor_handles is None:
new_actor_handles = []
if driver_id is None:
driver_id = self.task_driver_id
if job_id is None:
job_id = self.current_job_id
if resources is None:
raise ValueError("The resources dictionary is required.")
@@ -674,13 +675,13 @@ class Worker(object):
assert not self.current_task_id.is_nil()
# Current driver id must not be nil when submitting a task.
# Because every task must belong to a driver.
assert not self.task_driver_id.is_nil()
assert not self.current_job_id.is_nil()
# Submit the task to raylet.
function_descriptor_list = (
function_descriptor.get_function_descriptor_list())
assert isinstance(driver_id, DriverID)
assert isinstance(job_id, JobID)
task = ray._raylet.Task(
driver_id,
job_id,
function_descriptor_list,
args_for_raylet,
num_return_vals,
@@ -747,7 +748,7 @@ class Worker(object):
# Run the function on all workers.
self.redis_client.hmset(
key, {
"driver_id": self.task_driver_id.binary(),
"job_id": self.current_job_id.binary(),
"function_id": function_to_run_id,
"function": pickled_function,
"run_on_other_drivers": str(run_on_other_drivers)
@@ -853,17 +854,17 @@ class Worker(object):
assert self.task_context.task_index == 0
assert self.task_context.put_index == 1
if task.actor_id().is_nil():
# If this worker is not an actor, check that `task_driver_id`
# If this worker is not an actor, check that `current_job_id`
# was reset when the worker finished the previous task.
assert self.task_driver_id.is_nil()
assert self.current_job_id.is_nil()
# Set the driver ID of the current running task. This is
# needed so that if the task throws an exception, we propagate
# the error message to the correct driver.
self.task_driver_id = task.driver_id()
self.current_job_id = task.job_id()
else:
# If this worker is an actor, task_driver_id wasn't reset.
# If this worker is an actor, current_job_id wasn't reset.
# Check that current task's driver ID equals the previous one.
assert self.task_driver_id == task.driver_id()
assert self.current_job_id == task.job_id()
self.task_context.current_task_id = task.task_id()
@@ -945,7 +946,7 @@ class Worker(object):
self,
ray_constants.TASK_PUSH_ERROR,
str(failure_object),
driver_id=self.task_driver_id)
job_id=self.current_job_id)
# Mark the actor init as failed
if not self.actor_id.is_nil() and function_name == "__init__":
self.mark_actor_init_failed(error)
@@ -960,7 +961,7 @@ class Worker(object):
"""
function_descriptor = FunctionDescriptor.from_bytes_list(
task.function_descriptor_list())
driver_id = task.driver_id()
job_id = task.job_id()
# TODO(rkn): It would be preferable for actor creation tasks to share
# more of the code path with regular task execution.
@@ -969,7 +970,7 @@ class Worker(object):
self.actor_id = task.actor_creation_id()
self.actor_creation_task_id = task.task_id()
actor_class = self.function_actor_manager.load_actor_class(
driver_id, function_descriptor)
job_id, function_descriptor)
self.actors[self.actor_id] = actor_class.__new__(actor_class)
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
num_tasks_since_last_checkpoint=0,
@@ -978,7 +979,7 @@ class Worker(object):
)
execution_info = self.function_actor_manager.get_execution_info(
driver_id, function_descriptor)
job_id, function_descriptor)
# Execute the task.
function_name = execution_info.function_name
@@ -1005,20 +1006,20 @@ class Worker(object):
self.task_context.task_index = 0
self.task_context.put_index = 1
if self.actor_id.is_nil():
# Don't need to reset task_driver_id if the worker is an
# Don't need to reset `current_job_id` if the worker is an
# actor. Because the following tasks should all have the
# same driver id.
self.task_driver_id = DriverID.nil()
self.current_job_id = WorkerID.nil()
# Reset signal counters so that the next task can get
# all past signals.
ray_signal.reset()
# Increase the task execution counter.
self.function_actor_manager.increase_task_counter(
driver_id, function_descriptor)
job_id, function_descriptor)
reached_max_executions = (self.function_actor_manager.get_task_counter(
driver_id, function_descriptor) == execution_info.max_calls)
job_id, function_descriptor) == execution_info.max_calls)
if reached_max_executions:
self.raylet_client.disconnect()
sys.exit(0)
@@ -1141,7 +1142,7 @@ def print_failed_task(task_status):
task_status["error_message"]))
def _initialize_serialization(driver_id, worker=global_worker):
def _initialize_serialization(job_id, worker=global_worker):
"""Initialize the serialization library.
This defines a custom serializer for object IDs and also tells ray to
@@ -1177,7 +1178,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
custom_serializer=actor_handle_serializer,
custom_deserializer=actor_handle_deserializer)
worker.serialization_context_map[driver_id] = serialization_context
worker.serialization_context_map[job_id] = serialization_context
# Register exception types.
for error_cls in RAY_EXCEPTION_TYPES:
@@ -1185,7 +1186,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
error_cls,
use_dict=True,
local=True,
driver_id=driver_id,
job_id=job_id,
class_id=error_cls.__module__ + ". " + error_cls.__name__,
)
# Tell Ray to serialize lambdas with pickle.
@@ -1193,22 +1194,18 @@ def _initialize_serialization(driver_id, worker=global_worker):
type(lambda: 0),
use_pickle=True,
local=True,
driver_id=driver_id,
job_id=job_id,
class_id="lambda")
# Tell Ray to serialize types with pickle.
register_custom_serializer(
type(int),
use_pickle=True,
local=True,
driver_id=driver_id,
class_id="type")
type(int), use_pickle=True, local=True, job_id=job_id, class_id="type")
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
# used when passing around actor handles.
register_custom_serializer(
ray.signature.FunctionSignature,
use_dict=True,
local=True,
driver_id=driver_id,
job_id=job_id,
class_id="ray.signature.FunctionSignature")
@@ -1231,7 +1228,7 @@ def init(redis_address=None,
plasma_directory=None,
huge_pages=False,
include_webui=False,
driver_id=None,
job_id=None,
configure_logging=True,
logging_level=logging.INFO,
logging_format=ray_constants.LOGGER_FORMAT,
@@ -1302,7 +1299,7 @@ def init(redis_address=None,
Store with hugetlbfs support. Requires plasma_directory.
include_webui: Boolean flag indicating whether to start the web
UI, which displays the status of the Ray cluster.
driver_id: The ID of driver.
job_id: The ID of this job.
configure_logging: True if allow the logging cofiguration here.
Otherwise, the users may want to configure it by their own.
logging_level: Logging level, default will be logging.INFO.
@@ -1449,7 +1446,7 @@ def init(redis_address=None,
mode=driver_mode,
log_to_driver=log_to_driver,
worker=global_worker,
driver_id=driver_id)
job_id=job_id)
for hook in _post_init_hooks:
hook()
@@ -1660,10 +1657,10 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
assert len(gcs_entry.entries) == 1
error_data = ray.gcs_utils.ErrorTableData.FromString(
gcs_entry.entries[0])
driver_id = error_data.driver_id
if driver_id not in [
worker.task_driver_id.binary(),
DriverID.nil().binary()
job_id = error_data.job_id
if job_id not in [
worker.current_job_id.binary(),
JobID.nil().binary()
]:
continue
@@ -1691,7 +1688,7 @@ def connect(node,
mode=WORKER_MODE,
log_to_driver=False,
worker=global_worker,
driver_id=None):
job_id=None):
"""Connect this worker to the raylet, to Plasma, and to Redis.
Args:
@@ -1701,7 +1698,7 @@ def connect(node,
log_to_driver (bool): If true, then output from all of the worker
processes on all nodes will be directed to the driver.
worker: The ray.Worker instance.
driver_id: The ID of driver. If it's None, then we will generate one.
job_id: The ID of job. If it's None, then we will generate one.
"""
# Do some basic checking to make sure we didn't call ray.init twice.
error_message = "Perhaps you called ray.init twice by accident?"
@@ -1721,20 +1718,20 @@ def connect(node,
setproctitle.setproctitle("ray_worker")
else:
# This is the code path of driver mode.
if driver_id is None:
driver_id = DriverID.from_random()
if job_id is None:
job_id = JobID.from_random()
if not isinstance(driver_id, DriverID):
raise TypeError("The type of given driver id must be DriverID.")
if not isinstance(job_id, JobID):
raise TypeError("The type of given job id must be JobID.")
worker.worker_id = driver_id.binary()
worker.worker_id = job_id.binary()
# When tasks are executed on remote workers in the context of multiple
# drivers, the task driver ID is used to keep track of which driver is
# drivers, the current job ID is used to keep track of which driver is
# responsible for the task so that error messages will be propagated to
# the correct driver.
if mode != WORKER_MODE:
worker.task_driver_id = DriverID(worker.worker_id)
worker.current_job_id = JobID(worker.worker_id)
# All workers start out as non-actors. A worker can be turned into an actor
# after it is created.
@@ -1766,7 +1763,7 @@ def connect(node,
worker.redis_client,
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
traceback_str,
driver_id=None)
job_id=None)
worker.lock = threading.RLock()
@@ -1831,7 +1828,7 @@ def connect(node,
# Create an object store client.
worker.plasma_client = thread_safe_client(
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
driver_id_str = _random_string()
job_id_str = _random_string()
# If this is a driver, set the current task ID, the task driver ID, and set
# the task index to 0.
@@ -1859,11 +1856,11 @@ def connect(node,
function_descriptor = FunctionDescriptor.for_driver_task()
driver_task = ray._raylet.Task(
worker.task_driver_id,
worker.current_job_id,
function_descriptor.get_function_descriptor_list(),
[], # arguments.
0, # num_returns.
TaskID(driver_id_str[:TaskID.size()]), # parent_task_id.
TaskID(job_id_str[:TaskID.size()]), # parent_task_id.
0, # parent_counter.
ActorID.nil(), # actor_creation_id.
ObjectID.nil(), # actor_creation_dummy_object_id.
@@ -1895,7 +1892,7 @@ def connect(node,
node.raylet_socket_name,
ClientID(worker.worker_id),
(mode == WORKER_MODE),
DriverID(driver_id_str),
JobID(job_id_str),
)
# Start the import thread
@@ -2057,7 +2054,7 @@ def register_custom_serializer(cls,
serializer=None,
deserializer=None,
local=False,
driver_id=None,
job_id=None,
class_id=None):
"""Enable serialization and deserialization for a particular class.
@@ -2078,7 +2075,7 @@ def register_custom_serializer(cls,
if and only if use_pickle and use_dict are False.
local: True if the serializers should only be registered on the current
worker. This should usually be False.
driver_id: ID of the driver that we want to register the class for.
job_id: ID of the job that we want to register the class for.
class_id: ID of the class that we are registering. If this is not
specified, we will calculate a new one inside the function.
@@ -2126,9 +2123,9 @@ def register_custom_serializer(cls,
# Make sure class_id is a string.
class_id = ray.utils.binary_to_hex(class_id)
if driver_id is None:
driver_id = worker.task_driver_id
assert isinstance(driver_id, DriverID)
if job_id is None:
job_id = worker.current_job_id
assert isinstance(job_id, JobID)
def register_class_for_serialization(worker_info):
# TODO(rkn): We need to be more thoughtful about what to do if custom
@@ -2138,7 +2135,7 @@ def register_custom_serializer(cls,
# system.
serialization_context = worker_info[
"worker"].get_serialization_context(driver_id)
"worker"].get_serialization_context(job_id)
serialization_context.register_type(
cls,
class_id,
+1 -1
View File
@@ -102,7 +102,7 @@ if __name__ == "__main__":
ray.worker.global_worker,
"worker_crash",
traceback_str,
driver_id=None)
job_id=None)
# TODO(rkn): Note that if the worker was in the middle of executing
# a task, then any worker or driver that is blocking in a get call
# and waiting for the output of that task will hang. We need to