From 62e4b591e3d6443ce25b0f05cc32b43d5e2ebb3d Mon Sep 17 00:00:00 2001 From: Qing Wang Date: Fri, 28 Jun 2019 00:44:51 +0800 Subject: [PATCH] [ID Refactor] Rename DriverID to JobID (#5004) * WIP WIP WIP Rename Driver -> Job Fix complition Fix Rename in Java In py WIP Fix WIP Fix Fix test Fix Fix C++ linting Fix * Update java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java Co-Authored-By: Stephanie Wang * Update src/ray/core_worker/core_worker.cc Co-Authored-By: Stephanie Wang * Address comments * Fix * Fix CI * Fix cpp linting * Fix py lint * FIx * Address comments and fix * Address comments * Address * Fix import_threading --- .../api/runtimecontext/RuntimeContext.java | 7 +- .../org/ray/runtime/AbstractRayRuntime.java | 10 +- .../org/ray/runtime/RayNativeRuntime.java | 2 +- .../org/ray/runtime/RuntimeContextImpl.java | 4 +- .../src/main/java/org/ray/runtime/Worker.java | 2 +- .../java/org/ray/runtime/WorkerContext.java | 23 +- .../org/ray/runtime/config/RayConfig.java | 26 +- .../functionmanager/FunctionManager.java | 52 +-- .../ray/runtime/raylet/MockRayletClient.java | 2 +- .../org/ray/runtime/raylet/RayletClient.java | 2 +- .../ray/runtime/raylet/RayletClientImpl.java | 22 +- .../java/org/ray/runtime/task/TaskSpec.java | 10 +- .../src/main/resources/ray.default.conf | 12 +- .../functionmanager/FunctionManagerTest.java | 24 +- .../java/org/ray/api/test/RayConfigTest.java | 6 +- .../org/ray/api/test/RuntimeContextTest.java | 10 +- python/ray/__init__.py | 6 +- python/ray/_raylet.pyx | 12 +- python/ray/actor.py | 40 +- python/ray/function_manager.py | 126 +++--- python/ray/gcs_utils.py | 14 +- python/ray/import_thread.py | 8 +- python/ray/includes/common.pxd | 5 +- python/ray/includes/libraylet.pxd | 9 +- python/ray/includes/task.pxd | 8 +- python/ray/includes/task.pxi | 10 +- python/ray/includes/unique_ids.pxd | 4 +- python/ray/includes/unique_ids.pxi | 21 +- python/ray/monitor.py | 46 +- python/ray/remote_function.py | 10 +- python/ray/runtime_context.py | 2 +- python/ray/state.py | 39 +- python/ray/tests/test_basic.py | 14 +- python/ray/utils.py | 29 +- python/ray/worker.py | 157 ++++--- python/ray/workers/default_worker.py | 2 +- src/ray/common/id.cc | 16 +- src/ray/common/id.h | 8 +- src/ray/common/id_def.h | 2 +- src/ray/core_worker/context.cc | 17 +- src/ray/core_worker/context.h | 12 +- src/ray/core_worker/core_worker.cc | 9 +- src/ray/core_worker/core_worker.h | 2 +- src/ray/core_worker/core_worker_test.cc | 16 +- src/ray/core_worker/mock_worker.cc | 2 +- src/ray/core_worker/task_interface.cc | 12 +- src/ray/gcs/client.cc | 6 +- src/ray/gcs/client.h | 16 +- src/ray/gcs/client_test.cc | 416 +++++++++--------- src/ray/gcs/format/gcs.fbs | 4 +- src/ray/gcs/tables.cc | 101 +++-- src/ray/gcs/tables.h | 159 ++++--- src/ray/object_manager/object_directory.cc | 12 +- src/ray/protobuf/gcs.proto | 18 +- src/ray/raylet/actor_registration.cc | 4 +- src/ray/raylet/actor_registration.h | 4 +- src/ray/raylet/format/node_manager.fbs | 3 +- ...org_ray_runtime_raylet_RayletClientImpl.cc | 12 +- src/ray/raylet/lineage_cache.cc | 9 +- src/ray/raylet/lineage_cache_test.cc | 10 +- src/ray/raylet/monitor.cc | 8 +- src/ray/raylet/node_manager.cc | 108 +++-- src/ray/raylet/node_manager.h | 15 +- src/ray/raylet/raylet_client.cc | 14 +- src/ray/raylet/raylet_client.h | 16 +- src/ray/raylet/reconstruction_policy.cc | 6 +- src/ray/raylet/reconstruction_policy_test.cc | 16 +- src/ray/raylet/scheduling_queue.cc | 28 +- src/ray/raylet/scheduling_queue.h | 8 +- src/ray/raylet/task_dependency_manager.cc | 2 +- .../raylet/task_dependency_manager_test.cc | 4 +- src/ray/raylet/task_spec.cc | 14 +- src/ray/raylet/task_spec.h | 10 +- src/ray/raylet/task_test.cc | 4 +- src/ray/raylet/worker.cc | 6 +- src/ray/raylet/worker.h | 8 +- src/ray/raylet/worker_pool.cc | 8 +- src/ray/raylet/worker_pool.h | 10 +- src/ray/raylet/worker_pool_test.cc | 4 +- 79 files changed, 961 insertions(+), 974 deletions(-) diff --git a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java index c4f78c380..6e0feee10 100644 --- a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java +++ b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java @@ -9,12 +9,9 @@ import org.ray.api.id.UniqueId; public interface RuntimeContext { /** - * Get the current Driver ID. - * - * If called in a driver, this returns the driver ID. If called in a worker, this returns the ID - * of the associated driver. + * Get the current Job ID. */ - UniqueId getCurrentDriverId(); + UniqueId getCurrentJobId(); /** * Get the current actor ID. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 26a8d6e54..f77cd6658 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -74,10 +74,10 @@ public abstract class AbstractRayRuntime implements RayRuntime { public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; - functionManager = new FunctionManager(rayConfig.driverResourcePath); + functionManager = new FunctionManager(rayConfig.jobResourcePath); worker = new Worker(this); workerContext = new WorkerContext(rayConfig.workerMode, - rayConfig.driverId, rayConfig.runMode); + rayConfig.jobId, rayConfig.runMode); runtimeContext = new RuntimeContextImpl(this); } @@ -346,7 +346,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { boolean isActorCreationTask, BaseTaskOptions taskOptions) { Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null)); - TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(), + TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentJobId(), workerContext.getCurrentTaskId(), workerContext.nextTaskIndex()); int numReturns = actor.getId().isNil() ? 1 : 2; ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns); @@ -377,7 +377,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { FunctionDescriptor functionDescriptor; if (func != null) { language = TaskLanguage.JAVA; - functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func) + functionDescriptor = functionManager.getFunction(workerContext.getCurrentJobId(), func) .getFunctionDescriptor(); } else { language = TaskLanguage.PYTHON; @@ -385,7 +385,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { } return new TaskSpec( - workerContext.getCurrentDriverId(), + workerContext.getCurrentJobId(), taskId, workerContext.getCurrentTaskId(), -1, diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index f86983809..49e4f6c39 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -101,7 +101,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime { rayConfig.rayletSocketName, workerContext.getCurrentWorkerId(), rayConfig.workerMode == WorkerMode.WORKER, - workerContext.getCurrentDriverId() + workerContext.getCurrentJobId() ); // register diff --git a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java index 937f13773..c9815fd26 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java @@ -17,8 +17,8 @@ public class RuntimeContextImpl implements RuntimeContext { } @Override - public UniqueId getCurrentDriverId() { - return runtime.getWorkerContext().getCurrentDriverId(); + public UniqueId getCurrentJobId() { + return runtime.getWorkerContext().getCurrentJobId(); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index b4de226e2..5a2109d98 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -86,7 +86,7 @@ public class Worker { try { // Get method RayFunction rayFunction = runtime.getFunctionManager() - .getFunction(spec.driverId, spec.getJavaFunctionDescriptor()); + .getFunction(spec.jobId, spec.getJavaFunctionDescriptor()); // Set context runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 44703bf67..3dc2be7ed 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -29,7 +29,7 @@ public class WorkerContext { private ThreadLocal currentTask; - private UniqueId currentDriverId; + private UniqueId currentJobId; private ClassLoader currentClassLoader; @@ -43,7 +43,7 @@ public class WorkerContext { */ private RunMode runMode; - public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) { + public WorkerContext(WorkerMode workerMode, UniqueId jobId, RunMode runMode) { mainThreadId = Thread.currentThread().getId(); taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); @@ -52,13 +52,15 @@ public class WorkerContext { currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { - workerId = driverId; + // TODO(qwang): Assign the driver id to worker id + // once we treat driver id as a special worker id. + workerId = jobId; currentTaskId.set(TaskId.randomId()); - currentDriverId = driverId; + currentJobId = jobId; } else { workerId = UniqueId.randomId(); this.currentTaskId.set(TaskId.NIL); - this.currentDriverId = UniqueId.NIL; + this.currentJobId = UniqueId.NIL; } } @@ -84,7 +86,7 @@ public class WorkerContext { Preconditions.checkNotNull(task); this.currentTaskId.set(task.taskId); - this.currentDriverId = task.driverId; + this.currentJobId = task.jobId; taskIndex.set(0); putIndex.set(0); this.currentTask.set(task); @@ -115,15 +117,14 @@ public class WorkerContext { } /** - * @return If this worker is a driver, this method returns the driver ID; Otherwise, it returns - * the driver ID of the current running task. + * The ID of the current job. */ - public UniqueId getCurrentDriverId() { - return currentDriverId; + public UniqueId getCurrentJobId() { + return currentJobId; } /** - * @return The class loader which is associated with the current driver. + * @return The class loader which is associated with the current job. */ public ClassLoader getCurrentClassLoader() { return currentClassLoader; diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 99e67288d..27a4ce38d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -32,7 +32,7 @@ public class RayConfig { public final WorkerMode workerMode; public final RunMode runMode; public final Map resources; - public final UniqueId driverId; + public final UniqueId jobId; public final String logDir; public final boolean redirectOutput; public final List libraryPath; @@ -53,7 +53,7 @@ public class RayConfig { public final String rayletSocketName; public final List rayletConfigParameters; - public final String driverResourcePath; + public final String jobResourcePath; public final String pythonWorkerCommand; /** @@ -105,12 +105,12 @@ public class RayConfig { resources.put("CPU", numCpu * 1.0); } } - // Driver id. - String driverId = config.getString("ray.driver.id"); - if (!driverId.isEmpty()) { - this.driverId = UniqueId.fromHexString(driverId); + // Job id. + String jobId = config.getString("ray.job.id"); + if (!jobId.isEmpty()) { + this.jobId = UniqueId.fromHexString(jobId); } else { - this.driverId = UniqueId.randomId(); + this.jobId = UniqueId.randomId(); } // Log dir. logDir = removeTrailingSlash(config.getString("ray.log-dir")); @@ -160,11 +160,11 @@ public class RayConfig { rayletConfigParameters.add(parameter); } - // Driver resource path. - if (config.hasPath("ray.driver.resource-path")) { - driverResourcePath = config.getString("ray.driver.resource-path"); + // Job resource path. + if (config.hasPath("ray.job.resource-path")) { + jobResourcePath = config.getString("ray.job.resource-path"); } else { - driverResourcePath = null; + jobResourcePath = null; } // Number of threads that execute tasks. @@ -205,7 +205,7 @@ public class RayConfig { + ", workerMode=" + workerMode + ", runMode=" + runMode + ", resources=" + resources - + ", driverId=" + driverId + + ", jobId=" + jobId + ", logDir='" + logDir + '\'' + ", redirectOutput=" + redirectOutput + ", libraryPath=" + libraryPath @@ -220,7 +220,7 @@ public class RayConfig { + ", objectStoreSize=" + objectStoreSize + ", rayletSocketName='" + rayletSocketName + '\'' + ", rayletConfigParameters=" + rayletConfigParameters - + ", driverResourcePath='" + driverResourcePath + '\'' + + ", jobResourcePath='" + jobResourcePath + '\'' + ", pythonWorkerCommand='" + pythonWorkerCommand + '\'' + '}'; } diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index 54f01aed1..988dac794 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -30,7 +30,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Manages functions by driver id. + * Manages functions by job id. */ public class FunctionManager { @@ -46,33 +46,33 @@ public class FunctionManager { RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new); /** - * Mapping from the driver id to the functions that belong to this driver. + * Mapping from the job id to the functions that belong to this job. */ - private Map driverFunctionTables = new HashMap<>(); + private Map jobFunctionTables = new HashMap<>(); /** - * The resource path which we can load the driver's jar resources. + * The resource path which we can load the job's jar resources. */ - private String driverResourcePath; + private String jobResourcePath; /** - * Construct a FunctionManager with the specified driver resource path. + * Construct a FunctionManager with the specified job resource path. * - * @param driverResourcePath The specified driver resource that can store the driver's + * @param jobResourcePath The specified job resource that can store the job's * resources. */ - public FunctionManager(String driverResourcePath) { - this.driverResourcePath = driverResourcePath; + public FunctionManager(String jobResourcePath) { + this.jobResourcePath = jobResourcePath; } /** * Get the RayFunction from a RayFunc instance (a lambda). * - * @param driverId current driver id. + * @param jobId current job id. * @param func The lambda. * @return A RayFunction object. */ - public RayFunction getFunction(UniqueId driverId, RayFunc func) { + public RayFunction getFunction(UniqueId jobId, RayFunc func) { JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); if (functionDescriptor == null) { SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func); @@ -82,24 +82,24 @@ public class FunctionManager { functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor); RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor); } - return getFunction(driverId, functionDescriptor); + return getFunction(jobId, functionDescriptor); } /** * Get the RayFunction from a function descriptor. * - * @param driverId Current driver id. + * @param jobId Current job id. * @param functionDescriptor The function descriptor. * @return A RayFunction object. */ - public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) { - DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId); - if (driverFunctionTable == null) { + public RayFunction getFunction(UniqueId jobId, JavaFunctionDescriptor functionDescriptor) { + JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId); + if (jobFunctionTable == null) { ClassLoader classLoader; - if (Strings.isNullOrEmpty(driverResourcePath)) { + if (Strings.isNullOrEmpty(jobResourcePath)) { classLoader = getClass().getClassLoader(); } else { - File resourceDir = new File(driverResourcePath + "/" + driverId.toString() + "/"); + File resourceDir = new File(jobResourcePath + "/" + jobId.toString() + "/"); Collection files = FileUtils.listFiles(resourceDir, new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY); files.add(resourceDir); @@ -111,23 +111,23 @@ public class FunctionManager { } }).collect(Collectors.toList()); classLoader = new URLClassLoader(urlList.toArray(new URL[urlList.size()])); - LOGGER.debug("Resource loaded for driver {} from path {}.", driverId, + LOGGER.debug("Resource loaded for job {} from path {}.", jobId, resourceDir.getAbsolutePath()); } - driverFunctionTable = new DriverFunctionTable(classLoader); - driverFunctionTables.put(driverId, driverFunctionTable); + jobFunctionTable = new JobFunctionTable(classLoader); + jobFunctionTables.put(jobId, jobFunctionTable); } - return driverFunctionTable.getFunction(functionDescriptor); + return jobFunctionTable.getFunction(functionDescriptor); } /** - * Manages all functions that belong to one driver. + * Manages all functions that belong to one job. */ - static class DriverFunctionTable { + static class JobFunctionTable { /** - * The driver's corresponding class loader. + * The job's corresponding class loader. */ ClassLoader classLoader; /** @@ -135,7 +135,7 @@ public class FunctionManager { */ Map, RayFunction>> functions; - DriverFunctionTable(ClassLoader classLoader) { + JobFunctionTable(ClassLoader classLoader) { this.classLoader = classLoader; this.functions = new HashMap<>(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index fe1f61d0b..9d014b72e 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -164,7 +164,7 @@ public class MockRayletClient implements RayletClient { } @Override - public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { + public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) { return TaskId.randomId(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 4a78fde94..0ef2163f7 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -21,7 +21,7 @@ public interface RayletClient { void notifyUnblocked(TaskId currentTaskId); - TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex); + TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int timeoutMs, TaskId currentTaskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index c369e6f2c..00b114460 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -44,10 +44,11 @@ public class RayletClientImpl implements RayletClient { */ private long client = 0; + // TODO(qwang): JobId parameter can be removed once we embed jobId in driverId. public RayletClientImpl(String schedulerSockName, UniqueId clientId, - boolean isWorker, UniqueId driverId) { + boolean isWorker, UniqueId jobId) { client = nativeInit(schedulerSockName, clientId.getBytes(), - isWorker, driverId.getBytes()); + isWorker, jobId.getBytes()); } @Override @@ -83,7 +84,7 @@ public class RayletClientImpl implements RayletClient { public void submitTask(TaskSpec spec) { LOGGER.debug("Submitting task: {}", spec); Preconditions.checkState(!spec.parentTaskId.isNil()); - Preconditions.checkState(!spec.driverId.isNil()); + Preconditions.checkState(!spec.jobId.isNil()); ByteBuffer info = convertTaskSpecToFlatbuffer(spec); byte[] cursorId = null; @@ -114,8 +115,8 @@ public class RayletClientImpl implements RayletClient { } @Override - public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) { - byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex); + public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) { + byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); return new TaskId(bytes); } @@ -141,11 +142,10 @@ public class RayletClientImpl implements RayletClient { nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes()); } - private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) { bb.order(ByteOrder.LITTLE_ENDIAN); TaskInfo info = TaskInfo.getRootAsTaskInfo(bb); - UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer()); + UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer()); TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer()); TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer()); int parentCounter = info.parentCounter(); @@ -197,7 +197,7 @@ public class RayletClientImpl implements RayletClient { dynamicWorkerOptions.add(info.dynamicWorkerOptions(i)); } - return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId, + return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId, maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles, args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions); } @@ -207,7 +207,7 @@ public class RayletClientImpl implements RayletClient { bb.clear(); FlatBufferBuilder fbb = new FlatBufferBuilder(bb); - final int driverIdOffset = fbb.createString(task.driverId.toByteBuffer()); + final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer()); final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer()); final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer()); final int parentCounter = task.parentCounter; @@ -290,7 +290,7 @@ public class RayletClientImpl implements RayletClient { int root = TaskInfo.createTaskInfo( fbb, - driverIdOffset, + jobIdOffset, taskIdOffset, parentTaskIdOffset, parentCounter, @@ -363,7 +363,7 @@ public class RayletClientImpl implements RayletClient { private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds, int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException; - private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId, + private static native byte[] nativeGenerateTaskId(byte[] jobId, byte[] parentTaskId, int taskIndex); private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds, diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 060ca6fff..449ff6111 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -18,8 +18,8 @@ import org.ray.runtime.util.IdUtil; */ public class TaskSpec { - // ID of the driver that created this task. - public final UniqueId driverId; + // ID of the job that created this task. + public final UniqueId jobId; // Task ID of the task. public final TaskId taskId; @@ -81,7 +81,7 @@ public class TaskSpec { } public TaskSpec( - UniqueId driverId, + UniqueId jobId, TaskId taskId, TaskId parentTaskId, int parentCounter, @@ -97,7 +97,7 @@ public class TaskSpec { TaskLanguage language, FunctionDescriptor functionDescriptor, List dynamicWorkerOptions) { - this.driverId = driverId; + this.jobId = jobId; this.taskId = taskId; this.parentTaskId = parentTaskId; this.parentCounter = parentCounter; @@ -147,7 +147,7 @@ public class TaskSpec { @Override public String toString() { return "TaskSpec{" + - "driverId=" + driverId + + "jobId=" + jobId + ", taskId=" + taskId + ", parentTaskId=" + parentTaskId + ", parentCounter=" + parentCounter + diff --git a/java/runtime/src/main/resources/ray.default.conf b/java/runtime/src/main/resources/ray.default.conf index a2762e76d..c5fd12e92 100644 --- a/java/runtime/src/main/resources/ray.default.conf +++ b/java/runtime/src/main/resources/ray.default.conf @@ -20,14 +20,14 @@ ray { // Available resources on this node, for example "CPU:4,GPU:0". resources: "" - // Configuration items about driver. - driver { - // If worker.mode is DRIVER, specify the driver id. + // Configuration items about job. + job { + // If worker.mode is DRIVER, specify the job id. // If not provided, a random id will be used. id: "" - // If this config is set, worker will use different paths to loadresources when - // executing tasks from different drivers. E.g. if it's set to '/tm/driver_resources', - // the path for driver 123 will be '/tmp/driver_resources/123'. + // If this config is set, worker will use different paths to load resources when + // executing tasks from different jobs. E.g. if it's set to '/tm/job_resources', + // the path for job 123 will be '/tmp/job_resources/123'. resource-path: "" } diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 6641abc6c..7c30ee755 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -13,7 +13,7 @@ import org.ray.api.annotation.RayRemote; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; import org.ray.api.id.UniqueId; -import org.ray.runtime.functionmanager.FunctionManager.DriverFunctionTable; +import org.ray.runtime.functionmanager.FunctionManager.JobFunctionTable; import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -106,7 +106,7 @@ public class FunctionManagerTest { @Test public void testLoadFunctionTableForClass() { - DriverFunctionTable functionTable = new DriverFunctionTable(getClass().getClassLoader()); + JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader()); Map, RayFunction> res = functionTable .loadFunctionsForClass(Bar.class.getName()); // The result should 2 entries, one for the constructor, the other for bar. @@ -119,13 +119,13 @@ public class FunctionManagerTest { @Test public void testGetFunctionFromLocalResource() throws Exception { - UniqueId driverId = UniqueId.randomId(); + UniqueId jobId = UniqueId.randomId(); final String resourcePath = FileUtils.getTempDirectoryPath() + "/ray_test_resources"; - final String driverResourcePath = resourcePath + "/" + driverId.toString(); - File driverResourceDir = new File(driverResourcePath); - FileUtils.deleteQuietly(driverResourceDir); - driverResourceDir.mkdirs(); - driverResourceDir.deleteOnExit(); + final String jobResourcePath = resourcePath + "/" + jobId.toString(); + File jobResourceDir = new File(jobResourcePath); + FileUtils.deleteQuietly(jobResourceDir); + jobResourceDir.mkdirs(); + jobResourceDir.deleteOnExit(); String demoJavaFile = ""; demoJavaFile += "public class DemoApp {\n"; @@ -134,13 +134,13 @@ public class FunctionManagerTest { demoJavaFile += " }\n"; demoJavaFile += "}"; - // Write the demo java file to the driver resource path. - String javaFilePath = driverResourcePath + "/DemoApp.java"; + // Write the demo java file to the job resource path. + String javaFilePath = jobResourcePath + "/DemoApp.java"; Files.write(Paths.get(javaFilePath), demoJavaFile.getBytes()); // Compile the java file. JavaCompiler compiler = ToolProvider.getSystemJavaCompiler(); - int result = compiler.run(null, null, null, "-d", driverResourcePath, javaFilePath); + int result = compiler.run(null, null, null, "-d", jobResourcePath, javaFilePath); if (result != 0) { throw new RuntimeException("Couldn't compile Demo.java."); } @@ -149,7 +149,7 @@ public class FunctionManagerTest { JavaFunctionDescriptor descriptor = new JavaFunctionDescriptor( "DemoApp", "hello", "()Ljava/lang/String;"); final FunctionManager functionManager = new FunctionManager(resourcePath); - RayFunction func = functionManager.getFunction(driverId, descriptor); + RayFunction func = functionManager.getFunction(jobId, descriptor); Assert.assertEquals(func.getFunctionDescriptor(), descriptor); } diff --git a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java index 26702ebe4..5b6834e5e 100644 --- a/java/test/src/main/java/org/ray/api/test/RayConfigTest.java +++ b/java/test/src/main/java/org/ray/api/test/RayConfigTest.java @@ -10,13 +10,13 @@ public class RayConfigTest { @Test public void testCreateRayConfig() { try { - System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path"); + System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path"); RayConfig rayConfig = RayConfig.create(); Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode); - Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath); + Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath); } finally { // Unset system properties. - System.clearProperty("ray.driver.resource-path"); + System.clearProperty("ray.job.resource-path"); } } diff --git a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java index 512519bce..33e2a345e 100644 --- a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java +++ b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java @@ -11,28 +11,28 @@ import org.testng.annotations.Test; public class RuntimeContextTest extends BaseTest { - private static UniqueId DRIVER_ID = + private static UniqueId JOB_ID = UniqueId.fromHexString("0011223344556677889900112233445566778899"); private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket"; @BeforeClass public void setUp() { - System.setProperty("ray.driver.id", DRIVER_ID.toString()); + System.setProperty("ray.job.id", JOB_ID.toString()); System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME); } @AfterClass public void tearDown() { - System.clearProperty("ray.driver.id"); + System.clearProperty("ray.job.id"); System.clearProperty("ray.raylet.socket-name"); System.clearProperty("ray.object-store.socket-name"); } @Test public void testRuntimeContextInDriver() { - Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId()); + Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId()); Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName()); Assert.assertEquals(OBJECT_STORE_SOCKET_NAME, Ray.getRuntimeContext().getObjectStoreSocketName()); @@ -42,7 +42,7 @@ public class RuntimeContextTest extends BaseTest { public static class RuntimeContextTester { public String testRuntimeContext(UniqueId actorId) { - Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId()); + Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId()); Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId()); Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName()); Assert.assertEquals(OBJECT_STORE_SOCKET_NAME, diff --git a/python/ray/__init__.py b/python/ray/__init__.py index 03792e5eb..e65b59a77 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -56,7 +56,8 @@ from ray._raylet import ( ActorID, ClientID, Config as _Config, - DriverID, + JobID, + WorkerID, FunctionID, ObjectID, TaskID, @@ -141,7 +142,8 @@ __all__ += [ "ActorHandleID", "ActorID", "ClientID", - "DriverID", + "JobID", + "WorkerID", "FunctionID", "ObjectID", "TaskID", diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 1cea8354d..f3968577d 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -221,13 +221,13 @@ cdef class RayletClient: def __cinit__(self, raylet_socket, ClientID client_id, c_bool is_worker, - DriverID driver_id): + JobID job_id): # We know that we are using Python, so just skip the language # parameter. # TODO(suquark): Should we allow unicode chars in "raylet_socket"? self.client.reset(new CRayletClient( raylet_socket.encode("ascii"), client_id.native(), is_worker, - driver_id.native(), LANGUAGE_PYTHON)) + job_id.native(), LANGUAGE_PYTHON)) def disconnect(self): check_status(self.client.get().Disconnect()) @@ -293,9 +293,9 @@ cdef class RayletClient: postincrement(iterator) return resources_dict - def push_error(self, DriverID driver_id, error_type, error_message, + def push_error(self, JobID job_id, error_type, error_message, double timestamp): - check_status(self.client.get().PushError(driver_id.native(), + check_status(self.client.get().PushError(job_id.native(), error_type.encode("ascii"), error_message.encode("ascii"), timestamp)) @@ -381,8 +381,8 @@ cdef class RayletClient: return ClientID(self.client.get().GetClientID().Binary()) @property - def driver_id(self): - return DriverID(self.client.get().GetDriverID().Binary()) + def job_id(self): + return JobID(self.client.get().GetJobID().Binary()) @property def is_worker(self): diff --git a/python/ray/actor.py b/python/ray/actor.py index 65642d992..3361b29c4 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -17,8 +17,7 @@ from ray.function_manager import FunctionDescriptor import ray.ray_constants as ray_constants import ray.signature as signature import ray.worker -from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID, - DriverID) +from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID) logger = logging.getLogger(__name__) @@ -186,7 +185,7 @@ class ActorClass(object): task. _resources: The default resources required by the actor creation task. _actor_method_cpus: The number of CPUs required by actor method tasks. - _last_driver_id_exported_for: The ID of the driver ID of the last Ray + _last_job_id_exported_for: The ID of the job of the last Ray session during which this actor class definition was exported. This is an imperfect mechanism used to determine if we need to export the remote function again. It is imperfect in the sense that the @@ -212,7 +211,7 @@ class ActorClass(object): self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._last_driver_id_exported_for = None + self._last_job_id_exported_for = None self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) @@ -345,13 +344,12 @@ class ActorClass(object): *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if (self._last_driver_id_exported_for is None - or self._last_driver_id_exported_for != - worker.task_driver_id): + if (self._last_job_id_exported_for is None or + self._last_job_id_exported_for != worker.current_job_id): # If this actor class was exported in a previous session, we # need to export this function again, because current GCS # doesn't have it. - self._last_driver_id_exported_for = worker.task_driver_id + self._last_job_id_exported_for = worker.current_job_id worker.function_actor_manager.export_actor_class( self._modified_class, self._actor_method_names) @@ -389,7 +387,7 @@ class ActorClass(object): actor_id, self._modified_class.__module__, self._class_name, actor_cursor, self._actor_method_names, self._method_decorators, self._method_signatures, self._actor_method_num_return_vals, - actor_cursor, actor_method_cpu, worker.task_driver_id) + actor_cursor, actor_method_cpu, worker.current_job_id) # We increment the actor counter by 1 to account for the actor creation # task. actor_handle._ray_actor_counter += 1 @@ -446,9 +444,9 @@ class ActorHandle(object): _ray_original_handle: True if this is the original actor handle for a given actor. If this is true, then the actor will be destroyed when this handle goes out of scope. - _ray_actor_driver_id: The driver ID of the job that created the actor - (it is possible that this ActorHandle exists on a driver with a - different driver ID). + _ray_actor_job_id: The ID of the job that created the actor + (it is possible that this ActorHandle exists on a job with a + different job ID). _ray_new_actor_handles: The new actor handles that were created from this handle since the last task on this handle was submitted. This is used to garbage-collect dummy objects that are no longer @@ -466,10 +464,10 @@ class ActorHandle(object): method_num_return_vals, actor_creation_dummy_object_id, actor_method_cpus, - actor_driver_id, + actor_job_id, actor_handle_id=None): assert isinstance(actor_id, ActorID) - assert isinstance(actor_driver_id, DriverID) + assert isinstance(actor_job_id, ray.JobID) self._ray_actor_id = actor_id self._ray_module_name = module_name # False if this actor handle was created by forking or pickling. True @@ -491,7 +489,7 @@ class ActorHandle(object): self._ray_actor_creation_dummy_object_id = ( actor_creation_dummy_object_id) self._ray_actor_method_cpus = actor_method_cpus - self._ray_actor_driver_id = actor_driver_id + self._ray_actor_job_id = actor_job_id self._ray_new_actor_handles = [] self._ray_actor_lock = threading.Lock() @@ -551,7 +549,7 @@ class ActorHandle(object): num_return_vals=num_return_vals + 1, resources={"CPU": self._ray_actor_method_cpus}, placement_resources={}, - driver_id=self._ray_actor_driver_id, + job_id=self._ray_actor_job_id, ) # Update the actor counter and cursor to reflect the most recent # invocation. @@ -612,7 +610,7 @@ class ActorHandle(object): # not just the first one. worker = ray.worker.get_global_worker() if (worker.mode == ray.worker.SCRIPT_MODE - and self._ray_actor_driver_id.binary() != worker.worker_id): + and self._ray_actor_job_id.binary() != worker.worker_id): # If the worker is a driver and driver id has changed because # Ray was shut down re-initialized, the actor is already cleaned up # and we don't need to send `__ray_terminate__` again. @@ -666,7 +664,7 @@ class ActorHandle(object): "actor_creation_dummy_object_id": self. _ray_actor_creation_dummy_object_id, "actor_method_cpus": self._ray_actor_method_cpus, - "actor_driver_id": self._ray_actor_driver_id, + "actor_job_id": self._ray_actor_job_id, "ray_forking": ray_forking } @@ -727,9 +725,9 @@ class ActorHandle(object): state["method_num_return_vals"], state["actor_creation_dummy_object_id"], state["actor_method_cpus"], - # This is the driver ID of the driver that owns the actor, not - # necessarily the driver that owns this actor handle. - state["actor_driver_id"], + # This is the ID of the job that owns the actor, not + # necessarily the job that owns this actor handle. + state["actor_job_id"], actor_handle_id=actor_handle_id) def __getstate__(self): diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 4914c9f87..4220feb82 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -277,9 +277,9 @@ class FunctionActorManager(object): the worker gets connected. _actors_to_export: The actors to export when the worker gets connected. - _function_execution_info: The map from driver_id to finction_id + _function_execution_info: The map from job_id to function_id and execution_info. - _num_task_executions: The map from driver_id to function + _num_task_executions: The map from job_id to function execution times. imported_actor_classes: The set of actor classes keys (format: ActorClass:function_id) that are already in GCS. @@ -303,17 +303,17 @@ class FunctionActorManager(object): self._loaded_actor_classes = {} self.lock = threading.Lock() - def increase_task_counter(self, driver_id, function_descriptor): + def increase_task_counter(self, job_id, function_descriptor): function_id = function_descriptor.function_id if self._worker.load_code_from_local: - driver_id = ray.DriverID.nil() - self._num_task_executions[driver_id][function_id] += 1 + job_id = ray.JobID.nil() + self._num_task_executions[job_id][function_id] += 1 - def get_task_counter(self, driver_id, function_descriptor): + def get_task_counter(self, job_id, function_descriptor): function_id = function_descriptor.function_id if self._worker.load_code_from_local: - driver_id = ray.DriverID.nil() - return self._num_task_executions[driver_id][function_id] + job_id = ray.JobID.nil() + return self._num_task_executions[job_id][function_id] def export_cached(self): """Export cached remote functions @@ -376,11 +376,11 @@ class FunctionActorManager(object): check_oversized_pickle(pickled_function, remote_function._function_name, "remote function", self._worker) - key = (b"RemoteFunction:" + self._worker.task_driver_id.binary() + b":" + key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":" + remote_function._function_descriptor.function_id.binary()) self._worker.redis_client.hmset( key, { - "driver_id": self._worker.task_driver_id.binary(), + "job_id": self._worker.current_job_id.binary(), "function_id": remote_function._function_descriptor. function_id.binary(), "name": remote_function._function_name, @@ -392,14 +392,14 @@ class FunctionActorManager(object): def fetch_and_register_remote_function(self, key): """Import a remote function.""" - (driver_id_str, function_id_str, function_name, serialized_function, + (job_id_str, function_id_str, function_name, serialized_function, num_return_vals, module, resources, max_calls) = self._worker.redis_client.hmget(key, [ - "driver_id", "function_id", "name", "function", "num_return_vals", + "job_id", "function_id", "name", "function", "num_return_vals", "module", "resources", "max_calls" ]) function_id = ray.FunctionID(function_id_str) - driver_id = ray.DriverID(driver_id_str) + job_id = ray.JobID(job_id_str) function_name = decode(function_name) max_calls = int(max_calls) module = decode(module) @@ -413,12 +413,12 @@ class FunctionActorManager(object): # atomic. Otherwise, there is race condition. Another thread may use # the temporary function above before the real function is ready. with self.lock: - self._function_execution_info[driver_id][function_id] = ( + self._function_execution_info[job_id][function_id] = ( FunctionExecutionInfo( function=f, function_name=function_name, max_calls=max_calls)) - self._num_task_executions[driver_id][function_id] = 0 + self._num_task_executions[job_id][function_id] = 0 try: function = pickle.loads(serialized_function) @@ -434,7 +434,7 @@ class FunctionActorManager(object): "Failed to unpickle the remote function '{}' with " "function ID {}. Traceback:\n{}".format( function_name, function_id.hex(), traceback_str), - driver_id=driver_id) + job_id=job_id) else: # The below line is necessary. Because in the driver process, # if the function is defined in the file where the python @@ -442,7 +442,7 @@ class FunctionActorManager(object): # However in the worker process, the `__main__` module is a # different module, which is `default_worker.py` function.__module__ = module - self._function_execution_info[driver_id][function_id] = ( + self._function_execution_info[job_id][function_id] = ( FunctionExecutionInfo( function=function, function_name=function_name, @@ -452,11 +452,11 @@ class FunctionActorManager(object): b"FunctionTable:" + function_id.binary(), self._worker.worker_id) - def get_execution_info(self, driver_id, function_descriptor): + def get_execution_info(self, job_id, function_descriptor): """Get the FunctionExecutionInfo of a remote function. Args: - driver_id: ID of the driver that the function belongs to. + job_id: ID of the job that the function belongs to. function_descriptor: The FunctionDescriptor of the function to get. Returns: @@ -464,11 +464,11 @@ class FunctionActorManager(object): """ if self._worker.load_code_from_local: # Load function from local code. - # Currently, we don't support isolating code by drivers, - # thus always set driver ID to NIL here. - driver_id = ray.DriverID.nil() + # Currently, we don't support isolating code by jobs, + # thus always set job ID to NIL here. + job_id = ray.JobID.nil() if not function_descriptor.is_actor_method(): - self._load_function_from_local(driver_id, function_descriptor) + self._load_function_from_local(job_id, function_descriptor) else: # Load function from GCS. # Wait until the function to be executed has actually been @@ -477,21 +477,21 @@ class FunctionActorManager(object): # The driver function may not be found in sys.path. Try to load # the function from GCS. with profiling.profile("wait_for_function"): - self._wait_for_function(function_descriptor, driver_id) + self._wait_for_function(function_descriptor, job_id) try: function_id = function_descriptor.function_id - info = self._function_execution_info[driver_id][function_id] + info = self._function_execution_info[job_id][function_id] except KeyError as e: message = ("Error occurs in get_execution_info: " - "driver_id: %s, function_descriptor: %s. Message: %s" % - (driver_id, function_descriptor, e)) + "job_id: %s, function_descriptor: %s. Message: %s" % + (job_id, function_descriptor, e)) raise KeyError(message) return info - def _load_function_from_local(self, driver_id, function_descriptor): + def _load_function_from_local(self, job_id, function_descriptor): assert not function_descriptor.is_actor_method() function_id = function_descriptor.function_id - if (driver_id in self._function_execution_info + if (job_id in self._function_execution_info and function_id in self._function_execution_info[function_id]): return module_name, function_name = ( @@ -501,13 +501,13 @@ class FunctionActorManager(object): try: module = importlib.import_module(module_name) function = getattr(module, function_name)._function - self._function_execution_info[driver_id][function_id] = ( + self._function_execution_info[job_id][function_id] = ( FunctionExecutionInfo( function=function, function_name=function_name, max_calls=0, )) - self._num_task_executions[driver_id][function_id] = 0 + self._num_task_executions[job_id][function_id] = 0 except Exception: logger.exception( "Failed to load function %s.".format(function_name)) @@ -515,7 +515,7 @@ class FunctionActorManager(object): "Function {} failed to be loaded from local code.".format( function_descriptor)) - def _wait_for_function(self, function_descriptor, driver_id, timeout=10): + def _wait_for_function(self, function_descriptor, job_id, timeout=10): """Wait until the function to be executed is present on this worker. This method will simply loop until the import thread has imported the @@ -528,7 +528,7 @@ class FunctionActorManager(object): Args: function_descriptor : The FunctionDescriptor of the function that we want to execute. - driver_id (str): The ID of the driver to push the error message to + job_id (str): The ID of the job to push the error message to if this times out. """ start_time = time.time() @@ -538,7 +538,7 @@ class FunctionActorManager(object): with self.lock: if (self._worker.actor_id.is_nil() and (function_descriptor.function_id in - self._function_execution_info[driver_id])): + self._function_execution_info[job_id])): break elif not self._worker.actor_id.is_nil() and ( self._worker.actor_id in self._worker.actors): @@ -553,7 +553,7 @@ class FunctionActorManager(object): self._worker, ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR, warning_message, - driver_id=driver_id) + job_id=job_id) warning_sent = True time.sleep(0.001) @@ -577,22 +577,22 @@ class FunctionActorManager(object): if self._worker.load_code_from_local: return function_descriptor = FunctionDescriptor.from_class(Class) - # `task_driver_id` shouldn't be NIL, unless: + # `current_job_id` shouldn't be NIL, unless: # 1) This worker isn't an actor; # 2) And a previous task started a background thread, which didn't # finish before the task finished, and still uses Ray API # after that. - assert not self._worker.task_driver_id.is_nil(), ( + assert not self._worker.current_job_id.is_nil(), ( "You might have started a background thread in a non-actor task, " "please make sure the thread finishes before the task finishes.") - driver_id = self._worker.task_driver_id - key = (b"ActorClass:" + driver_id.binary() + b":" + + job_id = self._worker.current_job_id + key = (b"ActorClass:" + job_id.binary() + b":" + function_descriptor.function_id.binary()) actor_class_info = { "class_name": Class.__name__, "module": Class.__module__, "class": pickle.dumps(Class), - "driver_id": driver_id.binary(), + "job_id": job_id.binary(), "actor_method_names": json.dumps(list(actor_method_names)) } @@ -616,11 +616,11 @@ class FunctionActorManager(object): # within tasks. I tried to disable this, but it may be necessary # because of https://github.com/ray-project/ray/issues/1146. - def load_actor_class(self, driver_id, function_descriptor): + def load_actor_class(self, job_id, function_descriptor): """Load the actor class. Args: - driver_id: Driver ID of the actor. + job_id: job ID of the actor. function_descriptor: Function descriptor of the actor constructor. Returns: @@ -632,14 +632,14 @@ class FunctionActorManager(object): if actor_class is None: # Load actor class. if self._worker.load_code_from_local: - driver_id = ray.DriverID.nil() + job_id = ray.JobID.nil() # Load actor class from local code. actor_class = self._load_actor_from_local( - driver_id, function_descriptor) + job_id, function_descriptor) else: # Load actor class from GCS. actor_class = self._load_actor_class_from_gcs( - driver_id, function_descriptor) + job_id, function_descriptor) # Save the loaded actor class in cache. self._loaded_actor_classes[function_id] = actor_class @@ -657,18 +657,19 @@ class FunctionActorManager(object): actor_method, actor_imported=True, ) - self._function_execution_info[driver_id][method_id] = ( + self._function_execution_info[job_id][method_id] = ( FunctionExecutionInfo( function=executor, function_name=actor_method_name, max_calls=0, )) - self._num_task_executions[driver_id][method_id] = 0 - self._num_task_executions[driver_id][function_id] = 0 + self._num_task_executions[job_id][method_id] = 0 + self._num_task_executions[job_id][function_id] = 0 return actor_class - def _load_actor_from_local(self, driver_id, function_descriptor): + def _load_actor_from_local(self, job_id, function_descriptor): """Load actor class from local code.""" + assert isinstance(job_id, ray.JobID) module_name, class_name = (function_descriptor.module_name, function_descriptor.class_name) try: @@ -699,9 +700,9 @@ class FunctionActorManager(object): return TemporaryActor - def _load_actor_class_from_gcs(self, driver_id, function_descriptor): + def _load_actor_class_from_gcs(self, job_id, function_descriptor): """Load actor class from GCS.""" - key = (b"ActorClass:" + driver_id.binary() + b":" + + key = (b"ActorClass:" + job_id.binary() + b":" + function_descriptor.function_id.binary()) # Wait for the actor class key to have been imported by the # import thread. TODO(rkn): It shouldn't be possible to end @@ -711,16 +712,14 @@ class FunctionActorManager(object): time.sleep(0.001) # Fetch raw data from GCS. - (driver_id_str, class_name, module, pickled_class, + (job_id_str, class_name, module, pickled_class, actor_method_names) = self._worker.redis_client.hmget( - key, [ - "driver_id", "class_name", "module", "class", - "actor_method_names" - ]) + key, + ["job_id", "class_name", "module", "class", "actor_method_names"]) class_name = ensure_str(class_name) module_name = ensure_str(module) - driver_id = ray.DriverID(driver_id_str) + job_id = ray.JobID(job_id_str) actor_method_names = json.loads(ensure_str(actor_method_names)) actor_class = None @@ -741,11 +740,12 @@ class FunctionActorManager(object): traceback.format_exc()) # Log the error message. push_error_to_driver( - self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR, + self._worker, + ray_constants.REGISTER_ACTOR_PUSH_ERROR, "Failed to unpickle actor class '{}' for actor ID {}. " - "Traceback:\n{}".format(class_name, - self._worker.actor_id.hex(), - traceback_str), driver_id) + "Traceback:\n{}".format( + class_name, self._worker.actor_id.hex(), traceback_str), + job_id=job_id) # TODO(rkn): In the future, it might make sense to have the worker # exit here. However, currently that would lead to hanging if # someone calls ray.get on a method invoked on the actor. @@ -859,7 +859,7 @@ class FunctionActorManager(object): self._worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, - driver_id=self._worker.task_driver_id) + job_id=self._worker.current_job_id) def _restore_and_log_checkpoint(self, actor): """Restore an actor from a checkpoint if available and log any errors. @@ -898,4 +898,4 @@ class FunctionActorManager(object): self._worker, ray_constants.CHECKPOINT_PUSH_ERROR, traceback_str, - driver_id=self._worker.task_driver_id) + job_id=self._worker.current_job_id) diff --git a/python/ray/gcs_utils.py b/python/ray/gcs_utils.py index ba72e96f4..25157a62e 100644 --- a/python/ray/gcs_utils.py +++ b/python/ray/gcs_utils.py @@ -7,7 +7,7 @@ from ray.core.generated.ray.protocol.Task import Task from ray.core.generated.gcs_pb2 import ( ActorCheckpointIdData, ClientTableData, - DriverTableData, + JobTableData, ErrorTableData, ErrorType, GcsEntry, @@ -23,7 +23,7 @@ from ray.core.generated.gcs_pb2 import ( __all__ = [ "ActorCheckpointIdData", "ClientTableData", - "DriverTableData", + "JobTableData", "ErrorTableData", "ErrorType", "GcsEntry", @@ -48,8 +48,8 @@ XRAY_HEARTBEAT_CHANNEL = str( XRAY_HEARTBEAT_BATCH_CHANNEL = str( TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii") -# xray driver updates -XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii") +# xray job updates +XRAY_JOB_CHANNEL = str(TablePubsub.Value("JOB_PUBSUB")).encode("ascii") # These prefixes must be kept up-to-date with the TablePrefix enum in # gcs.proto. @@ -61,11 +61,11 @@ TablePrefix_ERROR_INFO_string = "ERROR_INFO" TablePrefix_PROFILE_string = "PROFILE" -def construct_error_message(driver_id, error_type, message, timestamp): +def construct_error_message(job_id, error_type, message, timestamp): """Construct a serialized ErrorTableData object. Args: - driver_id: The ID of the driver that the error should go to. If this is + job_id: The ID of the job that the error should go to. If this is nil, then the error will go to all drivers. error_type: The type of the error. message: The error message. @@ -75,7 +75,7 @@ def construct_error_message(driver_id, error_type, message, timestamp): The serialized object. """ data = ErrorTableData() - data.driver_id = driver_id.binary() + data.job_id = job_id.binary() data.type = error_type data.error_message = message data.timestamp = timestamp diff --git a/python/ray/import_thread.py b/python/ray/import_thread.py index 440417020..91f5e8b1d 100644 --- a/python/ray/import_thread.py +++ b/python/ray/import_thread.py @@ -114,13 +114,13 @@ class ImportThread(object): def fetch_and_execute_function_to_run(self, key): """Run on arbitrary function on the worker.""" - (driver_id, serialized_function, + (job_id, serialized_function, run_on_other_drivers) = self.redis_client.hmget( - key, ["driver_id", "function", "run_on_other_drivers"]) + key, ["job_id", "function", "run_on_other_drivers"]) if (utils.decode(run_on_other_drivers) == "False" and self.worker.mode == ray.SCRIPT_MODE - and driver_id != self.worker.task_driver_id.binary()): + and job_id != self.worker.current_job_id.binary()): return try: @@ -140,4 +140,4 @@ class ImportThread(object): self.worker, ray_constants.FUNCTION_TO_RUN_PUSH_ERROR, traceback_str, - driver_id=ray.DriverID(driver_id)) + job_id=ray.JobID(job_id)) diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 4c2cd8437..5c716b673 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -6,7 +6,8 @@ from libcpp.unordered_map cimport unordered_map from libcpp.vector cimport vector as c_vector from ray.includes.unique_ids cimport ( - CDriverID, + CJobID, + CWorkerID, CObjectID, CTaskID, ) @@ -81,7 +82,7 @@ cdef extern from "ray/common/status.h" namespace "ray::StatusCode" nogil: cdef extern from "ray/common/id.h" namespace "ray" nogil: - const CTaskID GenerateTaskId(const CDriverID &driver_id, + const CTaskID GenerateTaskId(const CJobID &job_id, const CTaskID &parent_task_id, int parent_task_counter) diff --git a/python/ray/includes/libraylet.pxd b/python/ray/includes/libraylet.pxd index 1b4c5e3cd..3bc6eddd0 100644 --- a/python/ray/includes/libraylet.pxd +++ b/python/ray/includes/libraylet.pxd @@ -14,7 +14,8 @@ from ray.includes.unique_ids cimport ( CActorCheckpointID, CActorID, CClientID, - CDriverID, + CJobID, + CWorkerID, CObjectID, CTaskID, ) @@ -46,7 +47,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: cdef cppclass CRayletClient "RayletClient": CRayletClient(const c_string &raylet_socket, const CClientID &client_id, - c_bool is_worker, const CDriverID &driver_id, + c_bool is_worker, const CJobID &job_id, const CLanguage &language) CRayStatus Disconnect() CRayStatus SubmitTask( @@ -62,7 +63,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: int num_returns, int64_t timeout_milliseconds, c_bool wait_local, const CTaskID ¤t_task_id, WaitResultPair *result) - CRayStatus PushError(const CDriverID &driver_id, const c_string &type, + CRayStatus PushError(const CJobID &job_id, const c_string &type, const c_string &error_message, double timestamp) CRayStatus PushProfileEvents( const GCSProfileTableDataT &profile_events) @@ -75,6 +76,6 @@ cdef extern from "ray/raylet/raylet_client.h" nogil: CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id) CLanguage GetLanguage() const CClientID GetClientID() const - CDriverID GetDriverID() const + CJobID GetJobID() const c_bool IsWorker() const const ResourceMappingType &GetResourceIDs() const diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index 128560f7b..5d6511c32 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -12,7 +12,7 @@ from ray.includes.common cimport ( from ray.includes.unique_ids cimport ( CActorHandleID, CActorID, - CDriverID, + CJobID, CObjectID, CTaskID, ) @@ -46,7 +46,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil: cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification": CTaskSpecification( - const CDriverID &driver_id, const CTaskID &parent_task_id, + const CJobID &job_id, const CTaskID &parent_task_id, int64_t parent_counter, const c_vector[shared_ptr[CTaskArgument]] &task_arguments, int64_t num_returns, @@ -54,7 +54,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil: const CLanguage &language, const c_vector[c_string] &function_descriptor) CTaskSpecification( - const CDriverID &driver_id, const CTaskID &parent_task_id, + const CJobID &job_id, const CTaskID &parent_task_id, int64_t parent_counter, const CActorID &actor_creation_id, const CObjectID &actor_creation_dummy_object_id, int64_t max_actor_reconstructions, const CActorID &actor_id, @@ -70,7 +70,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil: c_string SerializeAsString() const CTaskID TaskId() const - CDriverID DriverId() const + CJobID JobId() const CTaskID ParentTaskId() const int64_t ParentCounter() const c_vector[c_string] FunctionDescriptor() const diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index f8258e0a6..c9e64d83b 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -18,7 +18,7 @@ cdef class Task: unique_ptr[CTaskSpecification] task_spec unique_ptr[c_vector[CObjectID]] execution_dependencies - def __init__(self, DriverID driver_id, function_descriptor, arguments, + def __init__(self, JobID job_id, function_descriptor, arguments, int num_returns, TaskID parent_task_id, int parent_counter, ActorID actor_creation_id, ObjectID actor_creation_dummy_object_id, @@ -72,7 +72,7 @@ cdef class Task: (new_actor_handle).native()) self.task_spec.reset(new CTaskSpecification( - driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(), + job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(), actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(), actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns, required_resources, required_placement_resources, LANGUAGE_PYTHON, @@ -122,9 +122,9 @@ cdef class Task: return SerializeTaskAsString( self.execution_dependencies.get(), self.task_spec.get()) - def driver_id(self): - """Return the driver ID for this task.""" - return DriverID(self.task_spec.get().DriverId().Binary()) + def job_id(self): + """Return the job ID for this task.""" + return JobID(self.task_spec.get().JobId().Binary()) def task_id(self): """Return the task ID for this task.""" diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 1bce0f4ba..91410f1ae 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -78,10 +78,10 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod CFunctionID FromBinary(const c_string &binary) - cdef cppclass CDriverID "ray::DriverID"(CUniqueID): + cdef cppclass CJobID "ray::JobID"(CUniqueID): @staticmethod - CDriverID FromBinary(const c_string &binary) + CJobID FromBinary(const c_string &binary) cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 98a0a2913..76a61e177 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -15,7 +15,7 @@ from ray.includes.unique_ids cimport ( CActorID, CClientID, CConfigID, - CDriverID, + CJobID, CFunctionID, CObjectID, CTaskID, @@ -212,15 +212,23 @@ cdef class ClientID(UniqueID): return self.data -cdef class DriverID(UniqueID): +cdef class JobID(UniqueID): def __init__(self, id): check_id(id) - self.data = CDriverID.FromBinary(id) + self.data = CJobID.FromBinary(id) - cdef CDriverID native(self): - return self.data + cdef CJobID native(self): + return self.data +cdef class WorkerID(UniqueID): + + def __init__(self, id): + check_id(id) + self.data = CWorkerID.FromBinary(id) + + cdef CWorkerID native(self): + return self.data cdef class ActorID(UniqueID): @@ -277,7 +285,8 @@ _ID_TYPES = [ ActorHandleID, ActorID, ClientID, - DriverID, + JobID, + WorkerID, FunctionID, ObjectID, TaskID, diff --git a/python/ray/monitor.py b/python/ray/monitor.py index b1b3df37d..d02486277 100644 --- a/python/ray/monitor.py +++ b/python/ray/monitor.py @@ -130,14 +130,14 @@ class Monitor(object): "Monitor: " "could not find ip for client {}".format(client_id)) - def _xray_clean_up_entries_for_driver(self, driver_id): - """Remove this driver's object/task entries from redis. + def _xray_clean_up_entries_for_job(self, job_id): + """Remove this job's object/task entries from redis. Removes control-state entries of all tasks and task return objects belonging to the driver. Args: - driver_id: The driver id. + job_id: The job id. """ xray_task_table_prefix = ( @@ -146,23 +146,23 @@ class Monitor(object): ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii")) task_table_objects = ray.tasks() - driver_id_hex = binary_to_hex(driver_id) - driver_task_id_bins = set() + job_id_hex = binary_to_hex(job_id) + job_task_id_bins = set() for task_id_hex, task_info in task_table_objects.items(): task_table_object = task_info["TaskSpec"] - task_driver_id_hex = task_table_object["DriverID"] - if driver_id_hex != task_driver_id_hex: + task_job_id_hex = task_table_object["JobID"] + if job_id_hex != task_job_id_hex: # Ignore tasks that aren't from this driver. continue - driver_task_id_bins.add(hex_to_binary(task_id_hex)) + job_task_id_bins.add(hex_to_binary(task_id_hex)) # Get objects associated with the driver. object_table_objects = ray.objects() - driver_object_id_bins = set() + job_object_id_bins = set() for object_id, _ in object_table_objects.items(): task_id_bin = ray._raylet.compute_task_id(object_id).binary() - if task_id_bin in driver_task_id_bins: - driver_object_id_bins.add(object_id.binary()) + if task_id_bin in job_task_id_bins: + job_object_id_bins.add(object_id.binary()) def to_shard_index(id_bin): if len(id_bin) == ray.TaskID.size(): @@ -174,10 +174,10 @@ class Monitor(object): # Form the redis keys to delete. sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))] - for task_id_bin in driver_task_id_bins: + for task_id_bin in job_task_id_bins: sharded_keys[to_shard_index(task_id_bin)].append( xray_task_table_prefix + task_id_bin) - for object_id_bin in driver_object_id_bins: + for object_id_bin in job_object_id_bins: sharded_keys[to_shard_index(object_id_bin)].append( xray_object_table_prefix + object_id_bin) @@ -198,21 +198,21 @@ class Monitor(object): "entries from redis shard {}.".format( len(keys) - num_deleted, shard_index)) - def xray_driver_removed_handler(self, unused_channel, data): - """Handle a notification that a driver has been removed. + def xray_job_removed_handler(self, unused_channel, data): + """Handle a notification that a job has been removed. Args: unused_channel: The message channel. data: The message data. """ gcs_entries = ray.gcs_utils.GcsEntry.FromString(data) - driver_data = gcs_entries.entries[0] - message = ray.gcs_utils.DriverTableData.FromString(driver_data) - driver_id = message.driver_id + job_data = gcs_entries.entries[0] + message = ray.gcs_utils.JobTableData.FromString(job_data) + job_id = message.job_id logger.info("Monitor: " "XRay Driver {} has been removed.".format( - binary_to_hex(driver_id))) - self._xray_clean_up_entries_for_driver(driver_id) + binary_to_hex(job_id))) + self._xray_clean_up_entries_for_job(job_id) def process_messages(self, max_messages=10000): """Process all messages ready in the subscription channels. @@ -240,9 +240,9 @@ class Monitor(object): if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL: # Similar functionality as raylet info channel message_handler = self.xray_heartbeat_batch_handler - elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL: + elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL: # Handles driver death. - message_handler = self.xray_driver_removed_handler + message_handler = self.xray_job_removed_handler else: raise Exception("This code should be unreachable.") @@ -298,7 +298,7 @@ class Monitor(object): """ # Initialize the subscription channel. self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL) - self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL) + self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL) # TODO(rkn): If there were any dead clients at startup, we should clean # up the associated state in the state tables. diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 9ff6994b8..025218612 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -44,7 +44,7 @@ class RemoteFunction(object): return the resulting ObjectIDs. For an example, see "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. - _last_driver_id_exported_for: The ID of the driver ID of the last Ray + _last_job_id_exported_for: The ID of the job ID of the last Ray session during which this remote function definition was exported. This is an imperfect mechanism used to determine if we need to export the remote function again. It is imperfect in the sense that @@ -73,7 +73,7 @@ class RemoteFunction(object): self._function_signature = ray.signature.extract_signature( self._function) - self._last_driver_id_exported_for = None + self._last_job_id_exported_for = None # Override task.remote's signature and docstring @wraps(function) @@ -115,11 +115,11 @@ class RemoteFunction(object): worker = ray.worker.get_global_worker() worker.check_connected() - if (self._last_driver_id_exported_for is None - or self._last_driver_id_exported_for != worker.task_driver_id): + if (self._last_job_id_exported_for is None + or self._last_job_id_exported_for != worker.current_job_id): # If this function was exported in a previous session, we need to # export this function again, because current GCS doesn't have it. - self._last_driver_id_exported_for = worker.task_driver_id + self._last_job_id_exported_for = worker.current_job_id worker.function_actor_manager.export(self) kwargs = {} if kwargs is None else kwargs diff --git a/python/ray/runtime_context.py b/python/ray/runtime_context.py index cb3b004cb..0c1c88e00 100644 --- a/python/ray/runtime_context.py +++ b/python/ray/runtime_context.py @@ -20,7 +20,7 @@ class RuntimeContext(object): a task, return the driver ID of the associated driver. """ assert self.worker is not None - return self.worker.task_driver_id + return self.worker.current_job_id _runtime_context = None diff --git a/python/ray/state.py b/python/ray/state.py index 35f97cd65..288b64dc1 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -316,7 +316,7 @@ class GlobalState(object): function_descriptor_list) task_spec_info = { - "DriverID": task.driver_id().hex(), + "JobID": task.job_id().hex(), "TaskID": task.task_id().hex(), "ParentTaskID": task.parent_task_id().hex(), "ParentCounter": task.parent_counter(), @@ -817,19 +817,19 @@ class GlobalState(object): return dict(total_available_resources) - def _error_messages(self, driver_id): + def _error_messages(self, job_id): """Get the error messages for a specific driver. Args: - driver_id: The ID of the driver to get the errors for. + job_id: The ID of the job to get the errors for. Returns: A list of the error messages for this driver. """ - assert isinstance(driver_id, ray.DriverID) + assert isinstance(job_id, ray.JobID) message = self.redis_client.execute_command( "RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "", - driver_id.binary()) + job_id.binary()) # If there are no errors, return early. if message is None: @@ -839,7 +839,7 @@ class GlobalState(object): error_messages = [] for entry in gcs_entries.entries: error_data = gcs_utils.ErrorTableData.FromString(entry) - assert driver_id.binary() == error_data.driver_id + assert job_id.binary() == error_data.job_id error_message = { "type": error_data.type, "message": error_data.error_message, @@ -848,12 +848,12 @@ class GlobalState(object): error_messages.append(error_message) return error_messages - def error_messages(self, driver_id=None): + def error_messages(self, job_id=None): """Get the error messages for all drivers or a specific driver. Args: - driver_id: The specific driver to get the errors for. If this is - None, then this method retrieves the errors for all drivers. + job_id: The specific job to get the errors for. If this is + None, then this method retrieves the errors for all jobs. Returns: A dictionary mapping driver ID to a list of the error messages for @@ -861,21 +861,20 @@ class GlobalState(object): """ self._check_connected() - if driver_id is not None: - assert isinstance(driver_id, ray.DriverID) - return self._error_messages(driver_id) + if job_id is not None: + assert isinstance(job_id, ray.JobID) + return self._error_messages(job_id) error_table_keys = self.redis_client.keys( gcs_utils.TablePrefix_ERROR_INFO_string + "*") - driver_ids = [ + job_ids = [ key[len(gcs_utils.TablePrefix_ERROR_INFO_string):] for key in error_table_keys ] return { - binary_to_hex(driver_id): self._error_messages( - ray.DriverID(driver_id)) - for driver_id in driver_ids + binary_to_hex(job_id): self._error_messages(ray.JobID(job_id)) + for job_id in job_ids } def actor_checkpoint_info(self, actor_id): @@ -969,12 +968,12 @@ class DeprecatedGlobalState(object): "instead.") return ray.available_resources() - def error_messages(self, driver_id=None): + def error_messages(self, job_id=None): logger.warning( "ray.global_state.error_messages() is deprecated and will be " "removed in a subsequent release. Use ray.errors() " "instead.") - return ray.errors(driver_id=driver_id) + return ray.errors(job_id=job_id) state = GlobalState() @@ -1095,7 +1094,7 @@ def errors(include_cluster_errors=True): Error messages pushed from the cluster. """ worker = ray.worker.global_worker - error_messages = state.error_messages(driver_id=worker.task_driver_id) + error_messages = state.error_messages(job_id=worker.current_job_id) if include_cluster_errors: - error_messages += state.error_messages(driver_id=ray.DriverID.nil()) + error_messages += state.error_messages(job_id=ray.JobID.nil()) return error_messages diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index b2ef1a8f0..cff8d45b9 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2439,7 +2439,7 @@ def test_global_state_api(shutdown_only): assert ray.objects() == {} - driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id) + job_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id) driver_task_id = ray.worker.global_worker.current_task_id.hex() # One task is put in the task table which corresponds to this driver. @@ -2453,7 +2453,7 @@ def test_global_state_api(shutdown_only): assert task_spec["TaskID"] == driver_task_id assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [] - assert task_spec["DriverID"] == driver_id + assert task_spec["JobID"] == job_id assert task_spec["FunctionID"] == nil_id_hex assert task_spec["ReturnObjectIDs"] == [] @@ -2481,7 +2481,7 @@ def test_global_state_api(shutdown_only): task_spec = task_table[task_id]["TaskSpec"] assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [1, "hi", x_id] - assert task_spec["DriverID"] == driver_id + assert task_spec["JobID"] == job_id assert task_spec["ReturnObjectIDs"] == [result_id] assert task_table[task_id] == ray.tasks(task_id) @@ -2613,9 +2613,9 @@ def test_workers(shutdown_only): worker_ids = set(ray.get([f.remote() for _ in range(10)])) -def test_specific_driver_id(): - dummy_driver_id = ray.DriverID(b"00112233445566778899") - ray.init(num_cpus=1, driver_id=dummy_driver_id) +def test_specific_job_id(): + dummy_driver_id = ray.JobID(b"00112233445566778899") + ray.init(num_cpus=1, job_id=dummy_driver_id) # in driver assert dummy_driver_id == ray._get_runtime_context().current_driver_id @@ -2727,7 +2727,7 @@ def test_ray_setproctitle(ray_start_2_cpus): def test_duplicate_error_messages(shutdown_only): ray.init(num_cpus=0) - driver_id = ray.DriverID.nil() + driver_id = ray.WorkerID.nil() error_data = ray.gcs_utils.construct_error_message(driver_id, "test", "message", 0) diff --git a/python/ray/utils.py b/python/ray/utils.py index 0db48e41d..8be4fe1df 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -51,7 +51,7 @@ def format_error_message(exception_message, task_exception=False): return "\n".join(lines) -def push_error_to_driver(worker, error_type, message, driver_id=None): +def push_error_to_driver(worker, error_type, message, job_id=None): """Push an error message to the driver to be printed in the background. Args: @@ -59,19 +59,19 @@ def push_error_to_driver(worker, error_type, message, driver_id=None): error_type (str): The type of the error. message (str): The message that will be printed in the background on the driver. - driver_id: The ID of the driver to push the error message to. If this + job_id: The ID of the driver to push the error message to. If this is None, then the message will be pushed to all drivers. """ - if driver_id is None: - driver_id = ray.DriverID.nil() - worker.raylet_client.push_error(driver_id, error_type, message, - time.time()) + if job_id is None: + job_id = ray.JobID.nil() + assert isinstance(job_id, ray.JobID) + worker.raylet_client.push_error(job_id, error_type, message, time.time()) def push_error_to_driver_through_redis(redis_client, error_type, message, - driver_id=None): + job_id=None): """Push an error message to the driver to be printed in the background. Normally the push_error_to_driver function should be used. However, in some @@ -84,19 +84,20 @@ def push_error_to_driver_through_redis(redis_client, error_type (str): The type of the error. message (str): The message that will be printed in the background on the driver. - driver_id: The ID of the driver to push the error message to. If this + job_id: The ID of the driver to push the error message to. If this is None, then the message will be pushed to all drivers. """ - if driver_id is None: - driver_id = ray.DriverID.nil() + if job_id is None: + job_id = ray.JobID.nil() + assert isinstance(job_id, ray.JobID) # Do everything in Python and through the Python Redis client instead # of through the raylet. - error_data = ray.gcs_utils.construct_error_message(driver_id, error_type, + error_data = ray.gcs_utils.construct_error_message(job_id, error_type, message, time.time()) redis_client.execute_command( "RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"), - ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), - driver_id.binary(), error_data) + ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), job_id.binary(), + error_data) def is_cython(obj): @@ -443,7 +444,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker): worker, ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR, warning_message, - driver_id=worker.task_driver_id) + job_id=worker.current_job_id) class _ThreadSafeProxy(object): diff --git a/python/ray/worker.py b/python/ray/worker.py index 710f0db43..15ee79890 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -40,7 +40,8 @@ from ray import ( ActorHandleID, ActorID, ClientID, - DriverID, + WorkerID, + JobID, ObjectID, TaskID, ) @@ -145,9 +146,9 @@ class Worker(object): # TODO: clean up the SerializationContext once the job finished. self.serialization_context_map = {} self.function_actor_manager = FunctionActorManager(self) - # Identity of the driver that this worker is processing. - # It is a DriverID. - self.task_driver_id = DriverID.nil() + # Identity of the job that this worker is processing. + # It is a JobID. + self.current_job_id = JobID.nil() self._task_context = threading.local() # This event is checked regularly by all of the threads so that they # know when to exit. @@ -227,24 +228,24 @@ class Worker(object): if self.actor_init_error is not None: raise self.actor_init_error - def get_serialization_context(self, driver_id): - """Get the SerializationContext of the driver that this worker is processing. + def get_serialization_context(self, job_id): + """Get the SerializationContext of the job that this worker is processing. Args: - driver_id: The ID of the driver that indicates which driver to get + job_id: The ID of the job that indicates which job to get the serialization context for. Returns: - The serialization context of the given driver. + The serialization context of the given job. """ # This function needs to be proctected by a lock, because it will be # called by`register_class_for_serialization`, as well as the import # thread, from different threads. Also, this function will recursively # call itself, so we use RLock here. with self.lock: - if driver_id not in self.serialization_context_map: - _initialize_serialization(driver_id) - return self.serialization_context_map[driver_id] + if job_id not in self.serialization_context_map: + _initialize_serialization(job_id) + return self.serialization_context_map[job_id] def check_connected(self): """Check if the worker is connected. @@ -314,7 +315,7 @@ class Worker(object): object_id=pyarrow.plasma.ObjectID(object_id.binary()), memcopy_threads=self.memcopy_threads, serialization_context=self.get_serialization_context( - self.task_driver_id)) + self.current_job_id)) break except pyarrow.SerializationCallbackError as e: try: @@ -388,17 +389,17 @@ class Worker(object): # should return an error code to the caller instead of printing a # message. logger.info( - "The object with ID {} already exists in the object store." - .format(object_id)) + "The object with ID {} already exists in the object store.". + format(object_id)) except TypeError: # This error can happen because one of the members of the object # may not be serializable for cloudpickle. So we need these extra # fallbacks here to start from the beginning. Hopefully the object # could have a `__reduce__` method. register_custom_serializer(type(value), use_pickle=True) - warning_message = ("WARNING: Serializing the class {} failed, " - "so are are falling back to cloudpickle." - .format(type(value))) + warning_message = ( + "WARNING: Serializing the class {} failed, " + "so are are falling back to cloudpickle.".format(type(value))) logger.warning(warning_message) self.store_and_register(object_id, value) @@ -407,7 +408,7 @@ class Worker(object): # Only send the warning once. warning_sent = False serialization_context = self.get_serialization_context( - self.task_driver_id) + self.current_job_id) while True: try: # We divide very large get requests into smaller get requests @@ -449,7 +450,7 @@ class Worker(object): self, ray_constants.WAIT_FOR_CLASS_PUSH_ERROR, warning_message, - driver_id=self.task_driver_id) + job_id=self.current_job_id) warning_sent = True def _deserialize_object_from_arrow(self, data, metadata, object_id, @@ -575,7 +576,7 @@ class Worker(object): num_return_vals=None, resources=None, placement_resources=None, - driver_id=None): + job_id=None): """Submit a remote task to the scheduler. Tell the scheduler to schedule the execution of the function with @@ -601,11 +602,11 @@ class Worker(object): placement_resources: The resources required for placing the task. If this is not provided or if it is an empty dictionary, then the placement resources will be equal to resources. - driver_id: The ID of the relevant driver. This is almost always the - driver ID of the driver that is currently running. However, in + job_id: The ID of the relevant job. This is almost always the + job ID of the job that is currently running. However, in the exceptional case that an actor task is being dispatched to - an actor created by a different driver, this should be the - driver ID of the driver that created the actor. + an actor created by a different job, this should be the + job ID of the job that created the actor. Returns: The return object IDs for this task. @@ -642,8 +643,8 @@ class Worker(object): if new_actor_handles is None: new_actor_handles = [] - if driver_id is None: - driver_id = self.task_driver_id + if job_id is None: + job_id = self.current_job_id if resources is None: raise ValueError("The resources dictionary is required.") @@ -674,13 +675,13 @@ class Worker(object): assert not self.current_task_id.is_nil() # Current driver id must not be nil when submitting a task. # Because every task must belong to a driver. - assert not self.task_driver_id.is_nil() + assert not self.current_job_id.is_nil() # Submit the task to raylet. function_descriptor_list = ( function_descriptor.get_function_descriptor_list()) - assert isinstance(driver_id, DriverID) + assert isinstance(job_id, JobID) task = ray._raylet.Task( - driver_id, + job_id, function_descriptor_list, args_for_raylet, num_return_vals, @@ -747,7 +748,7 @@ class Worker(object): # Run the function on all workers. self.redis_client.hmset( key, { - "driver_id": self.task_driver_id.binary(), + "job_id": self.current_job_id.binary(), "function_id": function_to_run_id, "function": pickled_function, "run_on_other_drivers": str(run_on_other_drivers) @@ -853,17 +854,17 @@ class Worker(object): assert self.task_context.task_index == 0 assert self.task_context.put_index == 1 if task.actor_id().is_nil(): - # If this worker is not an actor, check that `task_driver_id` + # If this worker is not an actor, check that `current_job_id` # was reset when the worker finished the previous task. - assert self.task_driver_id.is_nil() + assert self.current_job_id.is_nil() # Set the driver ID of the current running task. This is # needed so that if the task throws an exception, we propagate # the error message to the correct driver. - self.task_driver_id = task.driver_id() + self.current_job_id = task.job_id() else: - # If this worker is an actor, task_driver_id wasn't reset. + # If this worker is an actor, current_job_id wasn't reset. # Check that current task's driver ID equals the previous one. - assert self.task_driver_id == task.driver_id() + assert self.current_job_id == task.job_id() self.task_context.current_task_id = task.task_id() @@ -945,7 +946,7 @@ class Worker(object): self, ray_constants.TASK_PUSH_ERROR, str(failure_object), - driver_id=self.task_driver_id) + job_id=self.current_job_id) # Mark the actor init as failed if not self.actor_id.is_nil() and function_name == "__init__": self.mark_actor_init_failed(error) @@ -960,7 +961,7 @@ class Worker(object): """ function_descriptor = FunctionDescriptor.from_bytes_list( task.function_descriptor_list()) - driver_id = task.driver_id() + job_id = task.job_id() # TODO(rkn): It would be preferable for actor creation tasks to share # more of the code path with regular task execution. @@ -969,7 +970,7 @@ class Worker(object): self.actor_id = task.actor_creation_id() self.actor_creation_task_id = task.task_id() actor_class = self.function_actor_manager.load_actor_class( - driver_id, function_descriptor) + job_id, function_descriptor) self.actors[self.actor_id] = actor_class.__new__(actor_class) self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo( num_tasks_since_last_checkpoint=0, @@ -978,7 +979,7 @@ class Worker(object): ) execution_info = self.function_actor_manager.get_execution_info( - driver_id, function_descriptor) + job_id, function_descriptor) # Execute the task. function_name = execution_info.function_name @@ -1005,20 +1006,20 @@ class Worker(object): self.task_context.task_index = 0 self.task_context.put_index = 1 if self.actor_id.is_nil(): - # Don't need to reset task_driver_id if the worker is an + # Don't need to reset `current_job_id` if the worker is an # actor. Because the following tasks should all have the # same driver id. - self.task_driver_id = DriverID.nil() + self.current_job_id = WorkerID.nil() # Reset signal counters so that the next task can get # all past signals. ray_signal.reset() # Increase the task execution counter. self.function_actor_manager.increase_task_counter( - driver_id, function_descriptor) + job_id, function_descriptor) reached_max_executions = (self.function_actor_manager.get_task_counter( - driver_id, function_descriptor) == execution_info.max_calls) + job_id, function_descriptor) == execution_info.max_calls) if reached_max_executions: self.raylet_client.disconnect() sys.exit(0) @@ -1141,7 +1142,7 @@ def print_failed_task(task_status): task_status["error_message"])) -def _initialize_serialization(driver_id, worker=global_worker): +def _initialize_serialization(job_id, worker=global_worker): """Initialize the serialization library. This defines a custom serializer for object IDs and also tells ray to @@ -1177,7 +1178,7 @@ def _initialize_serialization(driver_id, worker=global_worker): custom_serializer=actor_handle_serializer, custom_deserializer=actor_handle_deserializer) - worker.serialization_context_map[driver_id] = serialization_context + worker.serialization_context_map[job_id] = serialization_context # Register exception types. for error_cls in RAY_EXCEPTION_TYPES: @@ -1185,7 +1186,7 @@ def _initialize_serialization(driver_id, worker=global_worker): error_cls, use_dict=True, local=True, - driver_id=driver_id, + job_id=job_id, class_id=error_cls.__module__ + ". " + error_cls.__name__, ) # Tell Ray to serialize lambdas with pickle. @@ -1193,22 +1194,18 @@ def _initialize_serialization(driver_id, worker=global_worker): type(lambda: 0), use_pickle=True, local=True, - driver_id=driver_id, + job_id=job_id, class_id="lambda") # Tell Ray to serialize types with pickle. register_custom_serializer( - type(int), - use_pickle=True, - local=True, - driver_id=driver_id, - class_id="type") + type(int), use_pickle=True, local=True, job_id=job_id, class_id="type") # Tell Ray to serialize FunctionSignatures as dictionaries. This is # used when passing around actor handles. register_custom_serializer( ray.signature.FunctionSignature, use_dict=True, local=True, - driver_id=driver_id, + job_id=job_id, class_id="ray.signature.FunctionSignature") @@ -1231,7 +1228,7 @@ def init(redis_address=None, plasma_directory=None, huge_pages=False, include_webui=False, - driver_id=None, + job_id=None, configure_logging=True, logging_level=logging.INFO, logging_format=ray_constants.LOGGER_FORMAT, @@ -1302,7 +1299,7 @@ def init(redis_address=None, Store with hugetlbfs support. Requires plasma_directory. include_webui: Boolean flag indicating whether to start the web UI, which displays the status of the Ray cluster. - driver_id: The ID of driver. + job_id: The ID of this job. configure_logging: True if allow the logging cofiguration here. Otherwise, the users may want to configure it by their own. logging_level: Logging level, default will be logging.INFO. @@ -1449,7 +1446,7 @@ def init(redis_address=None, mode=driver_mode, log_to_driver=log_to_driver, worker=global_worker, - driver_id=driver_id) + job_id=job_id) for hook in _post_init_hooks: hook() @@ -1660,10 +1657,10 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped): assert len(gcs_entry.entries) == 1 error_data = ray.gcs_utils.ErrorTableData.FromString( gcs_entry.entries[0]) - driver_id = error_data.driver_id - if driver_id not in [ - worker.task_driver_id.binary(), - DriverID.nil().binary() + job_id = error_data.job_id + if job_id not in [ + worker.current_job_id.binary(), + JobID.nil().binary() ]: continue @@ -1691,7 +1688,7 @@ def connect(node, mode=WORKER_MODE, log_to_driver=False, worker=global_worker, - driver_id=None): + job_id=None): """Connect this worker to the raylet, to Plasma, and to Redis. Args: @@ -1701,7 +1698,7 @@ def connect(node, log_to_driver (bool): If true, then output from all of the worker processes on all nodes will be directed to the driver. worker: The ray.Worker instance. - driver_id: The ID of driver. If it's None, then we will generate one. + job_id: The ID of job. If it's None, then we will generate one. """ # Do some basic checking to make sure we didn't call ray.init twice. error_message = "Perhaps you called ray.init twice by accident?" @@ -1721,20 +1718,20 @@ def connect(node, setproctitle.setproctitle("ray_worker") else: # This is the code path of driver mode. - if driver_id is None: - driver_id = DriverID.from_random() + if job_id is None: + job_id = JobID.from_random() - if not isinstance(driver_id, DriverID): - raise TypeError("The type of given driver id must be DriverID.") + if not isinstance(job_id, JobID): + raise TypeError("The type of given job id must be JobID.") - worker.worker_id = driver_id.binary() + worker.worker_id = job_id.binary() # When tasks are executed on remote workers in the context of multiple - # drivers, the task driver ID is used to keep track of which driver is + # drivers, the current job ID is used to keep track of which driver is # responsible for the task so that error messages will be propagated to # the correct driver. if mode != WORKER_MODE: - worker.task_driver_id = DriverID(worker.worker_id) + worker.current_job_id = JobID(worker.worker_id) # All workers start out as non-actors. A worker can be turned into an actor # after it is created. @@ -1766,7 +1763,7 @@ def connect(node, worker.redis_client, ray_constants.VERSION_MISMATCH_PUSH_ERROR, traceback_str, - driver_id=None) + job_id=None) worker.lock = threading.RLock() @@ -1831,7 +1828,7 @@ def connect(node, # Create an object store client. worker.plasma_client = thread_safe_client( plasma.connect(node.plasma_store_socket_name, None, 0, 300)) - driver_id_str = _random_string() + job_id_str = _random_string() # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1859,11 +1856,11 @@ def connect(node, function_descriptor = FunctionDescriptor.for_driver_task() driver_task = ray._raylet.Task( - worker.task_driver_id, + worker.current_job_id, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - TaskID(driver_id_str[:TaskID.size()]), # parent_task_id. + TaskID(job_id_str[:TaskID.size()]), # parent_task_id. 0, # parent_counter. ActorID.nil(), # actor_creation_id. ObjectID.nil(), # actor_creation_dummy_object_id. @@ -1895,7 +1892,7 @@ def connect(node, node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), - DriverID(driver_id_str), + JobID(job_id_str), ) # Start the import thread @@ -2057,7 +2054,7 @@ def register_custom_serializer(cls, serializer=None, deserializer=None, local=False, - driver_id=None, + job_id=None, class_id=None): """Enable serialization and deserialization for a particular class. @@ -2078,7 +2075,7 @@ def register_custom_serializer(cls, if and only if use_pickle and use_dict are False. local: True if the serializers should only be registered on the current worker. This should usually be False. - driver_id: ID of the driver that we want to register the class for. + job_id: ID of the job that we want to register the class for. class_id: ID of the class that we are registering. If this is not specified, we will calculate a new one inside the function. @@ -2126,9 +2123,9 @@ def register_custom_serializer(cls, # Make sure class_id is a string. class_id = ray.utils.binary_to_hex(class_id) - if driver_id is None: - driver_id = worker.task_driver_id - assert isinstance(driver_id, DriverID) + if job_id is None: + job_id = worker.current_job_id + assert isinstance(job_id, JobID) def register_class_for_serialization(worker_info): # TODO(rkn): We need to be more thoughtful about what to do if custom @@ -2138,7 +2135,7 @@ def register_custom_serializer(cls, # system. serialization_context = worker_info[ - "worker"].get_serialization_context(driver_id) + "worker"].get_serialization_context(job_id) serialization_context.register_type( cls, class_id, diff --git a/python/ray/workers/default_worker.py b/python/ray/workers/default_worker.py index 43ecd0658..508a38fae 100644 --- a/python/ray/workers/default_worker.py +++ b/python/ray/workers/default_worker.py @@ -102,7 +102,7 @@ if __name__ == "__main__": ray.worker.global_worker, "worker_crash", traceback_str, - driver_id=None) + job_id=None) # TODO(rkn): Note that if the worker was in the middle of executing # a task, then any worker or driver that is blocking in a get call # and waiting for the output of that task will hang. We need to diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 3928d4adf..2379a22fd 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -85,7 +85,7 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) { return h; } -TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) { +TaskID TaskID::GetDriverTaskID(const WorkerID &driver_id) { std::string driver_id_str = driver_id.Binary(); driver_id_str.resize(Size()); return TaskID::FromBinary(driver_id_str); @@ -113,12 +113,12 @@ ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) { return object_id; } -const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, +const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id, int parent_task_counter) { // Compute hashes. SHA256_CTX ctx; sha256_init(&ctx); - sha256_update(&ctx, reinterpret_cast(driver_id.Data()), driver_id.Size()); + sha256_update(&ctx, reinterpret_cast(job_id.Data()), job_id.Size()); sha256_update(&ctx, reinterpret_cast(parent_task_id.Data()), parent_task_id.Size()); sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter)); @@ -129,6 +129,16 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task return TaskID::FromBinary(std::string(buff, buff + TaskID::Size())); } +const WorkerID ComputeDriverId(const JobID &job_id) { + // Currently, a job id equals its driver id. + return WorkerID(job_id); +} + +const JobID ComputeJobId(const WorkerID &driver_id) { + // Currently, a job id equals its driver id. + return JobID(driver_id); +} + #define ID_OSTREAM_OPERATOR(id_type) \ std::ostream &operator<<(std::ostream &os, const id_type &id) { \ if (id.IsNil()) { \ diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 3b2d244cf..09f4a16aa 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -17,7 +17,7 @@ namespace ray { -class DriverID; +class WorkerID; class UniqueID; // Declaration. @@ -72,7 +72,7 @@ class TaskID : public BaseID { public: TaskID() : BaseID() {} static size_t Size() { return kTaskIDSize; } - static TaskID GetDriverTaskID(const DriverID &driver_id); + static TaskID GetDriverTaskID(const WorkerID &driver_id); private: uint8_t id_[kTaskIDSize]; @@ -152,11 +152,11 @@ std::ostream &operator<<(std::ostream &os, const ObjectID &id); /// Generate a task ID from the given info. /// -/// \param driver_id The driver that creates the task. +/// \param job_id The job that creates the task. /// \param parent_task_id The parent task of this task. /// \param parent_task_counter The task index of the worker. /// \return The task ID generated from the given info. -const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id, +const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id, int parent_task_counter); template diff --git a/src/ray/common/id_def.h b/src/ray/common/id_def.h index 96c7d59d1..7344793f6 100644 --- a/src/ray/common/id_def.h +++ b/src/ray/common/id_def.h @@ -10,6 +10,6 @@ DEFINE_UNIQUE_ID(ActorID) DEFINE_UNIQUE_ID(ActorHandleID) DEFINE_UNIQUE_ID(ActorCheckpointID) DEFINE_UNIQUE_ID(WorkerID) -DEFINE_UNIQUE_ID(DriverID) +DEFINE_UNIQUE_ID(JobID) DEFINE_UNIQUE_ID(ConfigID) DEFINE_UNIQUE_ID(ClientID) diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index 717c52e07..af44d3c8b 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -38,35 +38,34 @@ struct WorkerThreadContext { thread_local std::unique_ptr WorkerContext::thread_context_ = nullptr; -WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id) +WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) : worker_type(worker_type), - worker_id(worker_type == WorkerType::DRIVER - ? ClientID::FromBinary(driver_id.Binary()) - : ClientID::FromRandom()), - current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) { + worker_id(worker_type == WorkerType::DRIVER ? WorkerID::FromBinary(job_id.Binary()) + : WorkerID::FromRandom()), + current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) { // For worker main thread which initializes the WorkerContext, // set task_id according to whether current worker is a driver. - // (For other threads it's set to randmom ID via GetThreadContext). + // (For other threads it's set to random ID via GetThreadContext). GetThreadContext().SetCurrentTask( (worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil()); } const WorkerType WorkerContext::GetWorkerType() const { return worker_type; } -const ClientID &WorkerContext::GetWorkerID() const { return worker_id; } +const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; } int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); } int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); } -const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; } +const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; } const TaskID &WorkerContext::GetCurrentTaskID() const { return GetThreadContext().GetCurrentTaskID(); } void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) { - current_driver_id = spec.DriverId(); + current_job_id = spec.JobId(); GetThreadContext().SetCurrentTask(spec); } diff --git a/src/ray/core_worker/context.h b/src/ray/core_worker/context.h index 932d02891..5bfb830af 100644 --- a/src/ray/core_worker/context.h +++ b/src/ray/core_worker/context.h @@ -10,13 +10,13 @@ struct WorkerThreadContext; class WorkerContext { public: - WorkerContext(WorkerType worker_type, const DriverID &driver_id); + WorkerContext(WorkerType worker_type, const JobID &job_id); const WorkerType GetWorkerType() const; - const ClientID &GetWorkerID() const; + const WorkerID &GetWorkerID() const; - const DriverID &GetCurrentDriverID() const; + const JobID &GetCurrentJobID() const; const TaskID &GetCurrentTaskID() const; @@ -31,10 +31,10 @@ class WorkerContext { const WorkerType worker_type; /// ID for this worker. - const ClientID worker_id; + const WorkerID worker_id; - /// Driver ID for this worker. - DriverID current_driver_id; + /// Job ID for this worker. + JobID current_job_id; private: static WorkerThreadContext &GetThreadContext(); diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index bcc1bdd96..c8ef189c8 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -6,15 +6,16 @@ namespace ray { CoreWorker::CoreWorker(const enum WorkerType worker_type, const enum WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, - DriverID driver_id) + const JobID &job_id) : worker_type_(worker_type), language_(language), store_socket_(store_socket), raylet_socket_(raylet_socket), - worker_context_(worker_type, driver_id), - raylet_client_(raylet_socket_, worker_context_.GetWorkerID(), + worker_context_(worker_type, job_id), + raylet_client_(raylet_socket_, + ClientID::FromBinary(worker_context_.GetWorkerID().Binary()), (worker_type_ == ray::WorkerType::WORKER), - worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)), + worker_context_.GetCurrentJobID(), ToTaskLanguage(language_)), task_interface_(*this), object_interface_(*this), task_execution_interface_(*this) { diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index e03a8700b..3afc28ee5 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -24,7 +24,7 @@ class CoreWorker { /// NOTE(zhijunfu): the constructor would throw if a failure happens. CoreWorker(const WorkerType worker_type, const WorkerLanguage language, const std::string &store_socket, const std::string &raylet_socket, - DriverID driver_id = DriverID::Nil()); + const JobID &job_id = JobID::Nil()); /// Type of this worker. enum WorkerType WorkerType() const { return worker_type_; } diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index 6e4ecc161..6e866fb00 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -126,7 +126,7 @@ class CoreWorkerTest : public ::testing::Test { void TestNormalTask(const std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - DriverID::FromRandom()); + JobID::FromRandom()); // Test pass by value. { @@ -184,7 +184,7 @@ class CoreWorkerTest : public ::testing::Test { void TestActorTask(const std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - DriverID::FromRandom()); + JobID::FromRandom()); std::unique_ptr actor_handle; @@ -275,9 +275,9 @@ TEST_F(ZeroNodeTest, TestTaskArg) { } TEST_F(ZeroNodeTest, TestWorkerContext) { - auto driver_id = DriverID::FromRandom(); + auto job_id = JobID::FromRandom(); - WorkerContext context(WorkerType::WORKER, driver_id); + WorkerContext context(WorkerType::WORKER, job_id); ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); ASSERT_EQ(context.GetNextTaskIndex(), 1); ASSERT_EQ(context.GetNextTaskIndex(), 2); @@ -302,7 +302,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) { TEST_F(SingleNodeTest, TestObjectInterface) { CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - DriverID::FromRandom()); + JobID::FromRandom()); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -358,11 +358,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) { TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - DriverID::FromRandom()); + JobID::FromRandom()); CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON, raylet_store_socket_names_[1], raylet_socket_names_[1], - DriverID::FromRandom()); + JobID::FromRandom()); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -446,7 +446,7 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) { TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) { try { CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "", - raylet_socket_names_[0], DriverID::FromRandom()); + raylet_socket_names_[0], JobID::FromRandom()); } catch (const std::exception &e) { std::cout << "Caught exception when constructing core worker: " << e.what(); } diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc index a331a0b6a..b7f4b5a63 100644 --- a/src/ray/core_worker/mock_worker.cc +++ b/src/ray/core_worker/mock_worker.cc @@ -17,7 +17,7 @@ class MockWorker { public: MockWorker(const std::string &store_socket, const std::string &raylet_socket) : worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket, - DriverID::FromRandom()) {} + JobID::FromRandom()) {} void Run() { auto executor_func = [this](const RayFunction &ray_function, diff --git a/src/ray/core_worker/task_interface.cc b/src/ray/core_worker/task_interface.cc index 6a91bd6b2..bd84e5e0b 100644 --- a/src/ray/core_worker/task_interface.cc +++ b/src/ray/core_worker/task_interface.cc @@ -20,7 +20,7 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, std::vector *return_ids) { auto &context = core_worker_.worker_context_; auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + const auto task_id = GenerateTaskId(context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index); auto num_returns = task_options.num_returns; @@ -32,7 +32,7 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function, auto task_arguments = BuildTaskArguments(args); auto language = core_worker_.ToTaskLanguage(function.language); - ray::raylet::TaskSpecification spec(context.GetCurrentDriverID(), + ray::raylet::TaskSpecification spec(context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index, task_arguments, num_returns, task_options.resources, language, function.function_descriptor); @@ -48,7 +48,7 @@ Status CoreWorkerTaskInterface::CreateActor( std::unique_ptr *actor_handle) { auto &context = core_worker_.worker_context_; auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + const auto task_id = GenerateTaskId(context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index); std::vector return_ids; @@ -66,7 +66,7 @@ Status CoreWorkerTaskInterface::CreateActor( // Note that the caller is supposed to specify required placement resources // correctly via actor_creation_options.resources. ray::raylet::TaskSpecification spec( - context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index, actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions, ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1, actor_creation_options.resources, actor_creation_options.resources, language, @@ -84,7 +84,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, std::vector *return_ids) { auto &context = core_worker_.worker_context_; auto next_task_index = context.GetNextTaskIndex(); - const auto task_id = GenerateTaskId(context.GetCurrentDriverID(), + const auto task_id = GenerateTaskId(context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index); // add one for actor cursor object id. @@ -102,7 +102,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle, std::vector new_actor_handles; ray::raylet::TaskSpecification spec( - context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index, + context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index, ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(), actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(), new_actor_handles, task_arguments, num_returns, task_options.resources, task_options.resources, diff --git a/src/ray/gcs/client.cc b/src/ray/gcs/client.cc index 6de29bb52..e96c5ad38 100644 --- a/src/ray/gcs/client.cc +++ b/src/ray/gcs/client.cc @@ -109,7 +109,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port, actor_table_.reset(new ActorTable({primary_context_}, this)); client_table_.reset(new ClientTable({primary_context_}, this, client_id)); error_table_.reset(new ErrorTable({primary_context_}, this)); - driver_table_.reset(new DriverTable({primary_context_}, this)); + job_table_.reset(new JobTable({primary_context_}, this)); heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this)); // Tables below would be sharded. object_table_.reset(new ObjectTable(shard_contexts_, this)); @@ -188,7 +188,7 @@ std::string AsyncGcsClient::DebugString() const { result << "\n- ErrorTable: " << error_table_->DebugString(); result << "\n- ProfileTable: " << profile_table_->DebugString(); result << "\n- ClientTable: " << client_table_->DebugString(); - result << "\n- DriverTable: " << driver_table_->DebugString(); + result << "\n- JobTable: " << job_table_->DebugString(); return result.str(); } @@ -214,7 +214,7 @@ HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() { ErrorTable &AsyncGcsClient::error_table() { return *error_table_; } -DriverTable &AsyncGcsClient::driver_table() { return *driver_table_; } +JobTable &AsyncGcsClient::job_table() { return *job_table_; } ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; } diff --git a/src/ray/gcs/client.h b/src/ray/gcs/client.h index 5e70025b3..0ebee0d70 100644 --- a/src/ray/gcs/client.h +++ b/src/ray/gcs/client.h @@ -54,7 +54,7 @@ class RAY_EXPORT AsyncGcsClient { HeartbeatTable &heartbeat_table(); HeartbeatBatchTable &heartbeat_batch_table(); ErrorTable &error_table(); - DriverTable &driver_table(); + JobTable &job_table(); ProfileTable &profile_table(); ActorCheckpointTable &actor_checkpoint_table(); ActorCheckpointIdTable &actor_checkpoint_id_table(); @@ -64,8 +64,8 @@ class RAY_EXPORT AsyncGcsClient { // driver (to set the PYTHONPATH) using GetExportCallback = std::function; - Status AddExport(const std::string &driver_id, std::string &export_data); - Status GetExport(const std::string &driver_id, int64_t export_index, + Status AddExport(const std::string &job_id, std::string &export_data); + Status GetExport(const std::string &job_id, int64_t export_index, const GetExportCallback &done_callback); std::vector> shard_contexts() { return shard_contexts_; } @@ -96,7 +96,7 @@ class RAY_EXPORT AsyncGcsClient { std::vector> shard_asio_subscribe_clients_; // The following context writes everything to the primary shard std::shared_ptr primary_context_; - std::unique_ptr driver_table_; + std::unique_ptr job_table_; std::unique_ptr asio_async_auxiliary_client_; std::unique_ptr asio_subscribe_auxiliary_client_; CommandType command_type_; @@ -105,14 +105,14 @@ class RAY_EXPORT AsyncGcsClient { class SyncGcsClient { Status LogEvent(const std::string &key, const std::string &value, double timestamp); Status NotifyError(const std::map &error_info); - Status RegisterFunction(const DriverID &driver_id, const FunctionID &function_id, + Status RegisterFunction(const JobID &job_id, const FunctionID &function_id, const std::string &language, const std::string &name, const std::string &data); - Status RetrieveFunction(const DriverID &driver_id, const FunctionID &function_id, + Status RetrieveFunction(const JobID &job_id, const FunctionID &function_id, std::string *name, std::string *data); - Status AddExport(const std::string &driver_id, std::string &export_data); - Status GetExport(const std::string &driver_id, int64_t export_index, std::string *data); + Status AddExport(const std::string &job_id, std::string &export_data); + Status GetExport(const std::string &job_id, int64_t export_index, std::string *data); }; } // namespace gcs diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 55115b1e2..7fdd48f32 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -29,7 +29,7 @@ class TestGcs : public ::testing::Test { TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { client_ = std::make_shared("127.0.0.1", 6379, command_type_, /*is_test_client=*/true); - driver_id_ = DriverID::FromRandom(); + job_id_ = JobID::FromRandom(); } virtual ~TestGcs() { @@ -49,7 +49,7 @@ class TestGcs : public ::testing::Test { uint64_t num_callbacks_; gcs::CommandType command_type_; std::shared_ptr client_; - DriverID driver_id_; + JobID job_id_; }; TestGcs *test; @@ -82,8 +82,7 @@ class TestGcsWithChainAsio : public TestGcsWithAsio { TestGcsWithChainAsio() : TestGcsWithAsio(gcs::CommandType::kChain){}; }; -void TestTableLookup(const DriverID &driver_id, - std::shared_ptr client) { +void TestTableLookup(const JobID &job_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); auto data = std::make_shared(); data->set_task("123"); @@ -109,8 +108,8 @@ void TestTableLookup(const DriverID &driver_id, }; // Add the task, then do a lookup. - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); - RAY_CHECK_OK(client->raylet_task_table().Lookup(driver_id, task_id, lookup_callback, + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, failure_callback)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). @@ -122,7 +121,7 @@ void TestTableLookup(const DriverID &driver_id, #define TEST_MACRO(FIXTURE, TEST) \ TEST_F(FIXTURE, TEST) { \ test = this; \ - TEST(driver_id_, client_); \ + TEST(job_id_, client_); \ } TEST_MACRO(TestGcsWithAsio, TestTableLookup); @@ -130,8 +129,7 @@ TEST_MACRO(TestGcsWithAsio, TestTableLookup); TEST_MACRO(TestGcsWithChainAsio, TestTableLookup); #endif -void TestLogLookup(const DriverID &driver_id, - std::shared_ptr client) { +void TestLogLookup(const JobID &job_id, std::shared_ptr client) { // Append some entries to the log at an object ID. TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"abc", "def", "ghi"}; @@ -145,7 +143,7 @@ void TestLogLookup(const DriverID &driver_id, ASSERT_EQ(data->node_manager_id(), d.node_manager_id()); }; RAY_CHECK_OK( - client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); + client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); } // Check that lookup returns the added object entries. @@ -164,7 +162,7 @@ void TestLogLookup(const DriverID &driver_id, // Do a lookup at the object ID. RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(driver_id, task_id, lookup_callback)); + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). test->Start(); @@ -173,10 +171,10 @@ void TestLogLookup(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestLogLookup) { test = this; - TestLogLookup(driver_id_, client_); + TestLogLookup(job_id_, client_); } -void TestTableLookupFailure(const DriverID &driver_id, +void TestTableLookupFailure(const JobID &job_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); @@ -191,7 +189,7 @@ void TestTableLookupFailure(const DriverID &driver_id, }; // Lookup the task. We have not done any writes, so the key should be empty. - RAY_CHECK_OK(client->raylet_task_table().Lookup(driver_id, task_id, lookup_callback, + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, lookup_callback, failure_callback)); // Run the event loop. The loop will only stop if the failure callback is // called (or an assertion failure). @@ -203,8 +201,7 @@ TEST_MACRO(TestGcsWithAsio, TestTableLookupFailure); TEST_MACRO(TestGcsWithChainAsio, TestTableLookupFailure); #endif -void TestLogAppendAt(const DriverID &driver_id, - std::shared_ptr client) { +void TestLogAppendAt(const JobID &job_id, std::shared_ptr client) { TaskID task_id = TaskID::FromRandom(); std::vector node_manager_ids = {"A", "B"}; std::vector> data_log; @@ -222,22 +219,21 @@ void TestLogAppendAt(const DriverID &driver_id, }; // Will succeed. - RAY_CHECK_OK(client->task_reconstruction_log().Append(driver_id, task_id, - data_log.front(), + RAY_CHECK_OK(client->task_reconstruction_log().Append(job_id, task_id, data_log.front(), /*done callback=*/nullptr)); // Append at index 0 will fail. RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - driver_id, task_id, data_log[1], + job_id, task_id, data_log[1], /*done callback=*/nullptr, failure_callback, /*log_length=*/0)); // Append at index 2 will fail. RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - driver_id, task_id, data_log[1], + job_id, task_id, data_log[1], /*done callback=*/nullptr, failure_callback, /*log_length=*/2)); // Append at index 1 will succeed. RAY_CHECK_OK(client->task_reconstruction_log().AppendAt( - driver_id, task_id, data_log[1], + job_id, task_id, data_log[1], /*done callback=*/nullptr, failure_callback, /*log_length=*/1)); auto lookup_callback = [node_manager_ids]( @@ -251,7 +247,7 @@ void TestLogAppendAt(const DriverID &driver_id, test->Stop(); }; RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(driver_id, task_id, lookup_callback)); + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). test->Start(); @@ -260,10 +256,10 @@ void TestLogAppendAt(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestLogAppendAt) { test = this; - TestLogAppendAt(driver_id_, client_); + TestLogAppendAt(job_id_, client_); } -void TestSet(const DriverID &driver_id, std::shared_ptr client) { +void TestSet(const JobID &job_id, std::shared_ptr client) { // Add some entries to the set at an object ID. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"abc", "def", "ghi"}; @@ -277,7 +273,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); } // Check that lookup returns the added object entries. @@ -290,7 +286,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli }; // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); for (auto &manager : managers) { auto data = std::make_shared(); @@ -304,7 +300,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli test->IncrementNumCallbacks(); }; RAY_CHECK_OK( - client->object_table().Remove(driver_id, object_id, data, remove_entry_callback)); + client->object_table().Remove(job_id, object_id, data, remove_entry_callback)); } // Check that the entries are removed. @@ -318,7 +314,7 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli }; // Do a lookup at the object ID. - RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback2)); + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback2)); // Run the event loop. The loop will only stop if the Lookup callback is // called (or an assertion failure). test->Start(); @@ -327,11 +323,11 @@ void TestSet(const DriverID &driver_id, std::shared_ptr cli TEST_F(TestGcsWithAsio, TestSet) { test = this; - TestSet(driver_id_, client_); + TestSet(job_id_, client_); } void TestDeleteKeysFromLog( - const DriverID &driver_id, std::shared_ptr client, + const JobID &job_id, std::shared_ptr client, std::vector> &data_vector) { std::vector ids; TaskID task_id; @@ -346,7 +342,7 @@ void TestDeleteKeysFromLog( test->IncrementNumCallbacks(); }; RAY_CHECK_OK( - client->task_reconstruction_log().Append(driver_id, task_id, data, add_callback)); + client->task_reconstruction_log().Append(job_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { // Check that lookup returns the added object entries. @@ -358,12 +354,12 @@ void TestDeleteKeysFromLog( test->IncrementNumCallbacks(); }; RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(driver_id, task_id, lookup_callback)); + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); } if (ids.size() == 1) { - client->task_reconstruction_log().Delete(driver_id, ids[0]); + client->task_reconstruction_log().Delete(job_id, ids[0]); } else { - client->task_reconstruction_log().Delete(driver_id, ids); + client->task_reconstruction_log().Delete(job_id, ids); } for (const auto &task_id : ids) { auto lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, @@ -373,11 +369,11 @@ void TestDeleteKeysFromLog( test->IncrementNumCallbacks(); }; RAY_CHECK_OK( - client->task_reconstruction_log().Lookup(driver_id, task_id, lookup_callback)); + client->task_reconstruction_log().Lookup(job_id, task_id, lookup_callback)); } } -void TestDeleteKeysFromTable(const DriverID &driver_id, +void TestDeleteKeysFromTable(const JobID &job_id, std::shared_ptr client, std::vector> &data_vector, bool stop_at_end) { @@ -393,7 +389,7 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ASSERT_EQ(data->task(), d.task()); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, add_callback)); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, add_callback)); } for (const auto &task_id : ids) { auto task_lookup_callback = [task_id](gcs::AsyncGcsClient *client, const TaskID &id, @@ -401,13 +397,13 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, ASSERT_EQ(id, task_id); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->raylet_task_table().Lookup(driver_id, task_id, - task_lookup_callback, nullptr)); + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, task_lookup_callback, + nullptr)); } if (ids.size() == 1) { - client->raylet_task_table().Delete(driver_id, ids[0]); + client->raylet_task_table().Delete(job_id, ids[0]); } else { - client->raylet_task_table().Delete(driver_id, ids); + client->raylet_task_table().Delete(job_id, ids); } auto expected_failure_callback = [](AsyncGcsClient *client, const TaskID &id) { ASSERT_TRUE(true); @@ -416,17 +412,17 @@ void TestDeleteKeysFromTable(const DriverID &driver_id, auto undesired_callback = [](gcs::AsyncGcsClient *client, const TaskID &id, const TaskTableData &data) { ASSERT_TRUE(false); }; for (size_t i = 0; i < ids.size(); ++i) { - RAY_CHECK_OK(client->raylet_task_table().Lookup( - driver_id, task_id, undesired_callback, expected_failure_callback)); + RAY_CHECK_OK(client->raylet_task_table().Lookup(job_id, task_id, undesired_callback, + expected_failure_callback)); } if (stop_at_end) { auto stop_callback = [](AsyncGcsClient *client, const TaskID &id) { test->Stop(); }; RAY_CHECK_OK( - client->raylet_task_table().Lookup(driver_id, ids[0], nullptr, stop_callback)); + client->raylet_task_table().Lookup(job_id, ids[0], nullptr, stop_callback)); } } -void TestDeleteKeysFromSet(const DriverID &driver_id, +void TestDeleteKeysFromSet(const JobID &job_id, std::shared_ptr client, std::vector> &data_vector) { std::vector ids; @@ -441,7 +437,7 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ASSERT_EQ(data->manager(), d.manager()); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, add_callback)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, add_callback)); } for (const auto &object_id : ids) { // Check that lookup returns the added object entries. @@ -452,12 +448,12 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ASSERT_EQ(data.size(), 1); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); } if (ids.size() == 1) { - client->object_table().Delete(driver_id, ids[0]); + client->object_table().Delete(job_id, ids[0]); } else { - client->object_table().Delete(driver_id, ids); + client->object_table().Delete(job_id, ids); } for (const auto &object_id : ids) { auto lookup_callback = [object_id](gcs::AsyncGcsClient *client, const ObjectID &id, @@ -466,13 +462,12 @@ void TestDeleteKeysFromSet(const DriverID &driver_id, ASSERT_TRUE(data.size() == 0); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->object_table().Lookup(driver_id, object_id, lookup_callback)); + RAY_CHECK_OK(client->object_table().Lookup(job_id, object_id, lookup_callback)); } } // Test delete function for keys of Log or Table. -void TestDeleteKeys(const DriverID &driver_id, - std::shared_ptr client) { +void TestDeleteKeys(const JobID &job_id, std::shared_ptr client) { // Test delete function for keys of Log. std::vector> task_reconstruction_vector; auto AppendTaskReconstructionData = [&task_reconstruction_vector](size_t add_count) { @@ -485,7 +480,7 @@ void TestDeleteKeys(const DriverID &driver_id, // Test one element case. AppendTaskReconstructionData(1); ASSERT_EQ(task_reconstruction_vector.size(), 1); - TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test the case for more than one elements and less than // maximum_gcs_deletion_batch_size. AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / @@ -493,14 +488,14 @@ void TestDeleteKeys(const DriverID &driver_id, ASSERT_GT(task_reconstruction_vector.size(), 1); ASSERT_LT(task_reconstruction_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test the case for more than maximum_gcs_deletion_batch_size. // The Delete function will split the data into two commands. AppendTaskReconstructionData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(task_reconstruction_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromLog(driver_id, client, task_reconstruction_vector); + TestDeleteKeysFromLog(job_id, client, task_reconstruction_vector); // Test delete function for keys of Table. std::vector> task_vector; @@ -513,16 +508,16 @@ void TestDeleteKeys(const DriverID &driver_id, }; AppendTaskData(1); ASSERT_EQ(task_vector.size(), 1); - TestDeleteKeysFromTable(driver_id, client, task_vector, false); + TestDeleteKeysFromTable(job_id, client, task_vector, false); AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(task_vector.size(), 1); ASSERT_LT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromTable(driver_id, client, task_vector, false); + TestDeleteKeysFromTable(job_id, client, task_vector, false); AppendTaskData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(task_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromTable(driver_id, client, task_vector, true); + TestDeleteKeysFromTable(job_id, client, task_vector, true); test->Start(); ASSERT_GT(test->NumCallbacks(), @@ -540,76 +535,75 @@ void TestDeleteKeys(const DriverID &driver_id, // Test one element case. AppendObjectData(1); ASSERT_EQ(object_vector.size(), 1); - TestDeleteKeysFromSet(driver_id, client, object_vector); + TestDeleteKeysFromSet(job_id, client, object_vector); // Test the case for more than one elements and less than // maximum_gcs_deletion_batch_size. AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(object_vector.size(), 1); ASSERT_LT(object_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromSet(driver_id, client, object_vector); + TestDeleteKeysFromSet(job_id, client, object_vector); // Test the case for more than maximum_gcs_deletion_batch_size. // The Delete function will split the data into two commands. AppendObjectData(RayConfig::instance().maximum_gcs_deletion_batch_size() / 2); ASSERT_GT(object_vector.size(), RayConfig::instance().maximum_gcs_deletion_batch_size()); - TestDeleteKeysFromSet(driver_id, client, object_vector); + TestDeleteKeysFromSet(job_id, client, object_vector); } TEST_F(TestGcsWithAsio, TestDeleteKey) { test = this; - TestDeleteKeys(driver_id_, client_); + TestDeleteKeys(job_id_, client_); } -void TestLogSubscribeAll(const DriverID &driver_id, +void TestLogSubscribeAll(const JobID &job_id, std::shared_ptr client) { - std::vector driver_ids; + std::vector job_ids; for (int i = 0; i < 3; i++) { - driver_ids.emplace_back(DriverID::FromRandom()); + job_ids.emplace_back(JobID::FromRandom()); } // Callback for a notification. - auto notification_callback = [driver_ids](gcs::AsyncGcsClient *client, - const DriverID &id, - const std::vector data) { - ASSERT_EQ(id, driver_ids[test->NumCallbacks()]); + auto notification_callback = [job_ids](gcs::AsyncGcsClient *client, const JobID &id, + const std::vector data) { + ASSERT_EQ(id, job_ids[test->NumCallbacks()]); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids[test->NumCallbacks()].Binary()); + ASSERT_EQ(entry.job_id(), job_ids[test->NumCallbacks()].Binary()); test->IncrementNumCallbacks(); } - if (test->NumCallbacks() == driver_ids.size()) { + if (test->NumCallbacks() == job_ids.size()) { test->Stop(); } }; // Callback for subscription success. We are guaranteed to receive // notifications after this is called. - auto subscribe_callback = [driver_ids](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_ids](gcs::AsyncGcsClient *client) { // We have subscribed. Do the writes to the table. - for (size_t i = 0; i < driver_ids.size(); i++) { - RAY_CHECK_OK(client->driver_table().AppendDriverData(driver_ids[i], false)); + for (size_t i = 0; i < job_ids.size(); i++) { + RAY_CHECK_OK(client->job_table().AppendJobData(job_ids[i], false)); } }; // Subscribe to all driver table notifications. Once we have successfully // subscribed, we will append to the key several times and check that we get // notified for each. - RAY_CHECK_OK(client->driver_table().Subscribe( - driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + RAY_CHECK_OK(client->job_table().Subscribe(job_id, ClientID::Nil(), + notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called (or an assertion failure). test->Start(); // Check that we received one notification callback for each write. - ASSERT_EQ(test->NumCallbacks(), driver_ids.size()); + ASSERT_EQ(test->NumCallbacks(), job_ids.size()); } TEST_F(TestGcsWithAsio, TestLogSubscribeAll) { test = this; - TestLogSubscribeAll(driver_id_, client_); + TestLogSubscribeAll(job_id_, client_); } -void TestSetSubscribeAll(const DriverID &driver_id, +void TestSetSubscribeAll(const JobID &job_id, std::shared_ptr client) { std::vector object_ids; for (int i = 0; i < 3; i++) { @@ -640,8 +634,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Callback for subscription success. We are guaranteed to receive // notifications after this is called. - auto subscribe_callback = [driver_id, object_ids, - managers](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, object_ids, managers](gcs::AsyncGcsClient *client) { // We have subscribed. Do the writes to the table. for (size_t i = 0; i < object_ids.size(); i++) { for (size_t j = 0; j < managers.size(); j++) { @@ -650,8 +643,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, for (int k = 0; k < 3; k++) { // Add the same entry several times. // Expect no notification if the entry already exists. - RAY_CHECK_OK( - client->object_table().Add(driver_id, object_ids[i], data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_ids[i], data, nullptr)); } } } @@ -663,7 +655,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // Remove the same entry several times. // Expect no notification if the entry doesn't exist. RAY_CHECK_OK( - client->object_table().Remove(driver_id, object_ids[i], data, nullptr)); + client->object_table().Remove(job_id, object_ids[i], data, nullptr)); } } } @@ -673,7 +665,7 @@ void TestSetSubscribeAll(const DriverID &driver_id, // subscribed, we will append to the key several times and check that we get // notified for each. RAY_CHECK_OK(client->object_table().Subscribe( - driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + job_id, ClientID::Nil(), notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called (or an assertion failure). @@ -684,10 +676,10 @@ void TestSetSubscribeAll(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestSetSubscribeAll) { test = this; - TestSetSubscribeAll(driver_id_, client_); + TestSetSubscribeAll(job_id_, client_); } -void TestTableSubscribeId(const DriverID &driver_id, +void TestTableSubscribeId(const JobID &job_id, std::shared_ptr client) { // Add a table entry. TaskID task_id1 = TaskID::FromRandom(); @@ -724,29 +716,29 @@ void TestTableSubscribeId(const DriverID &driver_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [driver_id, task_id1, task_id2, task_specs1, + auto subscribe_callback = [job_id, task_id1, task_id2, task_specs1, task_specs2](gcs::AsyncGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - driver_id, task_id2, client->client_table().GetLocalClientId())); + job_id, task_id2, client->client_table().GetLocalClientId())); // Write both keys. We should only receive notifications for the key that // we requested them for. for (const auto &task_spec : task_specs1) { auto data = std::make_shared(); data->set_task(task_spec); - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id1, data, nullptr)); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id1, data, nullptr)); } for (const auto &task_spec : task_specs2) { auto data = std::make_shared(); data->set_task(task_spec); - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id2, data, nullptr)); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id2, data, nullptr)); } }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. RAY_CHECK_OK(client->raylet_task_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, + job_id, client->client_table().GetLocalClientId(), notification_callback, failure_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. @@ -764,95 +756,95 @@ TEST_MACRO(TestGcsWithAsio, TestTableSubscribeId); TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); #endif -void TestLogSubscribeId(const DriverID &driver_id, +void TestLogSubscribeId(const JobID &job_id, std::shared_ptr client) { // Add a log entry. - DriverID driver_id1 = DriverID::FromRandom(); - std::vector driver_ids1 = {"abc", "def", "ghi"}; - auto data1 = std::make_shared(); - data1->set_driver_id(driver_ids1[0]); - RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data1, nullptr)); + JobID job_id1 = JobID::FromRandom(); + std::vector job_ids1 = {"abc", "def", "ghi"}; + auto data1 = std::make_shared(); + data1->set_job_id(job_ids1[0]); + RAY_CHECK_OK(client->job_table().Append(job_id, job_id1, data1, nullptr)); // Add a log entry at a second key. - DriverID driver_id2 = DriverID::FromRandom(); - std::vector driver_ids2 = {"jkl", "mno", "pqr"}; - auto data2 = std::make_shared(); - data2->set_driver_id(driver_ids2[0]); - RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data2, nullptr)); + JobID job_id2 = JobID::FromRandom(); + std::vector job_ids2 = {"jkl", "mno", "pqr"}; + auto data2 = std::make_shared(); + data2->set_job_id(job_ids2[0]); + RAY_CHECK_OK(client->job_table().Append(job_id, job_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto notification_callback = [driver_id2, driver_ids2]( + auto notification_callback = [job_id2, job_ids2]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { + const std::vector &data) { // Check that we only get notifications for the requested key. - ASSERT_EQ(id, driver_id2); + ASSERT_EQ(id, job_id2); // Check that we get notifications in the same order as the writes. for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids2[test->NumCallbacks()]); + ASSERT_EQ(entry.job_id(), job_ids2[test->NumCallbacks()]); test->IncrementNumCallbacks(); } - if (test->NumCallbacks() == driver_ids2.size()) { + if (test->NumCallbacks() == job_ids2.size()) { test->Stop(); } }; // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [driver_id, driver_id1, driver_id2, driver_ids1, - driver_ids2](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, job_id1, job_id2, job_ids1, + job_ids2](gcs::AsyncGcsClient *client) { // Request notifications for one of the keys. - RAY_CHECK_OK(client->driver_table().RequestNotifications( - driver_id, driver_id2, client->client_table().GetLocalClientId())); + RAY_CHECK_OK(client->job_table().RequestNotifications( + job_id, job_id2, client->client_table().GetLocalClientId())); // Write both keys. We should only receive notifications for the key that // we requested them for. - auto remaining = std::vector(++driver_ids1.begin(), driver_ids1.end()); - for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->set_driver_id(driver_id_it); - RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id1, data, nullptr)); + auto remaining = std::vector(++job_ids1.begin(), job_ids1.end()); + for (const auto &job_id_it : remaining) { + auto data = std::make_shared(); + data->set_job_id(job_id_it); + RAY_CHECK_OK(client->job_table().Append(job_id, job_id1, data, nullptr)); } - remaining = std::vector(++driver_ids2.begin(), driver_ids2.end()); - for (const auto &driver_id_it : remaining) { - auto data = std::make_shared(); - data->set_driver_id(driver_id_it); - RAY_CHECK_OK(client->driver_table().Append(driver_id, driver_id2, data, nullptr)); + remaining = std::vector(++job_ids2.begin(), job_ids2.end()); + for (const auto &job_id_it : remaining) { + auto data = std::make_shared(); + data->set_job_id(job_id_it); + RAY_CHECK_OK(client->job_table().Append(job_id, job_id2, data, nullptr)); } }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK(client->driver_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, - subscribe_callback)); + RAY_CHECK_OK(client->job_table().Subscribe(job_id, + client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); // Check that we received one notification callback for each write to the // requested key. - ASSERT_EQ(test->NumCallbacks(), driver_ids2.size()); + ASSERT_EQ(test->NumCallbacks(), job_ids2.size()); } TEST_F(TestGcsWithAsio, TestLogSubscribeId) { test = this; - TestLogSubscribeId(driver_id_, client_); + TestLogSubscribeId(job_id_, client_); } -void TestSetSubscribeId(const DriverID &driver_id, +void TestSetSubscribeId(const JobID &job_id, std::shared_ptr client) { // Add a set entry. ObjectID object_id1 = ObjectID::FromRandom(); std::vector managers1 = {"abc", "def", "ghi"}; auto data1 = std::make_shared(); data1->set_manager(managers1[0]); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data1, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data1, nullptr)); // Add a set entry at a second key. ObjectID object_id2 = ObjectID::FromRandom(); std::vector managers2 = {"jkl", "mno", "pqr"}; auto data2 = std::make_shared(); data2->set_manager(managers2[0]); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data2, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data2, nullptr)); // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. @@ -875,32 +867,32 @@ void TestSetSubscribeId(const DriverID &driver_id, // The callback for subscription success. Once we've subscribed, request // notifications for only one of the keys, then write to both keys. - auto subscribe_callback = [driver_id, object_id1, object_id2, managers1, + auto subscribe_callback = [job_id, object_id1, object_id2, managers1, managers2](gcs::AsyncGcsClient *client) { // Request notifications for one of the keys. RAY_CHECK_OK(client->object_table().RequestNotifications( - driver_id, object_id2, client->client_table().GetLocalClientId())); + job_id, object_id2, client->client_table().GetLocalClientId())); // Write both keys. We should only receive notifications for the key that // we requested them for. auto remaining = std::vector(++managers1.begin(), managers1.end()); for (const auto &manager : remaining) { auto data = std::make_shared(); data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id1, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id1, data, nullptr)); } remaining = std::vector(++managers2.begin(), managers2.end()); for (const auto &manager : remaining) { auto data = std::make_shared(); data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id2, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id2, data, nullptr)); } }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK(client->object_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, - subscribe_callback)); + RAY_CHECK_OK( + client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -911,17 +903,17 @@ void TestSetSubscribeId(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestSetSubscribeId) { test = this; - TestSetSubscribeId(driver_id_, client_); + TestSetSubscribeId(job_id_, client_); } -void TestTableSubscribeCancel(const DriverID &driver_id, +void TestTableSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a table entry. TaskID task_id = TaskID::FromRandom(); std::vector task_specs = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->set_task(task_specs[0]); - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); // The failure callback should not be called since all keys are non-empty // when notifications are requested. @@ -950,32 +942,31 @@ void TestTableSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [driver_id, task_id, - task_specs](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, task_id, task_specs](gcs::AsyncGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - driver_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId())); RAY_CHECK_OK(client->raylet_task_table().CancelNotifications( - driver_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId())); // Write to the key. Since we canceled notifications, we should not receive // a notification for these writes. auto remaining = std::vector(++task_specs.begin(), task_specs.end()); for (const auto &task_spec : remaining) { auto data = std::make_shared(); data->set_task(task_spec); - RAY_CHECK_OK(client->raylet_task_table().Add(driver_id, task_id, data, nullptr)); + RAY_CHECK_OK(client->raylet_task_table().Add(job_id, task_id, data, nullptr)); } // Request notifications again. We should receive a notification for the // current value at the key. RAY_CHECK_OK(client->raylet_task_table().RequestNotifications( - driver_id, task_id, client->client_table().GetLocalClientId())); + job_id, task_id, client->client_table().GetLocalClientId())); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. RAY_CHECK_OK(client->raylet_task_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, + job_id, client->client_table().GetLocalClientId(), notification_callback, failure_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. @@ -990,87 +981,86 @@ TEST_MACRO(TestGcsWithAsio, TestTableSubscribeCancel); TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); #endif -void TestLogSubscribeCancel(const DriverID &driver_id, +void TestLogSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a log entry. - DriverID random_driver_id = DriverID::FromRandom(); - std::vector driver_ids = {"jkl", "mno", "pqr"}; - auto data = std::make_shared(); - data->set_driver_id(driver_ids[0]); - RAY_CHECK_OK(client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); + JobID random_job_id = JobID::FromRandom(); + std::vector job_ids = {"jkl", "mno", "pqr"}; + auto data = std::make_shared(); + data->set_job_id(job_ids[0]); + RAY_CHECK_OK(client->job_table().Append(job_id, random_job_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. - auto notification_callback = [random_driver_id, driver_ids]( + auto notification_callback = [random_job_id, job_ids]( gcs::AsyncGcsClient *client, const UniqueID &id, - const std::vector &data) { - ASSERT_EQ(id, random_driver_id); + const std::vector &data) { + ASSERT_EQ(id, random_job_id); // Check that we get a duplicate notification for the first write. We get a // duplicate notification because the log is append-only and notifications // are canceled after the first write, then requested again. - auto driver_ids_copy = driver_ids; - driver_ids_copy.insert(driver_ids_copy.begin(), driver_ids_copy.front()); + auto job_ids_copy = job_ids; + job_ids_copy.insert(job_ids_copy.begin(), job_ids_copy.front()); for (const auto &entry : data) { - ASSERT_EQ(entry.driver_id(), driver_ids_copy[test->NumCallbacks()]); + ASSERT_EQ(entry.job_id(), job_ids_copy[test->NumCallbacks()]); test->IncrementNumCallbacks(); } - if (test->NumCallbacks() == driver_ids_copy.size()) { + if (test->NumCallbacks() == job_ids_copy.size()) { test->Stop(); } }; // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [driver_id, random_driver_id, - driver_ids](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, random_job_id, + job_ids](gcs::AsyncGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. - RAY_CHECK_OK(client->driver_table().RequestNotifications( - driver_id, random_driver_id, client->client_table().GetLocalClientId())); - RAY_CHECK_OK(client->driver_table().CancelNotifications( - driver_id, random_driver_id, client->client_table().GetLocalClientId())); + RAY_CHECK_OK(client->job_table().RequestNotifications( + job_id, random_job_id, client->client_table().GetLocalClientId())); + RAY_CHECK_OK(client->job_table().CancelNotifications( + job_id, random_job_id, client->client_table().GetLocalClientId())); // Append to the key. Since we canceled notifications, we should not // receive a notification for these writes. - auto remaining = std::vector(++driver_ids.begin(), driver_ids.end()); - for (const auto &remaining_driver_id : remaining) { - auto data = std::make_shared(); - data->set_driver_id(remaining_driver_id); - RAY_CHECK_OK( - client->driver_table().Append(driver_id, random_driver_id, data, nullptr)); + auto remaining = std::vector(++job_ids.begin(), job_ids.end()); + for (const auto &remaining_job_id : remaining) { + auto data = std::make_shared(); + data->set_job_id(remaining_job_id); + RAY_CHECK_OK(client->job_table().Append(job_id, random_job_id, data, nullptr)); } // Request notifications again. We should receive a notification for the // current values at the key. - RAY_CHECK_OK(client->driver_table().RequestNotifications( - driver_id, random_driver_id, client->client_table().GetLocalClientId())); + RAY_CHECK_OK(client->job_table().RequestNotifications( + job_id, random_job_id, client->client_table().GetLocalClientId())); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK(client->driver_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, - subscribe_callback)); + RAY_CHECK_OK(client->job_table().Subscribe(job_id, + client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); // Check that we received a notification callback for the first append to the // key, then a notification for all of the appends, because we cancel // notifications in between. - ASSERT_EQ(test->NumCallbacks(), driver_ids.size() + 1); + ASSERT_EQ(test->NumCallbacks(), job_ids.size() + 1); } TEST_F(TestGcsWithAsio, TestLogSubscribeCancel) { test = this; - TestLogSubscribeCancel(driver_id_, client_); + TestLogSubscribeCancel(job_id_, client_); } -void TestSetSubscribeCancel(const DriverID &driver_id, +void TestSetSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a set entry. ObjectID object_id = ObjectID::FromRandom(); std::vector managers = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->set_manager(managers[0]); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. @@ -1105,33 +1095,32 @@ void TestSetSubscribeCancel(const DriverID &driver_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. - auto subscribe_callback = [driver_id, object_id, - managers](gcs::AsyncGcsClient *client) { + auto subscribe_callback = [job_id, object_id, managers](gcs::AsyncGcsClient *client) { // Request notifications, then cancel immediately. We should receive a // notification for the current value at the key. RAY_CHECK_OK(client->object_table().RequestNotifications( - driver_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId())); RAY_CHECK_OK(client->object_table().CancelNotifications( - driver_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId())); // Add to the key. Since we canceled notifications, we should not // receive a notification for these writes. auto remaining = std::vector(++managers.begin(), managers.end()); for (const auto &manager : remaining) { auto data = std::make_shared(); data->set_manager(manager); - RAY_CHECK_OK(client->object_table().Add(driver_id, object_id, data, nullptr)); + RAY_CHECK_OK(client->object_table().Add(job_id, object_id, data, nullptr)); } // Request notifications again. We should receive a notification for the // current values at the key. RAY_CHECK_OK(client->object_table().RequestNotifications( - driver_id, object_id, client->client_table().GetLocalClientId())); + job_id, object_id, client->client_table().GetLocalClientId())); }; // Subscribe to notifications for this client. This allows us to request and // receive notifications for specific keys. - RAY_CHECK_OK(client->object_table().Subscribe( - driver_id, client->client_table().GetLocalClientId(), notification_callback, - subscribe_callback)); + RAY_CHECK_OK( + client->object_table().Subscribe(job_id, client->client_table().GetLocalClientId(), + notification_callback, subscribe_callback)); // Run the event loop. The loop will only stop if the registered subscription // callback is called for the requested key. test->Start(); @@ -1143,7 +1132,7 @@ void TestSetSubscribeCancel(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestSetSubscribeCancel) { test = this; - TestSetSubscribeCancel(driver_id_, client_); + TestSetSubscribeCancel(job_id_, client_); } void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client_id, @@ -1160,7 +1149,7 @@ void ClientTableNotification(gcs::AsyncGcsClient *client, const ClientID &client ASSERT_EQ(cached_client.entry_type() == ClientTableData::INSERTION, is_insertion); } -void TestClientTableConnect(const DriverID &driver_id, +void TestClientTableConnect(const JobID &job_id, std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. @@ -1182,10 +1171,10 @@ void TestClientTableConnect(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestClientTableConnect) { test = this; - TestClientTableConnect(driver_id_, client_); + TestClientTableConnect(job_id_, client_); } -void TestClientTableDisconnect(const DriverID &driver_id, +void TestClientTableDisconnect(const JobID &job_id, std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. @@ -1213,10 +1202,10 @@ void TestClientTableDisconnect(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestClientTableDisconnect) { test = this; - TestClientTableDisconnect(driver_id_, client_); + TestClientTableDisconnect(job_id_, client_); } -void TestClientTableImmediateDisconnect(const DriverID &driver_id, +void TestClientTableImmediateDisconnect(const JobID &job_id, std::shared_ptr client) { // Register callbacks for when a client gets added and removed. The latter // event will stop the event loop. @@ -1242,10 +1231,10 @@ void TestClientTableImmediateDisconnect(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestClientTableImmediateDisconnect) { test = this; - TestClientTableImmediateDisconnect(driver_id_, client_); + TestClientTableImmediateDisconnect(job_id_, client_); } -void TestClientTableMarkDisconnected(const DriverID &driver_id, +void TestClientTableMarkDisconnected(const JobID &job_id, std::shared_ptr client) { ClientTableData local_client_info = client->client_table().GetLocalClient(); local_client_info.set_node_manager_address("127.0.0.1"); @@ -1269,11 +1258,10 @@ void TestClientTableMarkDisconnected(const DriverID &driver_id, TEST_F(TestGcsWithAsio, TestClientTableMarkDisconnected) { test = this; - TestClientTableMarkDisconnected(driver_id_, client_); + TestClientTableMarkDisconnected(job_id_, client_); } -void TestHashTable(const DriverID &driver_id, - std::shared_ptr client) { +void TestHashTable(const JobID &job_id, std::shared_ptr client) { const int expected_count = 14; ClientID client_id = ClientID::FromRandom(); // Prepare the first resource map: data_map1. @@ -1343,9 +1331,9 @@ void TestHashTable(const DriverID &driver_id, }; // Step 0: Subscribe the change of the hash table. RAY_CHECK_OK(client->resource_table().Subscribe( - driver_id, ClientID::Nil(), notification_callback, subscribe_callback)); + job_id, ClientID::Nil(), notification_callback, subscribe_callback)); RAY_CHECK_OK(client->resource_table().RequestNotifications( - driver_id, client_id, client->client_table().GetLocalClientId())); + job_id, client_id, client->client_table().GetLocalClientId())); // Step 1: Add elements to the hash table. auto update_callback1 = [data_map1, compare_test]( @@ -1355,24 +1343,24 @@ void TestHashTable(const DriverID &driver_id, test->IncrementNumCallbacks(); }; RAY_CHECK_OK( - client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + client->resource_table().Update(job_id, client_id, data_map1, update_callback1)); auto lookup_callback1 = [data_map1, compare_test]( AsyncGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback1)); + RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback1)); // Step 2: Decrease one element, increase one and add a new one. - RAY_CHECK_OK(client->resource_table().Update(driver_id, client_id, data_map2, nullptr)); + RAY_CHECK_OK(client->resource_table().Update(job_id, client_id, data_map2, nullptr)); auto lookup_callback2 = [data_map2, compare_test]( AsyncGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map2, callback_data); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback2)); + RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback2)); std::vector delete_keys({"GPU", "CUSTOM", "None-Existent"}); auto remove_callback = [delete_keys](AsyncGcsClient *client, const ClientID &id, const std::vector &callback_data) { @@ -1382,7 +1370,7 @@ void TestHashTable(const DriverID &driver_id, } test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->resource_table().RemoveEntries(driver_id, client_id, delete_keys, + RAY_CHECK_OK(client->resource_table().RemoveEntries(job_id, client_id, delete_keys, remove_callback)); DynamicResourceTable::DataMap data_map3(data_map2); data_map3.erase("GPU"); @@ -1393,22 +1381,22 @@ void TestHashTable(const DriverID &driver_id, compare_test(data_map3, callback_data); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback3)); + RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback3)); // Step 3: Reset the the resources to data_map1. RAY_CHECK_OK( - client->resource_table().Update(driver_id, client_id, data_map1, update_callback1)); + client->resource_table().Update(job_id, client_id, data_map1, update_callback1)); auto lookup_callback4 = [data_map1, compare_test]( AsyncGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { compare_test(data_map1, callback_data); test->IncrementNumCallbacks(); }; - RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback4)); + RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback4)); // Step 4: Removing all elements will remove the home Hash table from GCS. RAY_CHECK_OK(client->resource_table().RemoveEntries( - driver_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); + job_id, client_id, {"GPU", "CPU", "CUSTOM", "None-Existent"}, nullptr)); auto lookup_callback5 = [](AsyncGcsClient *client, const ClientID &id, const DynamicResourceTable::DataMap &callback_data) { ASSERT_EQ(callback_data.size(), 0); @@ -1418,14 +1406,14 @@ void TestHashTable(const DriverID &driver_id, test->Stop(); } }; - RAY_CHECK_OK(client->resource_table().Lookup(driver_id, client_id, lookup_callback5)); + RAY_CHECK_OK(client->resource_table().Lookup(job_id, client_id, lookup_callback5)); test->Start(); ASSERT_EQ(test->NumCallbacks(), expected_count); } TEST_F(TestGcsWithAsio, TestHashTable) { test = this; - TestHashTable(driver_id_, client_); + TestHashTable(job_id_, client_); } #undef TEST_MACRO diff --git a/src/ray/gcs/format/gcs.fbs b/src/ray/gcs/format/gcs.fbs index c06c79a02..4ab9f385a 100644 --- a/src/ray/gcs/format/gcs.fbs +++ b/src/ray/gcs/format/gcs.fbs @@ -18,8 +18,8 @@ table Arg { } table TaskInfo { - // ID of the driver that created this task. - driver_id: string; + // ID of the job that created this task. + job_id: string; // Task ID of the task. task_id: string; // Task ID of the parent task. diff --git a/src/ray/gcs/tables.cc b/src/ray/gcs/tables.cc index b7c19ebfd..5fc004ee6 100644 --- a/src/ray/gcs/tables.cc +++ b/src/ray/gcs/tables.cc @@ -39,7 +39,7 @@ namespace ray { namespace gcs { template -Status Log::Append(const DriverID &driver_id, const ID &id, +Status Log::Append(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) { num_appends_++; auto callback = [this, id, data, done](const CallbackReply &reply) { @@ -58,7 +58,7 @@ Status Log::Append(const DriverID &driver_id, const ID &id, } template -Status Log::AppendAt(const DriverID &driver_id, const ID &id, +Status Log::AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) { num_appends_++; @@ -81,8 +81,7 @@ Status Log::AppendAt(const DriverID &driver_id, const ID &id, } template -Status Log::Lookup(const DriverID &driver_id, const ID &id, - const Callback &lookup) { +Status Log::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) { num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { if (lookup != nullptr) { @@ -106,7 +105,7 @@ Status Log::Lookup(const DriverID &driver_id, const ID &id, } template -Status Log::Subscribe(const DriverID &driver_id, const ClientID &client_id, +Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const SubscriptionCallback &done) { auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id, @@ -115,11 +114,11 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien RAY_CHECK(change_mode != GcsChangeMode::REMOVE); subscribe(client, id, data); }; - return Subscribe(driver_id, client_id, subscribe_wrapper, done); + return Subscribe(job_id, client_id, subscribe_wrapper, done); } template -Status Log::Subscribe(const DriverID &driver_id, const ClientID &client_id, +Status Log::Subscribe(const JobID &job_id, const ClientID &client_id, const NotificationCallback &subscribe, const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) @@ -160,7 +159,7 @@ Status Log::Subscribe(const DriverID &driver_id, const ClientID &clien } template -Status Log::RequestNotifications(const DriverID &driver_id, const ID &id, +Status Log::RequestNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client requested notifications on a key before Subscribe completed"; @@ -170,7 +169,7 @@ Status Log::RequestNotifications(const DriverID &driver_id, const ID & } template -Status Log::CancelNotifications(const DriverID &driver_id, const ID &id, +Status Log::CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) { RAY_CHECK(subscribe_callback_index_ >= 0) << "Client canceled notifications on a key before Subscribe completed"; @@ -180,7 +179,7 @@ Status Log::CancelNotifications(const DriverID &driver_id, const ID &i } template -void Log::Delete(const DriverID &driver_id, const std::vector &ids) { +void Log::Delete(const JobID &job_id, const std::vector &ids) { if (ids.empty()) { return; } @@ -214,8 +213,8 @@ void Log::Delete(const DriverID &driver_id, const std::vector &ids } template -void Log::Delete(const DriverID &driver_id, const ID &id) { - Delete(driver_id, std::vector({id})); +void Log::Delete(const JobID &job_id, const ID &id) { + Delete(job_id, std::vector({id})); } template @@ -226,7 +225,7 @@ std::string Log::DebugString() const { } template -Status Table::Add(const DriverID &driver_id, const ID &id, +Status Table::Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) { num_adds_++; auto callback = [this, id, data, done](const CallbackReply &reply) { @@ -241,10 +240,10 @@ Status Table::Add(const DriverID &driver_id, const ID &id, } template -Status Table::Lookup(const DriverID &driver_id, const ID &id, - const Callback &lookup, const FailureCallback &failure) { +Status Table::Lookup(const JobID &job_id, const ID &id, const Callback &lookup, + const FailureCallback &failure) { num_lookups_++; - return Log::Lookup(driver_id, id, + return Log::Lookup(job_id, id, [lookup, failure](AsyncGcsClient *client, const ID &id, const std::vector &data) { if (data.empty()) { @@ -261,12 +260,12 @@ Status Table::Lookup(const DriverID &driver_id, const ID &id, } template -Status Table::Subscribe(const DriverID &driver_id, const ClientID &client_id, +Status Table::Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const FailureCallback &failure, const SubscriptionCallback &done) { return Log::Subscribe( - driver_id, client_id, + job_id, client_id, [subscribe, failure](AsyncGcsClient *client, const ID &id, const std::vector &data) { RAY_CHECK(data.empty() || data.size() == 1); @@ -289,8 +288,8 @@ std::string Table::DebugString() const { } template -Status Set::Add(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) { +Status Set::Add(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) { num_adds_++; auto callback = [this, id, data, done](const CallbackReply &reply) { if (done != nullptr) { @@ -303,7 +302,7 @@ Status Set::Add(const DriverID &driver_id, const ID &id, } template -Status Set::Remove(const DriverID &driver_id, const ID &id, +Status Set::Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) { num_removes_++; auto callback = [this, id, data, done](const CallbackReply &reply) { @@ -325,8 +324,8 @@ std::string Set::DebugString() const { } template -Status Hash::Update(const DriverID &driver_id, const ID &id, - const DataMap &data_map, const HashCallback &done) { +Status Hash::Update(const JobID &job_id, const ID &id, const DataMap &data_map, + const HashCallback &done) { num_adds_++; auto callback = [this, id, data_map, done](const CallbackReply &reply) { if (done != nullptr) { @@ -346,7 +345,7 @@ Status Hash::Update(const DriverID &driver_id, const ID &id, } template -Status Hash::RemoveEntries(const DriverID &driver_id, const ID &id, +Status Hash::RemoveEntries(const JobID &job_id, const ID &id, const std::vector &keys, const HashRemoveCallback &remove_callback) { num_removes_++; @@ -375,7 +374,7 @@ std::string Hash::DebugString() const { } template -Status Hash::Lookup(const DriverID &driver_id, const ID &id, +Status Hash::Lookup(const JobID &job_id, const ID &id, const HashCallback &lookup) { num_lookups_++; auto callback = [this, id, lookup](const CallbackReply &reply) { @@ -403,7 +402,7 @@ Status Hash::Lookup(const DriverID &driver_id, const ID &id, } template -Status Hash::Subscribe(const DriverID &driver_id, const ClientID &client_id, +Status Hash::Subscribe(const JobID &job_id, const ClientID &client_id, const HashNotificationCallback &subscribe, const SubscriptionCallback &done) { RAY_CHECK(subscribe_callback_index_ == -1) @@ -450,25 +449,25 @@ Status Hash::Subscribe(const DriverID &driver_id, const ClientID &clie return Status::OK(); } -Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type, +Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type, const std::string &error_message, double timestamp) { auto data = std::make_shared(); - data->set_driver_id(driver_id.Binary()); + data->set_job_id(job_id.Binary()); data->set_type(type); data->set_error_message(error_message); data->set_timestamp(timestamp); - return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); + return Append(job_id, job_id, data, /*done_callback=*/nullptr); } std::string ErrorTable::DebugString() const { - return Log::DebugString(); + return Log::DebugString(); } Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) { // TODO(hchen): Change the parameter to shared_ptr to avoid copying data. auto data = std::make_shared(); data->CopyFrom(profile_events); - return Append(DriverID::Nil(), UniqueID::FromRandom(), data, + return Append(JobID::Nil(), UniqueID::FromRandom(), data, /*done_callback=*/nullptr); } @@ -476,11 +475,11 @@ std::string ProfileTable::DebugString() const { return Log::DebugString(); } -Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) { - auto data = std::make_shared(); - data->set_driver_id(driver_id.Binary()); +Status JobTable::AppendJobData(const JobID &job_id, bool is_dead) { + auto data = std::make_shared(); + data->set_job_id(job_id.Binary()); data->set_is_dead(is_dead); - return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr); + return Append(JobID(job_id), job_id, data, /*done_callback=*/nullptr); } void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) { @@ -694,13 +693,13 @@ Status ClientTable::Connect(const ClientTableData &local_client) { // Callback to request notifications from the client table once we've // successfully subscribed. auto subscription_callback = [this](AsyncGcsClient *c) { - RAY_CHECK_OK(RequestNotifications(DriverID::Nil(), client_log_key_, client_id_)); + RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, client_id_)); }; // Subscribe to the client table. - RAY_CHECK_OK(Subscribe(DriverID::Nil(), client_id_, notification_callback, + RAY_CHECK_OK(Subscribe(JobID::Nil(), client_id_, notification_callback, subscription_callback)); }; - return Append(DriverID::Nil(), client_log_key_, data, add_callback); + return Append(JobID::Nil(), client_log_key_, data, add_callback); } Status ClientTable::Disconnect(const DisconnectCallback &callback) { @@ -709,12 +708,12 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) { auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id, const ClientTableData &data) { HandleConnected(client, data); - RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id)); + RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id)); if (callback != nullptr) { callback(); } }; - RAY_RETURN_NOT_OK(Append(DriverID::Nil(), client_log_key_, data, add_callback)); + RAY_RETURN_NOT_OK(Append(JobID::Nil(), client_log_key_, data, add_callback)); // We successfully added the deletion entry. Mark ourselves as disconnected. disconnected_ = true; return Status::OK(); @@ -724,7 +723,7 @@ ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) { auto data = std::make_shared(); data->set_client_id(dead_client_id.Binary()); data->set_entry_type(ClientTableData::DELETION); - return Append(DriverID::Nil(), client_log_key_, data, nullptr); + return Append(JobID::Nil(), client_log_key_, data, nullptr); } void ClientTable::GetClient(const ClientID &client_id, @@ -744,7 +743,7 @@ const std::unordered_map &ClientTable::GetAllClients( Status ClientTable::Lookup(const Callback &lookup) { RAY_CHECK(lookup != nullptr); - return Log::Lookup(DriverID::Nil(), client_log_key_, lookup); + return Log::Lookup(JobID::Nil(), client_log_key_, lookup); } std::string ClientTable::DebugString() const { @@ -755,10 +754,10 @@ std::string ClientTable::DebugString() const { return result.str(); } -Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, +Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id) { - auto lookup_callback = [this, checkpoint_id, driver_id, actor_id]( + auto lookup_callback = [this, checkpoint_id, job_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id, const ActorCheckpointIdData &data) { std::shared_ptr copy = @@ -772,20 +771,20 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id, RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id; copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin()); copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin()); - client_->actor_checkpoint_table().Delete(driver_id, to_delete); + client_->actor_checkpoint_table().Delete(job_id, to_delete); } - RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr)); + RAY_CHECK_OK(Add(job_id, actor_id, copy, nullptr)); }; - auto failure_callback = [this, checkpoint_id, driver_id, actor_id]( + auto failure_callback = [this, checkpoint_id, job_id, actor_id]( ray::gcs::AsyncGcsClient *client, const UniqueID &id) { std::shared_ptr data = std::make_shared(); data->set_actor_id(id.Binary()); data->add_timestamps(current_sys_time_ms()); *data->add_checkpoint_ids() = checkpoint_id.Binary(); - RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr)); + RAY_CHECK_OK(Add(job_id, actor_id, data, nullptr)); }; - return Lookup(driver_id, actor_id, lookup_callback, failure_callback); + return Lookup(job_id, actor_id, lookup_callback, failure_callback); } template class Log; @@ -797,9 +796,9 @@ template class Log; template class Table; template class Table; template class Table; -template class Log; +template class Log; template class Log; -template class Log; +template class Log; template class Log; template class Table; template class Table; diff --git a/src/ray/gcs/tables.h b/src/ray/gcs/tables.h index 2ecc34408..4984fdde7 100644 --- a/src/ray/gcs/tables.h +++ b/src/ray/gcs/tables.h @@ -24,12 +24,12 @@ using rpc::ActorCheckpointData; using rpc::ActorCheckpointIdData; using rpc::ActorTableData; using rpc::ClientTableData; -using rpc::DriverTableData; using rpc::ErrorTableData; using rpc::GcsChangeMode; using rpc::GcsEntry; using rpc::HeartbeatBatchTableData; using rpc::HeartbeatTableData; +using rpc::JobTableData; using rpc::ObjectTableData; using rpc::ProfileTableData; using rpc::RayResource; @@ -55,9 +55,9 @@ enum class CommandType { kRegular, kChain }; template class PubsubInterface { public: - virtual Status RequestNotifications(const DriverID &driver_id, const ID &id, + virtual Status RequestNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) = 0; - virtual Status CancelNotifications(const DriverID &driver_id, const ID &id, + virtual Status CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id) = 0; virtual ~PubsubInterface(){}; }; @@ -67,9 +67,9 @@ class LogInterface { public: using WriteCallback = std::function; - virtual Status Append(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; - virtual Status AppendAt(const DriverID &driver_id, const ID &task_id, + virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; + virtual Status AppendAt(const JobID &job_id, const ID &task_id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length) = 0; virtual ~LogInterface(){}; @@ -119,20 +119,20 @@ class Log : public LogInterface, virtual public PubsubInterface { /// Append a log entry to a key. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is added to the GCS. /// \param data Data to append to the log. TODO(rkn): This can be made const, /// right? /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Append(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Append a log entry to a key if and only if the log has the given number /// of entries. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is added to the GCS. /// \param data Data to append to the log. /// \param done Callback that is called if the data was appended to the log. @@ -141,25 +141,22 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param log_length The number of entries that the log must have for the /// append to succeed. /// \return Status - Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done, const WriteCallback &failure, int log_length); /// Lookup the log values at a key asynchronously. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is looked up in the GCS. /// \param lookup Callback that is called after lookup. If the callback is /// called with an empty vector, then there was no data at the key. /// \return Status - Status Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup); - + Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup); /// Subscribe to any Append operations to this table. The caller may choose - /// to subscribe to all Appends, or to subscribe only to keys that it /// requests notifications for. This may only be called once per Log - /// instance. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param client_id The type of update to listen to. If this is nil, then a /// message for each Add to the table will be received. Else, only /// messages for the given client will be received. In the latter @@ -170,7 +167,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called when subscription is complete and we /// are ready to receive messages. /// \return Status - Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + Status Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const SubscriptionCallback &done); /// Request notifications about a key in this table. @@ -182,37 +179,37 @@ class Log : public LogInterface, virtual public PubsubInterface { /// notifications can be requested, the caller must first call `Subscribe`, /// with the same `client_id`. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the key to request notifications for. /// \param client_id The client who is requesting notifications. Before /// notifications can be requested, a call to `Subscribe` to this /// table with the same `client_id` must complete successfully. /// \return Status - Status RequestNotifications(const DriverID &driver_id, const ID &id, + Status RequestNotifications(const JobID &job_id, const ID &id, const ClientID &client_id); /// Cancel notifications about a key in this table. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the key to request notifications for. /// \param client_id The client who originally requested notifications. /// \return Status - Status CancelNotifications(const DriverID &driver_id, const ID &id, + Status CancelNotifications(const JobID &job_id, const ID &id, const ClientID &client_id); /// Delete an entire key from redis. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data to delete from the GCS. /// \return Void. - void Delete(const DriverID &driver_id, const ID &id); + void Delete(const JobID &job_id, const ID &id); /// Delete several keys from redis. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param ids The vector of IDs to delete from the GCS. /// \return Void. - void Delete(const DriverID &driver_id, const std::vector &ids); + void Delete(const JobID &job_id, const std::vector &ids); /// Returns debug string for class. /// @@ -232,7 +229,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// an additional parameter change_mode in NotificationCallback. Therefore this /// function supports notifications of remove operations. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param client_id The type of update to listen to. If this is nil, then a /// message for each Add to the table will be received. Else, only /// messages for the given client will be received. In the latter @@ -243,7 +240,7 @@ class Log : public LogInterface, virtual public PubsubInterface { /// \param done Callback that is called when subscription is complete and we /// are ready to receive messages. /// \return Status - Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + Status Subscribe(const JobID &job_id, const ClientID &client_id, const NotificationCallback &subscribe, const SubscriptionCallback &done); @@ -275,8 +272,8 @@ template class TableInterface { public: using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &task_id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual ~TableInterface(){}; }; @@ -312,32 +309,32 @@ class Table : private Log, /// Add an entry to the table. This overwrites any existing data at the key. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is added to the GCS. /// \param data Data that is added to the GCS. /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Lookup an entry asynchronously. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is looked up in the GCS. /// \param lookup Callback that is called after lookup if there was data the /// key. /// \param failure Callback that is called after lookup if there was no data /// at the key. /// \return Status - Status Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup, + Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup, const FailureCallback &failure); /// Subscribe to any Add operations to this table. The caller may choose to /// subscribe to all Adds, or to subscribe only to keys that it requests /// notifications for. This may only be called once per Table instance. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param client_id The type of update to listen to. If this is nil, then a /// message for each Add to the table will be received. Else, only /// messages for the given client will be received. In the latter @@ -350,16 +347,14 @@ class Table : private Log, /// \param done Callback that is called when subscription is complete and we /// are ready to receive messages. /// \return Status - Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + Status Subscribe(const JobID &job_id, const ClientID &client_id, const Callback &subscribe, const FailureCallback &failure, const SubscriptionCallback &done); - void Delete(const DriverID &driver_id, const ID &id) { - Log::Delete(driver_id, id); - } + void Delete(const JobID &job_id, const ID &id) { Log::Delete(job_id, id); } - void Delete(const DriverID &driver_id, const std::vector &ids) { - Log::Delete(driver_id, ids); + void Delete(const JobID &job_id, const std::vector &ids) { + Log::Delete(job_id, ids); } /// Returns debug string for class. @@ -383,10 +378,10 @@ template class SetInterface { public: using WriteCallback = typename Log::WriteCallback; - virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done) = 0; - virtual Status Remove(const DriverID &driver_id, const ID &id, - std::shared_ptr &data, const WriteCallback &done) = 0; + virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, + const WriteCallback &done) = 0; virtual ~SetInterface(){}; }; @@ -419,30 +414,30 @@ class Set : private Log, /// Add an entry to the set. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is added to the GCS. /// \param data Data to add to the set. /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Add(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); /// Remove an entry from the set. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is removed from the GCS. /// \param data Data to remove from the set. /// \param done Callback that is called once the data has been written to the /// GCS. /// \return Status - Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr &data, + Status Remove(const JobID &job_id, const ID &id, std::shared_ptr &data, const WriteCallback &done); - Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + Status Subscribe(const JobID &job_id, const ClientID &client_id, const NotificationCallback &subscribe, const SubscriptionCallback &done) { - return Log::Subscribe(driver_id, client_id, subscribe, done); + return Log::Subscribe(job_id, client_id, subscribe, done); } /// Returns debug string for class. @@ -499,40 +494,40 @@ class HashInterface { /// Add entries of a hash table. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is added to the GCS. /// \param pairs Map data to add to the hash table. /// \param done HashCallback that is called once the request data has been written to /// the GCS. /// \return Status - virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + virtual Status Update(const JobID &job_id, const ID &id, const DataMap &pairs, const HashCallback &done) = 0; /// Remove entries from the hash table. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is removed from the GCS. /// \param keys The entry keys of the hash table. /// \param remove_callback HashRemoveCallback that is called once the data has been /// written to the GCS no matter whether the key exists in the hash table. /// \return Status - virtual Status RemoveEntries(const DriverID &driver_id, const ID &id, + virtual Status RemoveEntries(const JobID &job_id, const ID &id, const std::vector &keys, const HashRemoveCallback &remove_callback) = 0; /// Lookup the map data of a hash table. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param id The ID of the data that is looked up in the GCS. /// \param lookup HashCallback that is called after lookup. If the callback is /// called with an empty hash table, then there was no data in the callback. /// \return Status - virtual Status Lookup(const DriverID &driver_id, const ID &id, + virtual Status Lookup(const JobID &job_id, const ID &id, const HashCallback &lookup) = 0; /// Subscribe to any Update or Remove operations to this hash table. /// - /// \param driver_id The ID of the driver. + /// \param job_id The ID of the job. /// \param client_id The type of update to listen to. If this is nil, then a /// message for each Update to the table will be received. Else, only /// messages for the given client will be received. In the latter @@ -542,7 +537,7 @@ class HashInterface { /// \param done SubscriptionCallback that is called when subscription is complete and /// we are ready to receive messages. /// \return Status - virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + virtual Status Subscribe(const JobID &job_id, const ClientID &client_id, const HashNotificationCallback &subscribe, const SubscriptionCallback &done) = 0; @@ -567,17 +562,16 @@ class Hash : private Log, using Log::RequestNotifications; using Log::CancelNotifications; - Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs, + Status Update(const JobID &job_id, const ID &id, const DataMap &pairs, const HashCallback &done) override; - Status Subscribe(const DriverID &driver_id, const ClientID &client_id, + Status Subscribe(const JobID &job_id, const ClientID &client_id, const HashNotificationCallback &subscribe, const SubscriptionCallback &done) override; - Status Lookup(const DriverID &driver_id, const ID &id, - const HashCallback &lookup) override; + Status Lookup(const JobID &job_id, const ID &id, const HashCallback &lookup) override; - Status RemoveEntries(const DriverID &driver_id, const ID &id, + Status RemoveEntries(const JobID &job_id, const ID &id, const std::vector &keys, const HashRemoveCallback &remove_callback) override; @@ -645,23 +639,23 @@ class HeartbeatBatchTable : public Table { virtual ~HeartbeatBatchTable() {} }; -class DriverTable : public Log { +class JobTable : public Log { public: - DriverTable(const std::vector> &contexts, - AsyncGcsClient *client) + JobTable(const std::vector> &contexts, + AsyncGcsClient *client) : Log(contexts, client) { - pubsub_channel_ = TablePubsub::DRIVER_PUBSUB; - prefix_ = TablePrefix::DRIVER; + pubsub_channel_ = TablePubsub::JOB_PUBSUB; + prefix_ = TablePrefix::JOB; }; - virtual ~DriverTable() {} + virtual ~JobTable() {} - /// Appends driver data to the driver table. + /// Appends job data to the job table. /// - /// \param driver_id The driver id. - /// \param is_dead Whether the driver is dead. + /// \param job_id The job id. + /// \param is_dead Whether the job is dead. /// \return The return status. - Status AppendDriverData(const DriverID &driver_id, bool is_dead); + Status AppendJobData(const JobID &job_id, bool is_dead); }; /// Actor table starts with an ALIVE entry, which represents the first time the actor @@ -697,9 +691,9 @@ class TaskLeaseTable : public Table { prefix_ = TablePrefix::TASK_LEASE; } - Status Add(const DriverID &driver_id, const TaskID &id, - std::shared_ptr &data, const WriteCallback &done) override { - RAY_RETURN_NOT_OK((Table::Add(driver_id, id, data, done))); + Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr &data, + const WriteCallback &done) override { + RAY_RETURN_NOT_OK((Table::Add(job_id, id, data, done))); // Mark the entry for expiration in Redis. It's okay if this command fails // since the lease entry itself contains the expiration period. In the // worst case, if the command fails, then a client that looks up the lease @@ -733,11 +727,11 @@ class ActorCheckpointIdTable : public Table { /// Add a checkpoint id to an actor, and remove a previous checkpoint if the /// total number of checkpoints in GCS exceeds the max allowed value. /// - /// \param driver_id The ID of the job (= driver). + /// \param job_id The ID of the job. /// \param actor_id ID of the actor. /// \param checkpoint_id ID of the checkpoint. /// \return Status. - Status AddCheckpointId(const DriverID &driver_id, const ActorID &actor_id, + Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id, const ActorCheckpointID &checkpoint_id); }; @@ -761,7 +755,7 @@ class TaskTable : public Table { } // namespace raylet -class ErrorTable : private Log { +class ErrorTable : private Log { public: ErrorTable(const std::vector> &contexts, AsyncGcsClient *client) @@ -770,19 +764,20 @@ class ErrorTable : private Log { prefix_ = TablePrefix::ERROR_INFO; }; - /// Push an error message for a specific job. + /// Push an error message for the driver of a specific. /// /// TODO(rkn): We need to make sure that the errors are unique because /// duplicate messages currently cause failures (the GCS doesn't allow it). A /// natural way to do this is to have finer-grained time stamps. /// - /// \param driver_id The ID of the job that generated the error. If the error - /// should be pushed to all jobs, then this should be nil. + /// \param job_id The ID of the job that generated the error. If the error + /// should be pushed to all drivers, then this should be nil. /// \param type The type of the error. /// \param error_message The error message to push. /// \param timestamp The timestamp of the error. /// \return Status. - Status PushErrorToDriver(const DriverID &driver_id, const std::string &type, + // TODO(qwang): refactor this API to implement broadcast. + Status PushErrorToDriver(const JobID &job_id, const std::string &type, const std::string &error_message, double timestamp); /// Returns debug string for class. diff --git a/src/ray/object_manager/object_directory.cc b/src/ray/object_manager/object_directory.cc index 454379d18..5aa598924 100644 --- a/src/ray/object_manager/object_directory.cc +++ b/src/ray/object_manager/object_directory.cc @@ -74,7 +74,7 @@ void ObjectDirectory::RegisterBackend() { } }; RAY_CHECK_OK(gcs_client_->object_table().Subscribe( - DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), + JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), object_notification_callback, nullptr)); } @@ -87,7 +87,7 @@ ray::Status ObjectDirectory::ReportObjectAdded( data->set_manager(client_id.Binary()); data->set_object_size(object_info.data_size); ray::Status status = - gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr); + gcs_client_->object_table().Add(JobID::Nil(), object_id, data, nullptr); return status; } @@ -100,7 +100,7 @@ ray::Status ObjectDirectory::ReportObjectRemoved( data->set_manager(client_id.Binary()); data->set_object_size(object_info.data_size); ray::Status status = - gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr); + gcs_client_->object_table().Remove(JobID::Nil(), object_id, data, nullptr); return status; }; @@ -159,7 +159,7 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i if (it == listeners_.end()) { it = listeners_.emplace(object_id, LocationListenerState()).first; status = gcs_client_->object_table().RequestNotifications( - DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); } auto &listener_state = it->second; // TODO(hme): Make this fatal after implementing Pull suppression. @@ -187,7 +187,7 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback entry->second.callbacks.erase(callback_id); if (entry->second.callbacks.empty()) { status = gcs_client_->object_table().CancelNotifications( - DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); + JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId()); listeners_.erase(entry); } return status; @@ -210,7 +210,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id, // SubscribeObjectLocations call, so look up the object's locations // directly from the GCS. status = gcs_client_->object_table().Lookup( - DriverID::Nil(), object_id, + JobID::Nil(), object_id, [this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id, const std::vector &location_updates) { // Build the set of current locations based on the entries in the log. diff --git a/src/ray/protobuf/gcs.proto b/src/ray/protobuf/gcs.proto index d0b2c5e00..b4de8b9ca 100644 --- a/src/ray/protobuf/gcs.proto +++ b/src/ray/protobuf/gcs.proto @@ -25,7 +25,7 @@ enum TablePrefix { HEARTBEAT = 9; HEARTBEAT_BATCH = 10; ERROR_INFO = 11; - DRIVER = 12; + JOB = 12; PROFILE = 13; TASK_LEASE = 14; ACTOR_CHECKPOINT = 15; @@ -47,7 +47,7 @@ enum TablePubsub { HEARTBEAT_BATCH_PUBSUB = 8; ERROR_INFO_PUBSUB = 9; TASK_LEASE_PUBSUB = 10; - DRIVER_PUBSUB = 11; + JOB_PUBSUB = 11; NODE_RESOURCE_PUBSUB = 12; TABLE_PUBSUB_MAX = 13; } @@ -102,8 +102,8 @@ message ActorTableData { // dies, then this is the object that should be reconstructed for the actor // to be recreated. bytes actor_creation_dummy_object_id = 2; - // The ID of the driver that created the actor. - bytes driver_id = 3; + // The ID of the job that created the actor. + bytes job_id = 3; // The ID of the node manager that created the actor. bytes node_manager_id = 4; // Current state of this actor. @@ -115,8 +115,8 @@ message ActorTableData { } message ErrorTableData { - // The ID of the driver that the error is for. - bytes driver_id = 1; + // The ID of the job that the error is for. + bytes job_id = 1; // The type of the error. string type = 2; // The error message. @@ -222,9 +222,9 @@ message TaskLeaseData { uint64 timeout = 3; } -message DriverTableData { - // The driver ID. - bytes driver_id = 1; +message JobTableData { + // The job ID. + bytes job_id = 1; // Whether it's dead. bool is_dead = 2; } diff --git a/src/ray/raylet/actor_registration.cc b/src/ray/raylet/actor_registration.cc index 7f940006b..afc3a7d89 100644 --- a/src/ray/raylet/actor_registration.cc +++ b/src/ray/raylet/actor_registration.cc @@ -43,8 +43,8 @@ const ObjectID ActorRegistration::GetExecutionDependency() const { return execution_dependency_; } -const DriverID ActorRegistration::GetDriverId() const { - return DriverID::FromBinary(actor_table_data_.driver_id()); +const JobID ActorRegistration::GetJobId() const { + return JobID::FromBinary(actor_table_data_.job_id()); } const int64_t ActorRegistration::GetMaxReconstructions() const { diff --git a/src/ray/raylet/actor_registration.h b/src/ray/raylet/actor_registration.h index 208e49982..7a7ce8e17 100644 --- a/src/ray/raylet/actor_registration.h +++ b/src/ray/raylet/actor_registration.h @@ -73,8 +73,8 @@ class ActorRegistration { /// \return The execution dependency returned by the actor's creation task. const ObjectID GetActorCreationDependency() const; - /// Get actor's driver ID. - const DriverID GetDriverId() const; + /// Get actor's job ID. + const JobID GetJobId() const; /// Get the max number of times this actor should be reconstructed. const int64_t GetMaxReconstructions() const; diff --git a/src/ray/raylet/format/node_manager.fbs b/src/ray/raylet/format/node_manager.fbs index a5b041f29..aba7ab8cf 100644 --- a/src/ray/raylet/format/node_manager.fbs +++ b/src/ray/raylet/format/node_manager.fbs @@ -135,6 +135,7 @@ table RegisterClientRequest { // The process ID of this worker. worker_pid: long; // The driver ID. This is non-nil if the client is a driver. + // TODO(qwang): rename this to driver_task_id. driver_id: string; // Language of this worker. language: Language; @@ -196,7 +197,7 @@ table WaitReply { // This struct is the same as ErrorTableData. table PushErrorRequest { // The ID of the job that the error is for. - driver_id: string; + job_id: string; // The type of the error. type: string; // The error message. diff --git a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc index 319d29d4a..8e7750aae 100644 --- a/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc +++ b/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc @@ -43,12 +43,12 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) { */ JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit( JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker, - jbyteArray driverId) { + jbyteArray jobId) { UniqueIdFromJByteArray worker_id(env, workerId); - UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray job_id(env, jobId); const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE); auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker, - driver_id.GetId(), Language::JAVA); + job_id.GetId(), Language::JAVA); env->ReleaseStringUTFChars(sockName, nativeString); return reinterpret_cast(raylet_client); } @@ -224,13 +224,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject( */ JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId( - JNIEnv *env, jclass, jbyteArray driverId, jbyteArray parentTaskId, + JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId, jint parent_task_counter) { - UniqueIdFromJByteArray driver_id(env, driverId); + UniqueIdFromJByteArray job_id(env, jobId); UniqueIdFromJByteArray parent_task_id(env, parentTaskId); TaskID task_id = - ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter); + ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter); jbyteArray result = env->NewByteArray(task_id.Size()); if (nullptr == result) { return nullptr; diff --git a/src/ray/raylet/lineage_cache.cc b/src/ray/raylet/lineage_cache.cc index 9d700b805..6f6d2cb69 100644 --- a/src/ray/raylet/lineage_cache.cc +++ b/src/ray/raylet/lineage_cache.cc @@ -275,9 +275,8 @@ void LineageCache::FlushTask(const TaskID &task_id) { // TODO(swang): Make this better... auto task_data = std::make_shared(); task_data->set_task(task->TaskData().Serialize()); - RAY_CHECK_OK( - task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()), - task_id, task_data, task_callback)); + RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()), + task_id, task_data, task_callback)); // We successfully wrote the task, so mark it as committing. // TODO(swang): Use a batched interface and write with all object entries. @@ -290,7 +289,7 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) { if (unsubscribed) { // Request notifications for the task if we haven't already requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.RequestNotifications(DriverID::Nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_)); } // Return whether we were previously unsubscribed to this task and are now // subscribed. @@ -303,7 +302,7 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) { if (subscribed) { // Cancel notifications for the task if we previously requested // notifications for it. - RAY_CHECK_OK(task_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_)); + RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); subscribed_tasks_.erase(it); } // Return whether we were previously subscribed to this task and are now diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index 27a231a85..1ecc13599 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -22,7 +22,7 @@ class MockGcs : public gcs::TableInterface, notification_callback_ = notification_callback; } - Status Add(const DriverID &driver_id, const TaskID &task_id, + Status Add(const JobID &job_id, const TaskID &task_id, std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done) { task_table_[task_id] = task_data; @@ -57,10 +57,10 @@ class MockGcs : public gcs::TableInterface, notification_callback_(client, task_id, data); } }; - return Add(DriverID::Nil(), task_id, task_data, callback); + return Add(JobID::Nil(), task_id, task_data, callback); } - Status RequestNotifications(const DriverID &driver_id, const TaskID &task_id, + Status RequestNotifications(const JobID &job_id, const TaskID &task_id, const ClientID &client_id) { subscribed_tasks_.insert(task_id); if (task_table_.count(task_id) == 1) { @@ -70,7 +70,7 @@ class MockGcs : public gcs::TableInterface, return ray::Status::OK(); } - Status CancelNotifications(const DriverID &driver_id, const TaskID &task_id, + Status CancelNotifications(const JobID &job_id, const TaskID &task_id, const ClientID &client_id) { subscribed_tasks_.erase(task_id); return ray::Status::OK(); @@ -133,7 +133,7 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments, + auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments, num_returns, required_resources, Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); diff --git a/src/ray/raylet/monitor.cc b/src/ray/raylet/monitor.cc index 0a8532608..d26f33cba 100644 --- a/src/ray/raylet/monitor.cc +++ b/src/ray/raylet/monitor.cc @@ -35,7 +35,7 @@ void Monitor::Start() { HandleHeartbeat(id, heartbeat_data); }; RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe( - DriverID::Nil(), ClientID::Nil(), heartbeat_callback, nullptr, nullptr)); + JobID::Nil(), ClientID::Nil(), heartbeat_callback, nullptr, nullptr)); Tick(); } @@ -68,9 +68,9 @@ void Monitor::Tick() { error_message << "The node with client ID " << client_id << " has been marked dead because the monitor" << " has missed too many heartbeats from it."; - // We use the nil DriverID to broadcast the message to all drivers. + // We use the nil JobID to broadcast the message to all drivers. RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver( - DriverID::Nil(), type, error_message.str(), current_time_ms())); + JobID::Nil(), type, error_message.str(), current_time_ms())); } }; RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback)); @@ -88,7 +88,7 @@ void Monitor::Tick() { for (const auto &heartbeat : heartbeat_buffer_) { batch->add_batch()->CopyFrom(heartbeat.second); } - RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(), + RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(JobID::Nil(), ClientID::Nil(), batch, nullptr)); heartbeat_buffer_.clear(); } diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 106c83832..8e2cdf684 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -136,7 +136,7 @@ ray::Status NodeManager::RegisterGcs() { lineage_cache_.HandleEntryCommitted(task_id); }; RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe( - DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), + JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), task_committed_callback, nullptr, nullptr)); const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client, @@ -160,7 +160,7 @@ ray::Status NodeManager::RegisterGcs() { reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0); }; RAY_RETURN_NOT_OK(gcs_client_->task_lease_table().Subscribe( - DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), + JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), task_lease_notification_callback, task_lease_empty_callback, nullptr)); // Register a callback to handle actor notifications. @@ -175,7 +175,7 @@ ray::Status NodeManager::RegisterGcs() { }; RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe( - DriverID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr)); + JobID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr)); // Register a callback on the client table for new clients. auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id, @@ -210,18 +210,17 @@ ray::Status NodeManager::RegisterGcs() { HeartbeatBatchAdded(heartbeat_batch); }; RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe( - DriverID::Nil(), ClientID::Nil(), heartbeat_batch_added, + JobID::Nil(), ClientID::Nil(), heartbeat_batch_added, /*subscribe_callback=*/nullptr, /*done_callback=*/nullptr)); // Subscribe to driver table updates. - const auto driver_table_handler = - [this](gcs::AsyncGcsClient *client, const DriverID &client_id, - const std::vector &driver_data) { - HandleDriverTableUpdate(client_id, driver_data); - }; - RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe( - DriverID::Nil(), ClientID::Nil(), driver_table_handler, nullptr)); + const auto job_table_handler = [this](gcs::AsyncGcsClient *client, const JobID &job_id, + const std::vector &job_data) { + HandleJobTableUpdate(job_id, job_data); + }; + RAY_RETURN_NOT_OK(gcs_client_->job_table().Subscribe(JobID::Nil(), ClientID::Nil(), + job_table_handler, nullptr)); // Start sending heartbeats to the GCS. last_heartbeat_at_ms_ = current_time_ms(); @@ -252,14 +251,14 @@ void NodeManager::KillWorker(std::shared_ptr worker) { }); } -void NodeManager::HandleDriverTableUpdate( - const DriverID &id, const std::vector &driver_data) { - for (const auto &entry : driver_data) { - RAY_LOG(DEBUG) << "HandleDriverTableUpdate " - << UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead(); +void NodeManager::HandleJobTableUpdate(const JobID &id, + const std::vector &job_data) { + for (const auto &entry : job_data) { + RAY_LOG(DEBUG) << "HandleJobTableUpdate " << UniqueID::FromBinary(entry.job_id()) + << " " << entry.is_dead(); if (entry.is_dead()) { - auto driver_id = DriverID::FromBinary(entry.driver_id()); - auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id); + auto job_id = JobID::FromBinary(entry.job_id()); + auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id); // Kill all the workers. The actual cleanup for these workers is done // later when we receive the DisconnectClient message from them. @@ -271,11 +270,11 @@ void NodeManager::HandleDriverTableUpdate( KillWorker(worker); } - // Remove all tasks for this driver from the scheduling queues, mark + // Remove all tasks for this job from the scheduling queues, mark // the results for these tasks as not required, cancel any attempts // at reconstruction. Note that at this time the workers are likely // alive because of the delay in killing workers. - CleanUpTasksForDeadDriver(driver_id); + CleanUpTasksForFinishedJob(job_id); } } } @@ -313,7 +312,7 @@ void NodeManager::Heartbeat() { } ray::Status status = heartbeat_table.Add( - DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, + JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data, /*success_callback=*/nullptr); RAY_CHECK_OK_PREPEND(status, "Heartbeat failed"); @@ -605,7 +604,7 @@ void NodeManager::PublishActorStateTransition( RAY_CHECK_OK(redis_context->RunArgvAsync(args)); } }; - RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(DriverID::Nil(), actor_id, + RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(JobID::Nil(), actor_id, actor_notification, success_callback, failure_callback, log_length)); } @@ -690,8 +689,8 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id, } } -void NodeManager::CleanUpTasksForDeadDriver(const DriverID &driver_id) { - auto tasks_to_remove = local_queues_.GetTaskIdsForDriver(driver_id); +void NodeManager::CleanUpTasksForFinishedJob(const JobID &job_id) { + auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id); task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove); // NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must // call it last. @@ -749,7 +748,7 @@ void NodeManager::ProcessClientMessage( << (registered_worker ? std::to_string(registered_worker->Pid()) : "nil"); if (registered_worker && registered_worker->IsDead()) { - // For a worker that is marked as dead (because the driver has died already), + // For a worker that is marked as dead (because the job has died already), // all the messages are ignored except DisconnectClient. if ((message_type_value != protocol::MessageType::DisconnectClient) && (message_type_value != protocol::MessageType::IntentionalDisconnectClient)) { @@ -824,7 +823,7 @@ void NodeManager::ProcessClientMessage( for (const auto &object_id : object_ids) { creating_task_ids.push_back(object_id.TaskId()); } - gcs_client_->raylet_task_table().Delete(DriverID::Nil(), creating_task_ids); + gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids); } } break; case protocol::MessageType::PrepareActorCheckpointRequest: { @@ -857,10 +856,11 @@ void NodeManager::ProcessRegisterClientRequestMessage( // message is actually the ID of the driver task, while client_id represents the // real driver ID, which can associate all the tasks/actors for a given driver, // which is set to the worker ID. - const DriverID driver_id = from_flatbuf(*message->driver_id()); + // TODO(qwang): Use driver_task_id instead here. + const WorkerID driver_id = from_flatbuf(*message->driver_id()); TaskID driver_task_id = TaskID::GetDriverTaskID(driver_id); worker->AssignTaskId(driver_task_id); - worker->AssignDriverId(from_flatbuf(*message->client_id())); + worker->AssignJobId(from_flatbuf(*message->client_id())); worker_pool_.RegisterDriver(std::move(worker)); local_queues_.AddDriverTaskId(driver_task_id); } @@ -992,14 +992,14 @@ void NodeManager::ProcessDisconnectClientMessage( if (!intentional_disconnect) { // Push the error to driver. - const DriverID &driver_id = worker->GetAssignedDriverId(); + const JobID &job_id = worker->GetAssignedJobId(); // TODO(rkn): Define this constant somewhere else. std::string type = "worker_died"; std::ostringstream error_message; error_message << "A worker died or was killed while executing task " << task_id << "."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - driver_id, type, error_message.str(), current_time_ms())); + job_id, type, error_message.str(), current_time_ms())); } } @@ -1022,22 +1022,21 @@ void NodeManager::ProcessDisconnectClientMessage( worker->ResetLifetimeResourceIds(); RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. " - << "driver_id: " << worker->GetAssignedDriverId(); + << "job_id: " << worker->GetAssignedJobId(); // Since some resources may have been released, we can try to dispatch more tasks. DispatchTasks(local_queues_.GetReadyTasksWithResources()); } else if (is_driver) { // The client is a driver. - RAY_CHECK_OK( - gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()), - /*is_dead=*/true)); - auto driver_id = worker->GetAssignedTaskId(); - RAY_CHECK(!driver_id.IsNil()); - local_queues_.RemoveDriverTaskId(driver_id); + RAY_CHECK_OK(gcs_client_->job_table().AppendJobData(JobID(client->GetClientId()), + /*is_dead=*/true)); + auto job_id = worker->GetAssignedTaskId(); + RAY_CHECK(!job_id.IsNil()); + local_queues_.RemoveDriverTaskId(job_id); worker_pool_.DisconnectDriver(worker); RAY_LOG(DEBUG) << "Driver (pid=" << worker->Pid() << ") is disconnected. " - << "driver_id: " << worker->GetAssignedDriverId(); + << "job_id: " << worker->GetAssignedJobId(); } // TODO(rkn): Tell the object manager that this client has disconnected so @@ -1142,13 +1141,13 @@ void NodeManager::ProcessWaitRequestMessage( void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) { auto message = flatbuffers::GetRoot(message_data); - DriverID driver_id = from_flatbuf(*message->driver_id()); + JobID job_id = from_flatbuf(*message->job_id()); auto const &type = string_from_flatbuf(*message->type()); auto const &error_message = string_from_flatbuf(*message->error_message()); double timestamp = message->timestamp(); - RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(driver_id, type, - error_message, timestamp)); + RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message, + timestamp)); } void NodeManager::ProcessPrepareActorCheckpointRequest( @@ -1173,7 +1172,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( // Write checkpoint data to GCS. RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add( - DriverID::Nil(), checkpoint_id, checkpoint_data, + JobID::Nil(), checkpoint_id, checkpoint_data, [worker, actor_id, this](ray::gcs::AsyncGcsClient *client, const ActorCheckpointID &checkpoint_id, const ActorCheckpointData &data) { @@ -1182,7 +1181,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest( // Save this actor-to-checkpoint mapping, and remove old checkpoints associated // with this actor. RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId( - DriverID::Nil(), actor_id, checkpoint_id)); + JobID::Nil(), actor_id, checkpoint_id)); // Send reply to worker. flatbuffers::FlatBufferBuilder fbb; auto reply = ray::protocol::CreatePrepareActorCheckpointReply( @@ -1284,7 +1283,7 @@ void NodeManager::ProcessSetResourceRequest( auto data_shared_ptr = std::make_shared(data); auto client_table = gcs_client_->client_table(); RAY_CHECK_OK(gcs_client_->client_table().Append( - DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); + JobID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr)); } void NodeManager::ScheduleTasks( @@ -1354,7 +1353,7 @@ void NodeManager::ScheduleTasks( << task.GetTaskSpecification().GetRequiredPlacementResources().ToString() << " for placement. Check the client table to view node resources."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - task.GetTaskSpecification().DriverId(), type, error_message.str(), + task.GetTaskSpecification().JobId(), type, error_message.str(), current_time_ms())); } // Assert that this placeable task is not feasible locally (necessary but not @@ -1415,8 +1414,7 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ std::string error_message = stream.str(); RAY_LOG(WARNING) << error_message; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - task.GetTaskSpecification().DriverId(), "task", error_message, - current_time_ms())); + task.GetTaskSpecification().JobId(), "task", error_message, current_time_ms())); } } task_dependency_manager_.TaskCanceled(spec.TaskId()); @@ -1558,7 +1556,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag HandleActorStateTransition(actor_id, ActorRegistration(data.back())); } }; - RAY_CHECK_OK(gcs_client_->actor_table().Lookup(DriverID::Nil(), spec.ActorId(), + RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::Nil(), spec.ActorId(), lookup_callback)); actor_creation_dummy_object = spec.ActorCreationDummyObjectId(); } else { @@ -1783,7 +1781,7 @@ bool NodeManager::AssignTask(const Task &task) { auto spec = assigned_task.GetTaskSpecification(); // We successfully assigned the task to the worker. worker->AssignTaskId(spec.TaskId()); - worker->AssignDriverId(spec.DriverId()); + worker->AssignJobId(spec.JobId()); // Actor tasks require extra accounting to track the actor's state. if (spec.IsActorTask()) { auto actor_entry = actor_registry_.find(spec.ActorId()); @@ -1870,10 +1868,10 @@ void NodeManager::FinishAssignedTask(Worker &worker) { // Unset the worker's assigned task. worker.AssignTaskId(TaskID::Nil()); - // Unset the worker's assigned driver Id if this is not an actor. + // Unset the worker's assigned job Id if this is not an actor. if (!task.GetTaskSpecification().IsActorCreationTask() && !task.GetTaskSpecification().IsActorTask()) { - worker.AssignDriverId(DriverID::Nil()); + worker.AssignJobId(JobID::Nil()); } } @@ -1892,7 +1890,7 @@ ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &tas new_actor_data.set_actor_id(actor_id.Binary()); new_actor_data.set_actor_creation_dummy_object_id( task.GetTaskSpecification().ActorDummyObject().Binary()); - new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary()); + new_actor_data.set_job_id(task.GetTaskSpecification().JobId().Binary()); new_actor_data.set_max_reconstructions( task.GetTaskSpecification().MaxActorReconstructions()); // This is the first time that the actor has been created, so the number @@ -1948,7 +1946,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { RAY_LOG(DEBUG) << "Looking up checkpoint " << checkpoint_id << " for actor " << actor_id; RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Lookup( - DriverID::Nil(), checkpoint_id, + JobID::Nil(), checkpoint_id, [this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client, const UniqueID &checkpoint_id, const ActorCheckpointData &checkpoint_data) { @@ -2017,7 +2015,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) { void NodeManager::HandleTaskReconstruction(const TaskID &task_id) { // Retrieve the task spec in order to re-execute the task. RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup( - DriverID::Nil(), task_id, + JobID::Nil(), task_id, /*success_callback=*/ [this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskTableData &task_data) { @@ -2072,7 +2070,7 @@ void NodeManager::ResubmitTask(const Task &task) { << " is a driver task and so the object created by ray.put " << "could not be reconstructed."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - task.GetTaskSpecification().DriverId(), type, error_message.str(), + task.GetTaskSpecification().JobId(), type, error_message.str(), current_time_ms())); return; } diff --git a/src/ray/raylet/node_manager.h b/src/ray/raylet/node_manager.h index 7e8121836..464e0c917 100644 --- a/src/ray/raylet/node_manager.h +++ b/src/ray/raylet/node_manager.h @@ -27,10 +27,10 @@ namespace raylet { using rpc::ActorTableData; using rpc::ClientTableData; -using rpc::DriverTableData; using rpc::ErrorType; using rpc::HeartbeatBatchTableData; using rpc::HeartbeatTableData; +using rpc::JobTableData; struct NodeManagerConfig { /// The node's resource configuration. @@ -326,12 +326,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { const ActorID &actor_id, const ActorTableData &data, const ray::gcs::ActorTable::WriteCallback &failure_callback); - /// When a driver dies, loop over all of the queued tasks for that driver and + /// When a job finished, loop over all of the queued tasks for that job and /// treat them as failed. /// - /// \param driver_id The driver that died. + /// \param job_id The job that exited. /// \return Void. - void CleanUpTasksForDeadDriver(const DriverID &driver_id); + void CleanUpTasksForFinishedJob(const JobID &job_id); /// Handle an object becoming local. This updates any local accounting, but /// does not write to any global accounting in the GCS. @@ -346,13 +346,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler { /// \return Void. void HandleObjectMissing(const ObjectID &object_id); - /// Handles updates to driver table. + /// Handles updates to job table. /// /// \param id An unused value. TODO(rkn): Should this be removed? - /// \param driver_data Data associated with a driver table event. + /// \param job_data Data associated with a job table event. /// \return Void. - void HandleDriverTableUpdate(const DriverID &id, - const std::vector &driver_data); + void HandleJobTableUpdate(const JobID &id, const std::vector &job_data); /// Check if certain invariants associated with the task dependency manager /// and the local queues are satisfied. This is only used for debugging diff --git a/src/ray/raylet/raylet_client.cc b/src/ray/raylet/raylet_client.cc index 801bb9112..2c4cbf60c 100644 --- a/src/ray/raylet/raylet_client.cc +++ b/src/ray/raylet/raylet_client.cc @@ -202,18 +202,14 @@ ray::Status RayletConnection::AtomicRequestReply( } RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id, - bool is_worker, const DriverID &driver_id, - const Language &language) - : client_id_(client_id), - is_worker_(is_worker), - driver_id_(driver_id), - language_(language) { + bool is_worker, const JobID &job_id, const Language &language) + : client_id_(client_id), is_worker_(is_worker), job_id_(job_id), language_(language) { // For C++14, we could use std::make_unique conn_ = std::unique_ptr(new RayletConnection(raylet_socket, -1, -1)); flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreateRegisterClientRequest( - fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, driver_id), + fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, job_id), language); fbb.Finish(message); // Register the process ID with the raylet. @@ -323,11 +319,11 @@ ray::Status RayletClient::Wait(const std::vector &object_ids, int num_ return ray::Status::OK(); } -ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string &type, +ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type, const std::string &error_message, double timestamp) { flatbuffers::FlatBufferBuilder fbb; auto message = ray::protocol::CreatePushErrorRequest( - fbb, to_flatbuf(fbb, driver_id), fbb.CreateString(type), + fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type), fbb.CreateString(error_message), timestamp); fbb.Finish(message); diff --git a/src/ray/raylet/raylet_client.h b/src/ray/raylet/raylet_client.h index 8b4dfad5b..53e880452 100644 --- a/src/ray/raylet/raylet_client.h +++ b/src/ray/raylet/raylet_client.h @@ -12,7 +12,7 @@ using ray::ActorCheckpointID; using ray::ActorID; using ray::ClientID; -using ray::DriverID; +using ray::JobID; using ray::ObjectID; using ray::TaskID; using ray::UniqueID; @@ -30,7 +30,7 @@ class RayletConnection { /// \param worker_id A unique ID to represent the worker. /// \param is_worker Whether this client is a worker. If it is a worker, an /// additional message will be sent to register as one. - /// \param driver_id The ID of the driver. This is non-nil if the client is a + /// \param job_id The ID of the driver. This is non-nil if the client is a /// driver. /// \return The connection information. RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout); @@ -66,10 +66,10 @@ class RayletClient { /// \param worker_id A unique ID to represent the worker. /// \param is_worker Whether this client is a worker. If it is a worker, an /// additional message will be sent to register as one. - /// \param driver_id The ID of the driver. This is non-nil if the client is a driver. + /// \param job_id The ID of the driver. This is non-nil if the client is a driver. /// \return The connection information. RayletClient(const std::string &raylet_socket, const ClientID &client_id, - bool is_worker, const DriverID &driver_id, const Language &language); + bool is_worker, const JobID &job_id, const Language &language); ray::Status Disconnect() { return conn_->Disconnect(); }; @@ -125,12 +125,12 @@ class RayletClient { /// Push an error to the relevant driver. /// - /// \param The ID of the job that the error is for. + /// \param The ID of the job_id that the error is for. /// \param The type of the error. /// \param The error message. /// \param The timestamp of the error. /// \return ray::Status. - ray::Status PushError(const DriverID &driver_id, const std::string &type, + ray::Status PushError(const ray::JobID &job_id, const std::string &type, const std::string &error_message, double timestamp); /// Store some profile events in the GCS. @@ -177,7 +177,7 @@ class RayletClient { ClientID GetClientID() const { return client_id_; } - DriverID GetDriverID() const { return driver_id_; } + JobID GetJobID() const { return job_id_; } bool IsWorker() const { return is_worker_; } @@ -186,7 +186,7 @@ class RayletClient { private: const ClientID client_id_; const bool is_worker_; - const DriverID driver_id_; + const JobID job_id_; const Language language_; /// A map from resource name to the resource IDs that are currently reserved /// for this worker. Each pair consists of the resource ID and the fraction diff --git a/src/ray/raylet/reconstruction_policy.cc b/src/ray/raylet/reconstruction_policy.cc index bf5c1acfa..f522c8986 100644 --- a/src/ray/raylet/reconstruction_policy.cc +++ b/src/ray/raylet/reconstruction_policy.cc @@ -52,7 +52,7 @@ void ReconstructionPolicy::SetTaskTimeout( // required by the task are no longer needed soon after. If the // task is still required after this initial period, then we now // subscribe to task lease notifications. - RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(DriverID::Nil(), task_id, + RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_)); it->second.subscribed = true; } @@ -110,7 +110,7 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id, reconstruction_entry->set_num_reconstructions(reconstruction_attempt); reconstruction_entry->set_node_manager_id(client_id_.Binary()); RAY_CHECK_OK(task_reconstruction_log_.AppendAt( - DriverID::Nil(), task_id, reconstruction_entry, + JobID::Nil(), task_id, reconstruction_entry, /*success_callback=*/ [this](gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionData &data) { @@ -199,7 +199,7 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) { // Cancel notifications for the task lease if we were subscribed to them. if (it->second.subscribed) { RAY_CHECK_OK( - task_lease_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_)); + task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_)); } listening_tasks_.erase(it); } diff --git a/src/ray/raylet/reconstruction_policy_test.cc b/src/ray/raylet/reconstruction_policy_test.cc index 12d9336a3..033da0561 100644 --- a/src/ray/raylet/reconstruction_policy_test.cc +++ b/src/ray/raylet/reconstruction_policy_test.cc @@ -84,7 +84,7 @@ class MockGcs : public gcs::PubsubInterface, failure_callback_ = failure_callback; } - void Add(const DriverID &driver_id, const TaskID &task_id, + void Add(const JobID &job_id, const TaskID &task_id, std::shared_ptr &task_lease_data) { task_lease_table_[task_id] = task_lease_data; if (subscribed_tasks_.count(task_id) == 1) { @@ -92,7 +92,7 @@ class MockGcs : public gcs::PubsubInterface, } } - Status RequestNotifications(const DriverID &driver_id, const TaskID &task_id, + Status RequestNotifications(const JobID &job_id, const TaskID &task_id, const ClientID &client_id) { subscribed_tasks_.insert(task_id); auto entry = task_lease_table_.find(task_id); @@ -104,14 +104,14 @@ class MockGcs : public gcs::PubsubInterface, return ray::Status::OK(); } - Status CancelNotifications(const DriverID &driver_id, const TaskID &task_id, + Status CancelNotifications(const JobID &job_id, const TaskID &task_id, const ClientID &client_id) { subscribed_tasks_.erase(task_id); return ray::Status::OK(); } Status AppendAt( - const DriverID &driver_id, const TaskID &task_id, + const JobID &job_id, const TaskID &task_id, std::shared_ptr &task_data, const ray::gcs::LogInterface::WriteCallback &success_callback, @@ -134,7 +134,7 @@ class MockGcs : public gcs::PubsubInterface, MOCK_METHOD4( Append, ray::Status( - const DriverID &, const TaskID &, std::shared_ptr &, + const JobID &, const TaskID &, std::shared_ptr &, const ray::gcs::LogInterface::WriteCallback &)); private: @@ -320,7 +320,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) { task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); task_lease_data->set_acquired_at(current_sys_time_ms()); task_lease_data->set_timeout(2 * test_period); - mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::Nil(), task_id, task_lease_data); // Listen for an object. reconstruction_policy_->ListenAndMaybeReconstruct(object_id); @@ -347,7 +347,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) { task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary()); task_lease_data->set_acquired_at(current_sys_time_ms()); task_lease_data->set_timeout(reconstruction_timeout_ms_); - mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data); + mock_gcs_.Add(JobID::Nil(), task_id, task_lease_data); }); // Run the test for much longer than the reconstruction timeout. Run(reconstruction_timeout_ms_ * 2); @@ -399,7 +399,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) { task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary()); task_reconstruction_data->set_num_reconstructions(0); RAY_CHECK_OK( - mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr, + mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr, /*failure_callback=*/ [](ray::gcs::AsyncGcsClient *client, const TaskID &task_id, const TaskReconstructionData &data) { ASSERT_TRUE(false); }, diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 73f0e2ef8..701b0c06b 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -19,15 +19,14 @@ inline const char *GetTaskStateString(ray::raylet::TaskState task_state) { return task_state_strings[static_cast(task_state)]; } -// Helper function to get tasks for a driver from a given state. +// Helper function to get tasks for a job from a given state. template -inline void GetDriverTasksFromQueue(const TaskQueue &queue, - const ray::DriverID &driver_id, +inline void GetTasksForJobFromQueue(const TaskQueue &queue, const ray::JobID &job_id, std::unordered_set &task_ids) { const auto &tasks = queue.GetTasks(); for (const auto &task : tasks) { auto const &spec = task.GetTaskSpecification(); - if (driver_id == spec.DriverId()) { + if (job_id == spec.JobId()) { task_ids.insert(spec.TaskId()); } } @@ -187,9 +186,9 @@ void SchedulingQueue::FilterState(std::unordered_set &task_ids, } } break; case TaskState::DRIVER: { - const auto driver_ids = GetDriverTaskIds(); + const auto driver_task_ids = GetDriverTaskIds(); for (auto it = task_ids.begin(); it != task_ids.end();) { - if (driver_ids.count(*it) == 1) { + if (driver_task_ids.count(*it) == 1) { it = task_ids.erase(it); } else { it++; @@ -356,11 +355,10 @@ bool SchedulingQueue::HasTask(const TaskID &task_id) const { return false; } -std::unordered_set SchedulingQueue::GetTaskIdsForDriver( - const DriverID &driver_id) const { +std::unordered_set SchedulingQueue::GetTaskIdsForJob(const JobID &job_id) const { std::unordered_set task_ids; for (const auto &task_queue : task_queues_) { - GetDriverTasksFromQueue(*task_queue, driver_id, task_ids); + GetTasksForJobFromQueue(*task_queue, job_id, task_ids); } return task_ids; } @@ -394,15 +392,15 @@ void SchedulingQueue::RemoveBlockedTaskId(const TaskID &task_id) { RAY_CHECK(erased == 1); } -void SchedulingQueue::AddDriverTaskId(const TaskID &driver_id) { - RAY_LOG(DEBUG) << "Added driver task " << driver_id; - auto inserted = driver_task_ids_.insert(driver_id); +void SchedulingQueue::AddDriverTaskId(const TaskID &task_id) { + RAY_LOG(DEBUG) << "Added driver task " << task_id; + auto inserted = driver_task_ids_.insert(task_id); RAY_CHECK(inserted.second); } -void SchedulingQueue::RemoveDriverTaskId(const TaskID &driver_id) { - RAY_LOG(DEBUG) << "Removed driver task " << driver_id; - auto erased = driver_task_ids_.erase(driver_id); +void SchedulingQueue::RemoveDriverTaskId(const TaskID &task_id) { + RAY_LOG(DEBUG) << "Removed driver task " << task_id; + auto erased = driver_task_ids_.erase(task_id); RAY_CHECK(erased == 1); } diff --git a/src/ray/raylet/scheduling_queue.h b/src/ray/raylet/scheduling_queue.h index 465f2a434..2cb0ca7de 100644 --- a/src/ray/raylet/scheduling_queue.h +++ b/src/ray/raylet/scheduling_queue.h @@ -283,11 +283,11 @@ class SchedulingQueue { /// \param filter_state The task state to filter out. void FilterState(std::unordered_set &task_ids, TaskState filter_state) const; - /// \brief Get all the task IDs for a driver. + /// \brief Get all the task IDs for a job. /// - /// \param driver_id All the tasks that have the given driver_id are returned. - /// \return All the tasks that have the given driver ID. - std::unordered_set GetTaskIdsForDriver(const DriverID &driver_id) const; + /// \param job_id All the tasks that have the given job_id are returned. + /// \return All the tasks that have the given job ID. + std::unordered_set GetTaskIdsForJob(const JobID &job_id) const; /// \brief Get all the task IDs for an actor. /// diff --git a/src/ray/raylet/task_dependency_manager.cc b/src/ray/raylet/task_dependency_manager.cc index 89028c733..8b1671f98 100644 --- a/src/ray/raylet/task_dependency_manager.cc +++ b/src/ray/raylet/task_dependency_manager.cc @@ -265,7 +265,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) { task_lease_data->set_node_manager_id(client_id_.Hex()); task_lease_data->set_acquired_at(current_sys_time_ms()); task_lease_data->set_timeout(it->second.lease_period); - RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr)); + RAY_CHECK_OK(task_lease_table_.Add(JobID::Nil(), task_id, task_lease_data, nullptr)); auto period = boost::posix_time::milliseconds(it->second.lease_period / 2); it->second.lease_timer->expires_from_now(period); diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index f7a60989f..f16c17b6b 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -29,7 +29,7 @@ class MockGcs : public gcs::TableInterface { public: MOCK_METHOD4( Add, - ray::Status(const DriverID &driver_id, const TaskID &task_id, + ray::Status(const JobID &job_id, const TaskID &task_id, std::shared_ptr &task_data, const gcs::TableInterface::WriteCallback &done)); }; @@ -75,7 +75,7 @@ static inline Task ExampleTask(const std::vector &arguments, task_arguments.emplace_back(std::make_shared(references)); } std::vector function_descriptor(3); - auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments, + auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments, num_returns, required_resources, Language::PYTHON, function_descriptor); auto execution_spec = TaskExecutionSpecification(std::vector()); diff --git a/src/ray/raylet/task_spec.cc b/src/ray/raylet/task_spec.cc index 1d722de18..e401e5a2b 100644 --- a/src/ray/raylet/task_spec.cc +++ b/src/ray/raylet/task_spec.cc @@ -61,18 +61,18 @@ TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size) { } TaskSpecification::TaskSpecification( - const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, const std::unordered_map &required_resources, const Language &language, const std::vector &function_descriptor) - : TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::Nil(), + : TaskSpecification(job_id, parent_task_id, parent_counter, ActorID::Nil(), ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), -1, {}, task_arguments, num_returns, required_resources, std::unordered_map(), language, function_descriptor) {} TaskSpecification::TaskSpecification( - const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, const int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -85,7 +85,7 @@ TaskSpecification::TaskSpecification( : spec_() { flatbuffers::FlatBufferBuilder fbb; - TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter); + TaskID task_id = GenerateTaskId(job_id, parent_task_id, parent_counter); // Add argument object IDs. std::vector> arguments; for (auto &argument : task_arguments) { @@ -94,7 +94,7 @@ TaskSpecification::TaskSpecification( // Serialize the TaskSpecification. auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id), + fbb, to_flatbuf(fbb, job_id), to_flatbuf(fbb, task_id), to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id), to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions, to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter, @@ -123,9 +123,9 @@ TaskID TaskSpecification::TaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); return from_flatbuf(*message->task_id()); } -DriverID TaskSpecification::DriverId() const { +JobID TaskSpecification::JobId() const { auto message = flatbuffers::GetRoot(spec_.data()); - return from_flatbuf(*message->driver_id()); + return from_flatbuf(*message->job_id()); } TaskID TaskSpecification::ParentTaskId() const { auto message = flatbuffers::GetRoot(spec_.data()); diff --git a/src/ray/raylet/task_spec.h b/src/ray/raylet/task_spec.h index 8a08e9974..4339a1a4c 100644 --- a/src/ray/raylet/task_spec.h +++ b/src/ray/raylet/task_spec.h @@ -86,7 +86,7 @@ class TaskSpecification { /// Create a task specification from the raw fields. This constructor omits /// some values and sets them to sensible defaults. /// - /// \param driver_id The driver ID, representing the job that this task is a + /// \param job_id The driver ID, representing the job that this task is a /// part of. /// \param parent_task_id The task ID of the task that spawned this task. /// \param parent_counter The number of tasks that this task's parent spawned @@ -96,7 +96,7 @@ class TaskSpecification { /// \param num_returns The number of values returned by the task. /// \param required_resources The task's resource demands. /// \param language The language of the worker that must execute the function. - TaskSpecification(const DriverID &driver_id, const TaskID &parent_task_id, + TaskSpecification(const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, const std::vector> &task_arguments, int64_t num_returns, @@ -107,7 +107,7 @@ class TaskSpecification { // TODO(swang): Define an actor task constructor. /// Create a task specification from the raw fields. /// - /// \param driver_id The driver ID, representing the job that this task is a + /// \param job_id The driver ID, representing the job that this task is a /// part of. /// \param parent_task_id The task ID of the task that spawned this task. /// \param parent_counter The number of tasks that this task's parent spawned @@ -130,7 +130,7 @@ class TaskSpecification { /// \param function_descriptor The function descriptor. /// \param dynamic_worker_options The dynamic options for starting an actor worker. TaskSpecification( - const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter, + const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter, const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id, int64_t max_actor_reconstructions, const ActorID &actor_id, const ActorHandleID &actor_handle_id, int64_t actor_counter, @@ -171,7 +171,7 @@ class TaskSpecification { // TODO(swang): Finalize and document these methods. TaskID TaskId() const; - DriverID DriverId() const; + JobID JobId() const; TaskID ParentTaskId() const; int64_t ParentCounter() const; std::vector FunctionDescriptor() const; diff --git a/src/ray/raylet/task_test.cc b/src/ray/raylet/task_test.cc index 1e26cb33b..72864785a 100644 --- a/src/ray/raylet/task_test.cc +++ b/src/ray/raylet/task_test.cc @@ -64,7 +64,7 @@ TEST(TaskSpecTest, TaskInfoSize) { } // General task. auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id), + fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id), to_flatbuf(fbb, TaskID::FromRandom()), 0, to_flatbuf(fbb, ActorID::Nil()), to_flatbuf(fbb, ObjectID::Nil()), 0, to_flatbuf(fbb, ActorID::Nil()), to_flatbuf(fbb, ActorHandleID::Nil()), 0, @@ -83,7 +83,7 @@ TEST(TaskSpecTest, TaskInfoSize) { } // General task. auto spec = CreateTaskInfo( - fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id), + fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id), to_flatbuf(fbb, TaskID::FromRandom()), 10, to_flatbuf(fbb, ActorID::FromRandom()), to_flatbuf(fbb, ObjectID::FromRandom()), 10000000, to_flatbuf(fbb, ActorID::FromRandom()), diff --git a/src/ray/raylet/worker.cc b/src/ray/raylet/worker.cc index 36bfc6d84..359754340 100644 --- a/src/ray/raylet/worker.cc +++ b/src/ray/raylet/worker.cc @@ -50,11 +50,9 @@ const std::unordered_set &Worker::GetBlockedTaskIds() const { return blocked_task_ids_; } -void Worker::AssignDriverId(const DriverID &driver_id) { - assigned_driver_id_ = driver_id; -} +void Worker::AssignJobId(const JobID &job_id) { assigned_job_id_ = job_id; } -const DriverID &Worker::GetAssignedDriverId() const { return assigned_driver_id_; } +const JobID &Worker::GetAssignedJobId() const { return assigned_job_id_; } void Worker::AssignActorId(const ActorID &actor_id) { RAY_CHECK(actor_id_.IsNil()) diff --git a/src/ray/raylet/worker.h b/src/ray/raylet/worker.h index cb0797ddd..7cd8d5e1d 100644 --- a/src/ray/raylet/worker.h +++ b/src/ray/raylet/worker.h @@ -34,8 +34,8 @@ class Worker { bool AddBlockedTaskId(const TaskID &task_id); bool RemoveBlockedTaskId(const TaskID &task_id); const std::unordered_set &GetBlockedTaskIds() const; - void AssignDriverId(const DriverID &driver_id); - const DriverID &GetAssignedDriverId() const; + void AssignJobId(const JobID &job_id); + const JobID &GetAssignedJobId() const; void AssignActorId(const ActorID &actor_id); const ActorID &GetActorId() const; /// Return the worker's connection. @@ -60,8 +60,8 @@ class Worker { std::shared_ptr connection_; /// The worker's currently assigned task. TaskID assigned_task_id_; - /// Driver ID for the worker's current assigned task. - DriverID assigned_driver_id_; + /// Job ID for the worker's current assigned task. + JobID assigned_job_id_; /// The worker's actor ID. If this is nil, then the worker is not an actor. ActorID actor_id_; /// Whether the worker is dead. diff --git a/src/ray/raylet/worker_pool.cc b/src/ray/raylet/worker_pool.cc index 16086565d..f15df88ae 100644 --- a/src/ray/raylet/worker_pool.cc +++ b/src/ray/raylet/worker_pool.cc @@ -319,13 +319,13 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua return state->second; } -std::vector> WorkerPool::GetWorkersRunningTasksForDriver( - const DriverID &driver_id) const { +std::vector> WorkerPool::GetWorkersRunningTasksForJob( + const JobID &job_id) const { std::vector> workers; for (const auto &entry : states_by_lang_) { for (const auto &worker : entry.second.registered_workers) { - if (worker->GetAssignedDriverId() == driver_id) { + if (worker->GetAssignedJobId() == job_id) { workers.push_back(worker); } } @@ -355,7 +355,7 @@ void WorkerPool::WarnAboutSize() { << "(see https://github.com/ray-project/ray/issues/3644) for " << "some a discussion of workarounds."; RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver( - DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); + JobID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms())); } } diff --git a/src/ray/raylet/worker_pool.h b/src/ray/raylet/worker_pool.h index e1e726268..4ea2648d7 100644 --- a/src/ray/raylet/worker_pool.h +++ b/src/ray/raylet/worker_pool.h @@ -102,12 +102,12 @@ class WorkerPool { /// \return The total count of all workers (actor and non-actor) in the pool. uint32_t Size(const Language &language) const; - /// Get all the workers which are running tasks for a given driver. + /// Get all the workers which are running tasks for a given job. /// - /// \param driver_id The driver ID. - /// \return A list containing all the workers which are running tasks for the driver. - std::vector> GetWorkersRunningTasksForDriver( - const DriverID &driver_id) const; + /// \param job_id The job ID. + /// \return A list containing all the workers which are running tasks for the job. + std::vector> GetWorkersRunningTasksForJob( + const JobID &job_id) const; /// Whether there is a pending worker for the given task. /// Note that, this is only used for actor creation task with dynamic options. diff --git a/src/ray/raylet/worker_pool_test.cc b/src/ray/raylet/worker_pool_test.cc index 15a5fb047..698387a8f 100644 --- a/src/ray/raylet/worker_pool_test.cc +++ b/src/ray/raylet/worker_pool_test.cc @@ -109,7 +109,7 @@ static inline TaskSpecification ExampleTaskSpec( const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON, const ActorID actor_creation_id = ActorID::Nil()) { std::vector function_descriptor(3); - return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id, + return TaskSpecification(JobID::Nil(), TaskID::Nil(), 0, actor_creation_id, ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, language, function_descriptor); } @@ -226,7 +226,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) { SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}}, {Language::JAVA, java_worker_command}}); - TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), + TaskSpecification task_spec(JobID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(), ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0, {}, {}, 0, {}, {}, Language::JAVA, {"", "", ""}, {"test_op_0", "test_op_1"});