mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[ID Refactor] Rename DriverID to JobID (#5004)
* WIP WIP WIP Rename Driver -> Job Fix complition Fix Rename in Java In py WIP Fix WIP Fix Fix test Fix Fix C++ linting Fix * Update java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Update src/ray/core_worker/core_worker.cc Co-Authored-By: Stephanie Wang <swang@cs.berkeley.edu> * Address comments * Fix * Fix CI * Fix cpp linting * Fix py lint * FIx * Address comments and fix * Address comments * Address * Fix import_threading
This commit is contained in:
@@ -9,12 +9,9 @@ import org.ray.api.id.UniqueId;
|
||||
public interface RuntimeContext {
|
||||
|
||||
/**
|
||||
* Get the current Driver ID.
|
||||
*
|
||||
* If called in a driver, this returns the driver ID. If called in a worker, this returns the ID
|
||||
* of the associated driver.
|
||||
* Get the current Job ID.
|
||||
*/
|
||||
UniqueId getCurrentDriverId();
|
||||
UniqueId getCurrentJobId();
|
||||
|
||||
/**
|
||||
* Get the current actor ID.
|
||||
|
||||
@@ -74,10 +74,10 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
functionManager = new FunctionManager(rayConfig.driverResourcePath);
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
worker = new Worker(this);
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.driverId, rayConfig.runMode);
|
||||
rayConfig.jobId, rayConfig.runMode);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
@@ -346,7 +346,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
boolean isActorCreationTask, BaseTaskOptions taskOptions) {
|
||||
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
|
||||
|
||||
TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentDriverId(),
|
||||
TaskId taskId = rayletClient.generateTaskId(workerContext.getCurrentJobId(),
|
||||
workerContext.getCurrentTaskId(), workerContext.nextTaskIndex());
|
||||
int numReturns = actor.getId().isNil() ? 1 : 2;
|
||||
ObjectId[] returnIds = IdUtil.genReturnIds(taskId, numReturns);
|
||||
@@ -377,7 +377,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
FunctionDescriptor functionDescriptor;
|
||||
if (func != null) {
|
||||
language = TaskLanguage.JAVA;
|
||||
functionDescriptor = functionManager.getFunction(workerContext.getCurrentDriverId(), func)
|
||||
functionDescriptor = functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.getFunctionDescriptor();
|
||||
} else {
|
||||
language = TaskLanguage.PYTHON;
|
||||
@@ -385,7 +385,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
}
|
||||
|
||||
return new TaskSpec(
|
||||
workerContext.getCurrentDriverId(),
|
||||
workerContext.getCurrentJobId(),
|
||||
taskId,
|
||||
workerContext.getCurrentTaskId(),
|
||||
-1,
|
||||
|
||||
@@ -101,7 +101,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
rayConfig.rayletSocketName,
|
||||
workerContext.getCurrentWorkerId(),
|
||||
rayConfig.workerMode == WorkerMode.WORKER,
|
||||
workerContext.getCurrentDriverId()
|
||||
workerContext.getCurrentJobId()
|
||||
);
|
||||
|
||||
// register
|
||||
|
||||
@@ -17,8 +17,8 @@ public class RuntimeContextImpl implements RuntimeContext {
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentDriverId() {
|
||||
return runtime.getWorkerContext().getCurrentDriverId();
|
||||
public UniqueId getCurrentJobId() {
|
||||
return runtime.getWorkerContext().getCurrentJobId();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -86,7 +86,7 @@ public class Worker {
|
||||
try {
|
||||
// Get method
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.driverId, spec.getJavaFunctionDescriptor());
|
||||
.getFunction(spec.jobId, spec.getJavaFunctionDescriptor());
|
||||
// Set context
|
||||
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
|
||||
@@ -29,7 +29,7 @@ public class WorkerContext {
|
||||
|
||||
private ThreadLocal<TaskSpec> currentTask;
|
||||
|
||||
private UniqueId currentDriverId;
|
||||
private UniqueId currentJobId;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
@@ -43,7 +43,7 @@ public class WorkerContext {
|
||||
*/
|
||||
private RunMode runMode;
|
||||
|
||||
public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) {
|
||||
public WorkerContext(WorkerMode workerMode, UniqueId jobId, RunMode runMode) {
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
taskIndex = ThreadLocal.withInitial(() -> 0);
|
||||
putIndex = ThreadLocal.withInitial(() -> 0);
|
||||
@@ -52,13 +52,15 @@ public class WorkerContext {
|
||||
currentTask = ThreadLocal.withInitial(() -> null);
|
||||
currentClassLoader = null;
|
||||
if (workerMode == WorkerMode.DRIVER) {
|
||||
workerId = driverId;
|
||||
// TODO(qwang): Assign the driver id to worker id
|
||||
// once we treat driver id as a special worker id.
|
||||
workerId = jobId;
|
||||
currentTaskId.set(TaskId.randomId());
|
||||
currentDriverId = driverId;
|
||||
currentJobId = jobId;
|
||||
} else {
|
||||
workerId = UniqueId.randomId();
|
||||
this.currentTaskId.set(TaskId.NIL);
|
||||
this.currentDriverId = UniqueId.NIL;
|
||||
this.currentJobId = UniqueId.NIL;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +86,7 @@ public class WorkerContext {
|
||||
|
||||
Preconditions.checkNotNull(task);
|
||||
this.currentTaskId.set(task.taskId);
|
||||
this.currentDriverId = task.driverId;
|
||||
this.currentJobId = task.jobId;
|
||||
taskIndex.set(0);
|
||||
putIndex.set(0);
|
||||
this.currentTask.set(task);
|
||||
@@ -115,15 +117,14 @@ public class WorkerContext {
|
||||
}
|
||||
|
||||
/**
|
||||
* @return If this worker is a driver, this method returns the driver ID; Otherwise, it returns
|
||||
* the driver ID of the current running task.
|
||||
* The ID of the current job.
|
||||
*/
|
||||
public UniqueId getCurrentDriverId() {
|
||||
return currentDriverId;
|
||||
public UniqueId getCurrentJobId() {
|
||||
return currentJobId;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The class loader which is associated with the current driver.
|
||||
* @return The class loader which is associated with the current job.
|
||||
*/
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
|
||||
@@ -32,7 +32,7 @@ public class RayConfig {
|
||||
public final WorkerMode workerMode;
|
||||
public final RunMode runMode;
|
||||
public final Map<String, Double> resources;
|
||||
public final UniqueId driverId;
|
||||
public final UniqueId jobId;
|
||||
public final String logDir;
|
||||
public final boolean redirectOutput;
|
||||
public final List<String> libraryPath;
|
||||
@@ -53,7 +53,7 @@ public class RayConfig {
|
||||
public final String rayletSocketName;
|
||||
public final List<String> rayletConfigParameters;
|
||||
|
||||
public final String driverResourcePath;
|
||||
public final String jobResourcePath;
|
||||
public final String pythonWorkerCommand;
|
||||
|
||||
/**
|
||||
@@ -105,12 +105,12 @@ public class RayConfig {
|
||||
resources.put("CPU", numCpu * 1.0);
|
||||
}
|
||||
}
|
||||
// Driver id.
|
||||
String driverId = config.getString("ray.driver.id");
|
||||
if (!driverId.isEmpty()) {
|
||||
this.driverId = UniqueId.fromHexString(driverId);
|
||||
// Job id.
|
||||
String jobId = config.getString("ray.job.id");
|
||||
if (!jobId.isEmpty()) {
|
||||
this.jobId = UniqueId.fromHexString(jobId);
|
||||
} else {
|
||||
this.driverId = UniqueId.randomId();
|
||||
this.jobId = UniqueId.randomId();
|
||||
}
|
||||
// Log dir.
|
||||
logDir = removeTrailingSlash(config.getString("ray.log-dir"));
|
||||
@@ -160,11 +160,11 @@ public class RayConfig {
|
||||
rayletConfigParameters.add(parameter);
|
||||
}
|
||||
|
||||
// Driver resource path.
|
||||
if (config.hasPath("ray.driver.resource-path")) {
|
||||
driverResourcePath = config.getString("ray.driver.resource-path");
|
||||
// Job resource path.
|
||||
if (config.hasPath("ray.job.resource-path")) {
|
||||
jobResourcePath = config.getString("ray.job.resource-path");
|
||||
} else {
|
||||
driverResourcePath = null;
|
||||
jobResourcePath = null;
|
||||
}
|
||||
|
||||
// Number of threads that execute tasks.
|
||||
@@ -205,7 +205,7 @@ public class RayConfig {
|
||||
+ ", workerMode=" + workerMode
|
||||
+ ", runMode=" + runMode
|
||||
+ ", resources=" + resources
|
||||
+ ", driverId=" + driverId
|
||||
+ ", jobId=" + jobId
|
||||
+ ", logDir='" + logDir + '\''
|
||||
+ ", redirectOutput=" + redirectOutput
|
||||
+ ", libraryPath=" + libraryPath
|
||||
@@ -220,7 +220,7 @@ public class RayConfig {
|
||||
+ ", objectStoreSize=" + objectStoreSize
|
||||
+ ", rayletSocketName='" + rayletSocketName + '\''
|
||||
+ ", rayletConfigParameters=" + rayletConfigParameters
|
||||
+ ", driverResourcePath='" + driverResourcePath + '\''
|
||||
+ ", jobResourcePath='" + jobResourcePath + '\''
|
||||
+ ", pythonWorkerCommand='" + pythonWorkerCommand + '\''
|
||||
+ '}';
|
||||
}
|
||||
|
||||
@@ -30,7 +30,7 @@ import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Manages functions by driver id.
|
||||
* Manages functions by job id.
|
||||
*/
|
||||
public class FunctionManager {
|
||||
|
||||
@@ -46,33 +46,33 @@ public class FunctionManager {
|
||||
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
|
||||
|
||||
/**
|
||||
* Mapping from the driver id to the functions that belong to this driver.
|
||||
* Mapping from the job id to the functions that belong to this job.
|
||||
*/
|
||||
private Map<UniqueId, DriverFunctionTable> driverFunctionTables = new HashMap<>();
|
||||
private Map<UniqueId, JobFunctionTable> jobFunctionTables = new HashMap<>();
|
||||
|
||||
/**
|
||||
* The resource path which we can load the driver's jar resources.
|
||||
* The resource path which we can load the job's jar resources.
|
||||
*/
|
||||
private String driverResourcePath;
|
||||
private String jobResourcePath;
|
||||
|
||||
/**
|
||||
* Construct a FunctionManager with the specified driver resource path.
|
||||
* Construct a FunctionManager with the specified job resource path.
|
||||
*
|
||||
* @param driverResourcePath The specified driver resource that can store the driver's
|
||||
* @param jobResourcePath The specified job resource that can store the job's
|
||||
* resources.
|
||||
*/
|
||||
public FunctionManager(String driverResourcePath) {
|
||||
this.driverResourcePath = driverResourcePath;
|
||||
public FunctionManager(String jobResourcePath) {
|
||||
this.jobResourcePath = jobResourcePath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a RayFunc instance (a lambda).
|
||||
*
|
||||
* @param driverId current driver id.
|
||||
* @param jobId current job id.
|
||||
* @param func The lambda.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, RayFunc func) {
|
||||
public RayFunction getFunction(UniqueId jobId, RayFunc func) {
|
||||
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
|
||||
if (functionDescriptor == null) {
|
||||
SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func);
|
||||
@@ -82,24 +82,24 @@ public class FunctionManager {
|
||||
functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor);
|
||||
RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor);
|
||||
}
|
||||
return getFunction(driverId, functionDescriptor);
|
||||
return getFunction(jobId, functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the RayFunction from a function descriptor.
|
||||
*
|
||||
* @param driverId Current driver id.
|
||||
* @param jobId Current job id.
|
||||
* @param functionDescriptor The function descriptor.
|
||||
* @return A RayFunction object.
|
||||
*/
|
||||
public RayFunction getFunction(UniqueId driverId, JavaFunctionDescriptor functionDescriptor) {
|
||||
DriverFunctionTable driverFunctionTable = driverFunctionTables.get(driverId);
|
||||
if (driverFunctionTable == null) {
|
||||
public RayFunction getFunction(UniqueId jobId, JavaFunctionDescriptor functionDescriptor) {
|
||||
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
|
||||
if (jobFunctionTable == null) {
|
||||
ClassLoader classLoader;
|
||||
if (Strings.isNullOrEmpty(driverResourcePath)) {
|
||||
if (Strings.isNullOrEmpty(jobResourcePath)) {
|
||||
classLoader = getClass().getClassLoader();
|
||||
} else {
|
||||
File resourceDir = new File(driverResourcePath + "/" + driverId.toString() + "/");
|
||||
File resourceDir = new File(jobResourcePath + "/" + jobId.toString() + "/");
|
||||
Collection<File> files = FileUtils.listFiles(resourceDir,
|
||||
new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY);
|
||||
files.add(resourceDir);
|
||||
@@ -111,23 +111,23 @@ public class FunctionManager {
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
classLoader = new URLClassLoader(urlList.toArray(new URL[urlList.size()]));
|
||||
LOGGER.debug("Resource loaded for driver {} from path {}.", driverId,
|
||||
LOGGER.debug("Resource loaded for job {} from path {}.", jobId,
|
||||
resourceDir.getAbsolutePath());
|
||||
}
|
||||
|
||||
driverFunctionTable = new DriverFunctionTable(classLoader);
|
||||
driverFunctionTables.put(driverId, driverFunctionTable);
|
||||
jobFunctionTable = new JobFunctionTable(classLoader);
|
||||
jobFunctionTables.put(jobId, jobFunctionTable);
|
||||
}
|
||||
return driverFunctionTable.getFunction(functionDescriptor);
|
||||
return jobFunctionTable.getFunction(functionDescriptor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages all functions that belong to one driver.
|
||||
* Manages all functions that belong to one job.
|
||||
*/
|
||||
static class DriverFunctionTable {
|
||||
static class JobFunctionTable {
|
||||
|
||||
/**
|
||||
* The driver's corresponding class loader.
|
||||
* The job's corresponding class loader.
|
||||
*/
|
||||
ClassLoader classLoader;
|
||||
/**
|
||||
@@ -135,7 +135,7 @@ public class FunctionManager {
|
||||
*/
|
||||
Map<String, Map<Pair<String, String>, RayFunction>> functions;
|
||||
|
||||
DriverFunctionTable(ClassLoader classLoader) {
|
||||
JobFunctionTable(ClassLoader classLoader) {
|
||||
this.classLoader = classLoader;
|
||||
this.functions = new HashMap<>();
|
||||
}
|
||||
|
||||
@@ -164,7 +164,7 @@ public class MockRayletClient implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) {
|
||||
public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) {
|
||||
return TaskId.randomId();
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ public interface RayletClient {
|
||||
|
||||
void notifyUnblocked(TaskId currentTaskId);
|
||||
|
||||
TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex);
|
||||
TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex);
|
||||
|
||||
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId);
|
||||
|
||||
@@ -44,10 +44,11 @@ public class RayletClientImpl implements RayletClient {
|
||||
*/
|
||||
private long client = 0;
|
||||
|
||||
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
|
||||
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
|
||||
boolean isWorker, UniqueId driverId) {
|
||||
boolean isWorker, UniqueId jobId) {
|
||||
client = nativeInit(schedulerSockName, clientId.getBytes(),
|
||||
isWorker, driverId.getBytes());
|
||||
isWorker, jobId.getBytes());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -83,7 +84,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
public void submitTask(TaskSpec spec) {
|
||||
LOGGER.debug("Submitting task: {}", spec);
|
||||
Preconditions.checkState(!spec.parentTaskId.isNil());
|
||||
Preconditions.checkState(!spec.driverId.isNil());
|
||||
Preconditions.checkState(!spec.jobId.isNil());
|
||||
|
||||
ByteBuffer info = convertTaskSpecToFlatbuffer(spec);
|
||||
byte[] cursorId = null;
|
||||
@@ -114,8 +115,8 @@ public class RayletClientImpl implements RayletClient {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId generateTaskId(UniqueId driverId, TaskId parentTaskId, int taskIndex) {
|
||||
byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
|
||||
public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) {
|
||||
byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex);
|
||||
return new TaskId(bytes);
|
||||
}
|
||||
|
||||
@@ -141,11 +142,10 @@ public class RayletClientImpl implements RayletClient {
|
||||
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
|
||||
}
|
||||
|
||||
|
||||
private static TaskSpec parseTaskSpecFromFlatbuffer(ByteBuffer bb) {
|
||||
bb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
|
||||
UniqueId driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer());
|
||||
UniqueId jobId = UniqueId.fromByteBuffer(info.jobIdAsByteBuffer());
|
||||
TaskId taskId = TaskId.fromByteBuffer(info.taskIdAsByteBuffer());
|
||||
TaskId parentTaskId = TaskId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
|
||||
int parentCounter = info.parentCounter();
|
||||
@@ -197,7 +197,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
dynamicWorkerOptions.add(info.dynamicWorkerOptions(i));
|
||||
}
|
||||
|
||||
return new TaskSpec(driverId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
maxActorReconstructions, actorId, actorHandleId, actorCounter, newActorHandles,
|
||||
args, numReturns, resources, TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
|
||||
}
|
||||
@@ -207,7 +207,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
bb.clear();
|
||||
|
||||
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
|
||||
final int driverIdOffset = fbb.createString(task.driverId.toByteBuffer());
|
||||
final int jobIdOffset = fbb.createString(task.jobId.toByteBuffer());
|
||||
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
|
||||
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
|
||||
final int parentCounter = task.parentCounter;
|
||||
@@ -290,7 +290,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
|
||||
int root = TaskInfo.createTaskInfo(
|
||||
fbb,
|
||||
driverIdOffset,
|
||||
jobIdOffset,
|
||||
taskIdOffset,
|
||||
parentTaskIdOffset,
|
||||
parentCounter,
|
||||
@@ -363,7 +363,7 @@ public class RayletClientImpl implements RayletClient {
|
||||
private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
|
||||
int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException;
|
||||
|
||||
private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId,
|
||||
private static native byte[] nativeGenerateTaskId(byte[] jobId, byte[] parentTaskId,
|
||||
int taskIndex);
|
||||
|
||||
private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds,
|
||||
|
||||
@@ -18,8 +18,8 @@ import org.ray.runtime.util.IdUtil;
|
||||
*/
|
||||
public class TaskSpec {
|
||||
|
||||
// ID of the driver that created this task.
|
||||
public final UniqueId driverId;
|
||||
// ID of the job that created this task.
|
||||
public final UniqueId jobId;
|
||||
|
||||
// Task ID of the task.
|
||||
public final TaskId taskId;
|
||||
@@ -81,7 +81,7 @@ public class TaskSpec {
|
||||
}
|
||||
|
||||
public TaskSpec(
|
||||
UniqueId driverId,
|
||||
UniqueId jobId,
|
||||
TaskId taskId,
|
||||
TaskId parentTaskId,
|
||||
int parentCounter,
|
||||
@@ -97,7 +97,7 @@ public class TaskSpec {
|
||||
TaskLanguage language,
|
||||
FunctionDescriptor functionDescriptor,
|
||||
List<String> dynamicWorkerOptions) {
|
||||
this.driverId = driverId;
|
||||
this.jobId = jobId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.parentCounter = parentCounter;
|
||||
@@ -147,7 +147,7 @@ public class TaskSpec {
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TaskSpec{" +
|
||||
"driverId=" + driverId +
|
||||
"jobId=" + jobId +
|
||||
", taskId=" + taskId +
|
||||
", parentTaskId=" + parentTaskId +
|
||||
", parentCounter=" + parentCounter +
|
||||
|
||||
@@ -20,14 +20,14 @@ ray {
|
||||
// Available resources on this node, for example "CPU:4,GPU:0".
|
||||
resources: ""
|
||||
|
||||
// Configuration items about driver.
|
||||
driver {
|
||||
// If worker.mode is DRIVER, specify the driver id.
|
||||
// Configuration items about job.
|
||||
job {
|
||||
// If worker.mode is DRIVER, specify the job id.
|
||||
// If not provided, a random id will be used.
|
||||
id: ""
|
||||
// If this config is set, worker will use different paths to loadresources when
|
||||
// executing tasks from different drivers. E.g. if it's set to '/tm/driver_resources',
|
||||
// the path for driver 123 will be '/tmp/driver_resources/123'.
|
||||
// If this config is set, worker will use different paths to load resources when
|
||||
// executing tasks from different jobs. E.g. if it's set to '/tm/job_resources',
|
||||
// the path for job 123 will be '/tmp/job_resources/123'.
|
||||
resource-path: ""
|
||||
}
|
||||
|
||||
|
||||
+12
-12
@@ -13,7 +13,7 @@ import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionManager.DriverFunctionTable;
|
||||
import org.ray.runtime.functionmanager.FunctionManager.JobFunctionTable;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -106,7 +106,7 @@ public class FunctionManagerTest {
|
||||
|
||||
@Test
|
||||
public void testLoadFunctionTableForClass() {
|
||||
DriverFunctionTable functionTable = new DriverFunctionTable(getClass().getClassLoader());
|
||||
JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader());
|
||||
Map<Pair<String, String>, RayFunction> res = functionTable
|
||||
.loadFunctionsForClass(Bar.class.getName());
|
||||
// The result should 2 entries, one for the constructor, the other for bar.
|
||||
@@ -119,13 +119,13 @@ public class FunctionManagerTest {
|
||||
|
||||
@Test
|
||||
public void testGetFunctionFromLocalResource() throws Exception {
|
||||
UniqueId driverId = UniqueId.randomId();
|
||||
UniqueId jobId = UniqueId.randomId();
|
||||
final String resourcePath = FileUtils.getTempDirectoryPath() + "/ray_test_resources";
|
||||
final String driverResourcePath = resourcePath + "/" + driverId.toString();
|
||||
File driverResourceDir = new File(driverResourcePath);
|
||||
FileUtils.deleteQuietly(driverResourceDir);
|
||||
driverResourceDir.mkdirs();
|
||||
driverResourceDir.deleteOnExit();
|
||||
final String jobResourcePath = resourcePath + "/" + jobId.toString();
|
||||
File jobResourceDir = new File(jobResourcePath);
|
||||
FileUtils.deleteQuietly(jobResourceDir);
|
||||
jobResourceDir.mkdirs();
|
||||
jobResourceDir.deleteOnExit();
|
||||
|
||||
String demoJavaFile = "";
|
||||
demoJavaFile += "public class DemoApp {\n";
|
||||
@@ -134,13 +134,13 @@ public class FunctionManagerTest {
|
||||
demoJavaFile += " }\n";
|
||||
demoJavaFile += "}";
|
||||
|
||||
// Write the demo java file to the driver resource path.
|
||||
String javaFilePath = driverResourcePath + "/DemoApp.java";
|
||||
// Write the demo java file to the job resource path.
|
||||
String javaFilePath = jobResourcePath + "/DemoApp.java";
|
||||
Files.write(Paths.get(javaFilePath), demoJavaFile.getBytes());
|
||||
|
||||
// Compile the java file.
|
||||
JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
|
||||
int result = compiler.run(null, null, null, "-d", driverResourcePath, javaFilePath);
|
||||
int result = compiler.run(null, null, null, "-d", jobResourcePath, javaFilePath);
|
||||
if (result != 0) {
|
||||
throw new RuntimeException("Couldn't compile Demo.java.");
|
||||
}
|
||||
@@ -149,7 +149,7 @@ public class FunctionManagerTest {
|
||||
JavaFunctionDescriptor descriptor = new JavaFunctionDescriptor(
|
||||
"DemoApp", "hello", "()Ljava/lang/String;");
|
||||
final FunctionManager functionManager = new FunctionManager(resourcePath);
|
||||
RayFunction func = functionManager.getFunction(driverId, descriptor);
|
||||
RayFunction func = functionManager.getFunction(jobId, descriptor);
|
||||
Assert.assertEquals(func.getFunctionDescriptor(), descriptor);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,13 +10,13 @@ public class RayConfigTest {
|
||||
@Test
|
||||
public void testCreateRayConfig() {
|
||||
try {
|
||||
System.setProperty("ray.driver.resource-path", "path/to/ray/driver/resource/path");
|
||||
System.setProperty("ray.job.resource-path", "path/to/ray/job/resource/path");
|
||||
RayConfig rayConfig = RayConfig.create();
|
||||
Assert.assertEquals(WorkerMode.DRIVER, rayConfig.workerMode);
|
||||
Assert.assertEquals("path/to/ray/driver/resource/path", rayConfig.driverResourcePath);
|
||||
Assert.assertEquals("path/to/ray/job/resource/path", rayConfig.jobResourcePath);
|
||||
} finally {
|
||||
// Unset system properties.
|
||||
System.clearProperty("ray.driver.resource-path");
|
||||
System.clearProperty("ray.job.resource-path");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,28 +11,28 @@ import org.testng.annotations.Test;
|
||||
|
||||
public class RuntimeContextTest extends BaseTest {
|
||||
|
||||
private static UniqueId DRIVER_ID =
|
||||
private static UniqueId JOB_ID =
|
||||
UniqueId.fromHexString("0011223344556677889900112233445566778899");
|
||||
private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
|
||||
private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket";
|
||||
|
||||
@BeforeClass
|
||||
public void setUp() {
|
||||
System.setProperty("ray.driver.id", DRIVER_ID.toString());
|
||||
System.setProperty("ray.job.id", JOB_ID.toString());
|
||||
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
|
||||
System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME);
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public void tearDown() {
|
||||
System.clearProperty("ray.driver.id");
|
||||
System.clearProperty("ray.job.id");
|
||||
System.clearProperty("ray.raylet.socket-name");
|
||||
System.clearProperty("ray.object-store.socket-name");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRuntimeContextInDriver() {
|
||||
Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId());
|
||||
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
|
||||
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
|
||||
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
|
||||
Ray.getRuntimeContext().getObjectStoreSocketName());
|
||||
@@ -42,7 +42,7 @@ public class RuntimeContextTest extends BaseTest {
|
||||
public static class RuntimeContextTester {
|
||||
|
||||
public String testRuntimeContext(UniqueId actorId) {
|
||||
Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId());
|
||||
Assert.assertEquals(JOB_ID, Ray.getRuntimeContext().getCurrentJobId());
|
||||
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
|
||||
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
|
||||
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
|
||||
|
||||
@@ -56,7 +56,8 @@ from ray._raylet import (
|
||||
ActorID,
|
||||
ClientID,
|
||||
Config as _Config,
|
||||
DriverID,
|
||||
JobID,
|
||||
WorkerID,
|
||||
FunctionID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
@@ -141,7 +142,8 @@ __all__ += [
|
||||
"ActorHandleID",
|
||||
"ActorID",
|
||||
"ClientID",
|
||||
"DriverID",
|
||||
"JobID",
|
||||
"WorkerID",
|
||||
"FunctionID",
|
||||
"ObjectID",
|
||||
"TaskID",
|
||||
|
||||
@@ -221,13 +221,13 @@ cdef class RayletClient:
|
||||
def __cinit__(self, raylet_socket,
|
||||
ClientID client_id,
|
||||
c_bool is_worker,
|
||||
DriverID driver_id):
|
||||
JobID job_id):
|
||||
# We know that we are using Python, so just skip the language
|
||||
# parameter.
|
||||
# TODO(suquark): Should we allow unicode chars in "raylet_socket"?
|
||||
self.client.reset(new CRayletClient(
|
||||
raylet_socket.encode("ascii"), client_id.native(), is_worker,
|
||||
driver_id.native(), LANGUAGE_PYTHON))
|
||||
job_id.native(), LANGUAGE_PYTHON))
|
||||
|
||||
def disconnect(self):
|
||||
check_status(self.client.get().Disconnect())
|
||||
@@ -293,9 +293,9 @@ cdef class RayletClient:
|
||||
postincrement(iterator)
|
||||
return resources_dict
|
||||
|
||||
def push_error(self, DriverID driver_id, error_type, error_message,
|
||||
def push_error(self, JobID job_id, error_type, error_message,
|
||||
double timestamp):
|
||||
check_status(self.client.get().PushError(driver_id.native(),
|
||||
check_status(self.client.get().PushError(job_id.native(),
|
||||
error_type.encode("ascii"),
|
||||
error_message.encode("ascii"),
|
||||
timestamp))
|
||||
@@ -381,8 +381,8 @@ cdef class RayletClient:
|
||||
return ClientID(self.client.get().GetClientID().Binary())
|
||||
|
||||
@property
|
||||
def driver_id(self):
|
||||
return DriverID(self.client.get().GetDriverID().Binary())
|
||||
def job_id(self):
|
||||
return JobID(self.client.get().GetJobID().Binary())
|
||||
|
||||
@property
|
||||
def is_worker(self):
|
||||
|
||||
+19
-21
@@ -17,8 +17,7 @@ from ray.function_manager import FunctionDescriptor
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.signature as signature
|
||||
import ray.worker
|
||||
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID,
|
||||
DriverID)
|
||||
from ray import (ObjectID, ActorID, ActorHandleID, ActorClassID, TaskID)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -186,7 +185,7 @@ class ActorClass(object):
|
||||
task.
|
||||
_resources: The default resources required by the actor creation task.
|
||||
_actor_method_cpus: The number of CPUs required by actor method tasks.
|
||||
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
|
||||
_last_job_id_exported_for: The ID of the job of the last Ray
|
||||
session during which this actor class definition was exported. This
|
||||
is an imperfect mechanism used to determine if we need to export
|
||||
the remote function again. It is imperfect in the sense that the
|
||||
@@ -212,7 +211,7 @@ class ActorClass(object):
|
||||
self._num_cpus = num_cpus
|
||||
self._num_gpus = num_gpus
|
||||
self._resources = resources
|
||||
self._last_driver_id_exported_for = None
|
||||
self._last_job_id_exported_for = None
|
||||
|
||||
self._actor_methods = inspect.getmembers(
|
||||
self._modified_class, ray.utils.is_function_or_method)
|
||||
@@ -345,13 +344,12 @@ class ActorClass(object):
|
||||
*copy.deepcopy(args), **copy.deepcopy(kwargs))
|
||||
else:
|
||||
# Export the actor.
|
||||
if (self._last_driver_id_exported_for is None
|
||||
or self._last_driver_id_exported_for !=
|
||||
worker.task_driver_id):
|
||||
if (self._last_job_id_exported_for is None or
|
||||
self._last_job_id_exported_for != worker.current_job_id):
|
||||
# If this actor class was exported in a previous session, we
|
||||
# need to export this function again, because current GCS
|
||||
# doesn't have it.
|
||||
self._last_driver_id_exported_for = worker.task_driver_id
|
||||
self._last_job_id_exported_for = worker.current_job_id
|
||||
worker.function_actor_manager.export_actor_class(
|
||||
self._modified_class, self._actor_method_names)
|
||||
|
||||
@@ -389,7 +387,7 @@ class ActorClass(object):
|
||||
actor_id, self._modified_class.__module__, self._class_name,
|
||||
actor_cursor, self._actor_method_names, self._method_decorators,
|
||||
self._method_signatures, self._actor_method_num_return_vals,
|
||||
actor_cursor, actor_method_cpu, worker.task_driver_id)
|
||||
actor_cursor, actor_method_cpu, worker.current_job_id)
|
||||
# We increment the actor counter by 1 to account for the actor creation
|
||||
# task.
|
||||
actor_handle._ray_actor_counter += 1
|
||||
@@ -446,9 +444,9 @@ class ActorHandle(object):
|
||||
_ray_original_handle: True if this is the original actor handle for a
|
||||
given actor. If this is true, then the actor will be destroyed when
|
||||
this handle goes out of scope.
|
||||
_ray_actor_driver_id: The driver ID of the job that created the actor
|
||||
(it is possible that this ActorHandle exists on a driver with a
|
||||
different driver ID).
|
||||
_ray_actor_job_id: The ID of the job that created the actor
|
||||
(it is possible that this ActorHandle exists on a job with a
|
||||
different job ID).
|
||||
_ray_new_actor_handles: The new actor handles that were created from
|
||||
this handle since the last task on this handle was submitted. This
|
||||
is used to garbage-collect dummy objects that are no longer
|
||||
@@ -466,10 +464,10 @@ class ActorHandle(object):
|
||||
method_num_return_vals,
|
||||
actor_creation_dummy_object_id,
|
||||
actor_method_cpus,
|
||||
actor_driver_id,
|
||||
actor_job_id,
|
||||
actor_handle_id=None):
|
||||
assert isinstance(actor_id, ActorID)
|
||||
assert isinstance(actor_driver_id, DriverID)
|
||||
assert isinstance(actor_job_id, ray.JobID)
|
||||
self._ray_actor_id = actor_id
|
||||
self._ray_module_name = module_name
|
||||
# False if this actor handle was created by forking or pickling. True
|
||||
@@ -491,7 +489,7 @@ class ActorHandle(object):
|
||||
self._ray_actor_creation_dummy_object_id = (
|
||||
actor_creation_dummy_object_id)
|
||||
self._ray_actor_method_cpus = actor_method_cpus
|
||||
self._ray_actor_driver_id = actor_driver_id
|
||||
self._ray_actor_job_id = actor_job_id
|
||||
self._ray_new_actor_handles = []
|
||||
self._ray_actor_lock = threading.Lock()
|
||||
|
||||
@@ -551,7 +549,7 @@ class ActorHandle(object):
|
||||
num_return_vals=num_return_vals + 1,
|
||||
resources={"CPU": self._ray_actor_method_cpus},
|
||||
placement_resources={},
|
||||
driver_id=self._ray_actor_driver_id,
|
||||
job_id=self._ray_actor_job_id,
|
||||
)
|
||||
# Update the actor counter and cursor to reflect the most recent
|
||||
# invocation.
|
||||
@@ -612,7 +610,7 @@ class ActorHandle(object):
|
||||
# not just the first one.
|
||||
worker = ray.worker.get_global_worker()
|
||||
if (worker.mode == ray.worker.SCRIPT_MODE
|
||||
and self._ray_actor_driver_id.binary() != worker.worker_id):
|
||||
and self._ray_actor_job_id.binary() != worker.worker_id):
|
||||
# If the worker is a driver and driver id has changed because
|
||||
# Ray was shut down re-initialized, the actor is already cleaned up
|
||||
# and we don't need to send `__ray_terminate__` again.
|
||||
@@ -666,7 +664,7 @@ class ActorHandle(object):
|
||||
"actor_creation_dummy_object_id": self.
|
||||
_ray_actor_creation_dummy_object_id,
|
||||
"actor_method_cpus": self._ray_actor_method_cpus,
|
||||
"actor_driver_id": self._ray_actor_driver_id,
|
||||
"actor_job_id": self._ray_actor_job_id,
|
||||
"ray_forking": ray_forking
|
||||
}
|
||||
|
||||
@@ -727,9 +725,9 @@ class ActorHandle(object):
|
||||
state["method_num_return_vals"],
|
||||
state["actor_creation_dummy_object_id"],
|
||||
state["actor_method_cpus"],
|
||||
# This is the driver ID of the driver that owns the actor, not
|
||||
# necessarily the driver that owns this actor handle.
|
||||
state["actor_driver_id"],
|
||||
# This is the ID of the job that owns the actor, not
|
||||
# necessarily the job that owns this actor handle.
|
||||
state["actor_job_id"],
|
||||
actor_handle_id=actor_handle_id)
|
||||
|
||||
def __getstate__(self):
|
||||
|
||||
@@ -277,9 +277,9 @@ class FunctionActorManager(object):
|
||||
the worker gets connected.
|
||||
_actors_to_export: The actors to export when the worker gets
|
||||
connected.
|
||||
_function_execution_info: The map from driver_id to finction_id
|
||||
_function_execution_info: The map from job_id to function_id
|
||||
and execution_info.
|
||||
_num_task_executions: The map from driver_id to function
|
||||
_num_task_executions: The map from job_id to function
|
||||
execution times.
|
||||
imported_actor_classes: The set of actor classes keys (format:
|
||||
ActorClass:function_id) that are already in GCS.
|
||||
@@ -303,17 +303,17 @@ class FunctionActorManager(object):
|
||||
self._loaded_actor_classes = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def increase_task_counter(self, driver_id, function_descriptor):
|
||||
def increase_task_counter(self, job_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
self._num_task_executions[driver_id][function_id] += 1
|
||||
job_id = ray.JobID.nil()
|
||||
self._num_task_executions[job_id][function_id] += 1
|
||||
|
||||
def get_task_counter(self, driver_id, function_descriptor):
|
||||
def get_task_counter(self, job_id, function_descriptor):
|
||||
function_id = function_descriptor.function_id
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
return self._num_task_executions[driver_id][function_id]
|
||||
job_id = ray.JobID.nil()
|
||||
return self._num_task_executions[job_id][function_id]
|
||||
|
||||
def export_cached(self):
|
||||
"""Export cached remote functions
|
||||
@@ -376,11 +376,11 @@ class FunctionActorManager(object):
|
||||
check_oversized_pickle(pickled_function,
|
||||
remote_function._function_name,
|
||||
"remote function", self._worker)
|
||||
key = (b"RemoteFunction:" + self._worker.task_driver_id.binary() + b":"
|
||||
key = (b"RemoteFunction:" + self._worker.current_job_id.binary() + b":"
|
||||
+ remote_function._function_descriptor.function_id.binary())
|
||||
self._worker.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self._worker.task_driver_id.binary(),
|
||||
"job_id": self._worker.current_job_id.binary(),
|
||||
"function_id": remote_function._function_descriptor.
|
||||
function_id.binary(),
|
||||
"name": remote_function._function_name,
|
||||
@@ -392,14 +392,14 @@ class FunctionActorManager(object):
|
||||
|
||||
def fetch_and_register_remote_function(self, key):
|
||||
"""Import a remote function."""
|
||||
(driver_id_str, function_id_str, function_name, serialized_function,
|
||||
(job_id_str, function_id_str, function_name, serialized_function,
|
||||
num_return_vals, module, resources,
|
||||
max_calls) = self._worker.redis_client.hmget(key, [
|
||||
"driver_id", "function_id", "name", "function", "num_return_vals",
|
||||
"job_id", "function_id", "name", "function", "num_return_vals",
|
||||
"module", "resources", "max_calls"
|
||||
])
|
||||
function_id = ray.FunctionID(function_id_str)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
job_id = ray.JobID(job_id_str)
|
||||
function_name = decode(function_name)
|
||||
max_calls = int(max_calls)
|
||||
module = decode(module)
|
||||
@@ -413,12 +413,12 @@ class FunctionActorManager(object):
|
||||
# atomic. Otherwise, there is race condition. Another thread may use
|
||||
# the temporary function above before the real function is ready.
|
||||
with self.lock:
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=f,
|
||||
function_name=function_name,
|
||||
max_calls=max_calls))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
|
||||
try:
|
||||
function = pickle.loads(serialized_function)
|
||||
@@ -434,7 +434,7 @@ class FunctionActorManager(object):
|
||||
"Failed to unpickle the remote function '{}' with "
|
||||
"function ID {}. Traceback:\n{}".format(
|
||||
function_name, function_id.hex(), traceback_str),
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
else:
|
||||
# The below line is necessary. Because in the driver process,
|
||||
# if the function is defined in the file where the python
|
||||
@@ -442,7 +442,7 @@ class FunctionActorManager(object):
|
||||
# However in the worker process, the `__main__` module is a
|
||||
# different module, which is `default_worker.py`
|
||||
function.__module__ = module
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
@@ -452,11 +452,11 @@ class FunctionActorManager(object):
|
||||
b"FunctionTable:" + function_id.binary(),
|
||||
self._worker.worker_id)
|
||||
|
||||
def get_execution_info(self, driver_id, function_descriptor):
|
||||
def get_execution_info(self, job_id, function_descriptor):
|
||||
"""Get the FunctionExecutionInfo of a remote function.
|
||||
|
||||
Args:
|
||||
driver_id: ID of the driver that the function belongs to.
|
||||
job_id: ID of the job that the function belongs to.
|
||||
function_descriptor: The FunctionDescriptor of the function to get.
|
||||
|
||||
Returns:
|
||||
@@ -464,11 +464,11 @@ class FunctionActorManager(object):
|
||||
"""
|
||||
if self._worker.load_code_from_local:
|
||||
# Load function from local code.
|
||||
# Currently, we don't support isolating code by drivers,
|
||||
# thus always set driver ID to NIL here.
|
||||
driver_id = ray.DriverID.nil()
|
||||
# Currently, we don't support isolating code by jobs,
|
||||
# thus always set job ID to NIL here.
|
||||
job_id = ray.JobID.nil()
|
||||
if not function_descriptor.is_actor_method():
|
||||
self._load_function_from_local(driver_id, function_descriptor)
|
||||
self._load_function_from_local(job_id, function_descriptor)
|
||||
else:
|
||||
# Load function from GCS.
|
||||
# Wait until the function to be executed has actually been
|
||||
@@ -477,21 +477,21 @@ class FunctionActorManager(object):
|
||||
# The driver function may not be found in sys.path. Try to load
|
||||
# the function from GCS.
|
||||
with profiling.profile("wait_for_function"):
|
||||
self._wait_for_function(function_descriptor, driver_id)
|
||||
self._wait_for_function(function_descriptor, job_id)
|
||||
try:
|
||||
function_id = function_descriptor.function_id
|
||||
info = self._function_execution_info[driver_id][function_id]
|
||||
info = self._function_execution_info[job_id][function_id]
|
||||
except KeyError as e:
|
||||
message = ("Error occurs in get_execution_info: "
|
||||
"driver_id: %s, function_descriptor: %s. Message: %s" %
|
||||
(driver_id, function_descriptor, e))
|
||||
"job_id: %s, function_descriptor: %s. Message: %s" %
|
||||
(job_id, function_descriptor, e))
|
||||
raise KeyError(message)
|
||||
return info
|
||||
|
||||
def _load_function_from_local(self, driver_id, function_descriptor):
|
||||
def _load_function_from_local(self, job_id, function_descriptor):
|
||||
assert not function_descriptor.is_actor_method()
|
||||
function_id = function_descriptor.function_id
|
||||
if (driver_id in self._function_execution_info
|
||||
if (job_id in self._function_execution_info
|
||||
and function_id in self._function_execution_info[function_id]):
|
||||
return
|
||||
module_name, function_name = (
|
||||
@@ -501,13 +501,13 @@ class FunctionActorManager(object):
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
function = getattr(module, function_name)._function
|
||||
self._function_execution_info[driver_id][function_id] = (
|
||||
self._function_execution_info[job_id][function_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=function,
|
||||
function_name=function_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to load function %s.".format(function_name))
|
||||
@@ -515,7 +515,7 @@ class FunctionActorManager(object):
|
||||
"Function {} failed to be loaded from local code.".format(
|
||||
function_descriptor))
|
||||
|
||||
def _wait_for_function(self, function_descriptor, driver_id, timeout=10):
|
||||
def _wait_for_function(self, function_descriptor, job_id, timeout=10):
|
||||
"""Wait until the function to be executed is present on this worker.
|
||||
|
||||
This method will simply loop until the import thread has imported the
|
||||
@@ -528,7 +528,7 @@ class FunctionActorManager(object):
|
||||
Args:
|
||||
function_descriptor : The FunctionDescriptor of the function that
|
||||
we want to execute.
|
||||
driver_id (str): The ID of the driver to push the error message to
|
||||
job_id (str): The ID of the job to push the error message to
|
||||
if this times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
@@ -538,7 +538,7 @@ class FunctionActorManager(object):
|
||||
with self.lock:
|
||||
if (self._worker.actor_id.is_nil()
|
||||
and (function_descriptor.function_id in
|
||||
self._function_execution_info[driver_id])):
|
||||
self._function_execution_info[job_id])):
|
||||
break
|
||||
elif not self._worker.actor_id.is_nil() and (
|
||||
self._worker.actor_id in self._worker.actors):
|
||||
@@ -553,7 +553,7 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.WAIT_FOR_FUNCTION_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
warning_sent = True
|
||||
time.sleep(0.001)
|
||||
|
||||
@@ -577,22 +577,22 @@ class FunctionActorManager(object):
|
||||
if self._worker.load_code_from_local:
|
||||
return
|
||||
function_descriptor = FunctionDescriptor.from_class(Class)
|
||||
# `task_driver_id` shouldn't be NIL, unless:
|
||||
# `current_job_id` shouldn't be NIL, unless:
|
||||
# 1) This worker isn't an actor;
|
||||
# 2) And a previous task started a background thread, which didn't
|
||||
# finish before the task finished, and still uses Ray API
|
||||
# after that.
|
||||
assert not self._worker.task_driver_id.is_nil(), (
|
||||
assert not self._worker.current_job_id.is_nil(), (
|
||||
"You might have started a background thread in a non-actor task, "
|
||||
"please make sure the thread finishes before the task finishes.")
|
||||
driver_id = self._worker.task_driver_id
|
||||
key = (b"ActorClass:" + driver_id.binary() + b":" +
|
||||
job_id = self._worker.current_job_id
|
||||
key = (b"ActorClass:" + job_id.binary() + b":" +
|
||||
function_descriptor.function_id.binary())
|
||||
actor_class_info = {
|
||||
"class_name": Class.__name__,
|
||||
"module": Class.__module__,
|
||||
"class": pickle.dumps(Class),
|
||||
"driver_id": driver_id.binary(),
|
||||
"job_id": job_id.binary(),
|
||||
"actor_method_names": json.dumps(list(actor_method_names))
|
||||
}
|
||||
|
||||
@@ -616,11 +616,11 @@ class FunctionActorManager(object):
|
||||
# within tasks. I tried to disable this, but it may be necessary
|
||||
# because of https://github.com/ray-project/ray/issues/1146.
|
||||
|
||||
def load_actor_class(self, driver_id, function_descriptor):
|
||||
def load_actor_class(self, job_id, function_descriptor):
|
||||
"""Load the actor class.
|
||||
|
||||
Args:
|
||||
driver_id: Driver ID of the actor.
|
||||
job_id: job ID of the actor.
|
||||
function_descriptor: Function descriptor of the actor constructor.
|
||||
|
||||
Returns:
|
||||
@@ -632,14 +632,14 @@ class FunctionActorManager(object):
|
||||
if actor_class is None:
|
||||
# Load actor class.
|
||||
if self._worker.load_code_from_local:
|
||||
driver_id = ray.DriverID.nil()
|
||||
job_id = ray.JobID.nil()
|
||||
# Load actor class from local code.
|
||||
actor_class = self._load_actor_from_local(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
else:
|
||||
# Load actor class from GCS.
|
||||
actor_class = self._load_actor_class_from_gcs(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
# Save the loaded actor class in cache.
|
||||
self._loaded_actor_classes[function_id] = actor_class
|
||||
|
||||
@@ -657,18 +657,19 @@ class FunctionActorManager(object):
|
||||
actor_method,
|
||||
actor_imported=True,
|
||||
)
|
||||
self._function_execution_info[driver_id][method_id] = (
|
||||
self._function_execution_info[job_id][method_id] = (
|
||||
FunctionExecutionInfo(
|
||||
function=executor,
|
||||
function_name=actor_method_name,
|
||||
max_calls=0,
|
||||
))
|
||||
self._num_task_executions[driver_id][method_id] = 0
|
||||
self._num_task_executions[driver_id][function_id] = 0
|
||||
self._num_task_executions[job_id][method_id] = 0
|
||||
self._num_task_executions[job_id][function_id] = 0
|
||||
return actor_class
|
||||
|
||||
def _load_actor_from_local(self, driver_id, function_descriptor):
|
||||
def _load_actor_from_local(self, job_id, function_descriptor):
|
||||
"""Load actor class from local code."""
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
module_name, class_name = (function_descriptor.module_name,
|
||||
function_descriptor.class_name)
|
||||
try:
|
||||
@@ -699,9 +700,9 @@ class FunctionActorManager(object):
|
||||
|
||||
return TemporaryActor
|
||||
|
||||
def _load_actor_class_from_gcs(self, driver_id, function_descriptor):
|
||||
def _load_actor_class_from_gcs(self, job_id, function_descriptor):
|
||||
"""Load actor class from GCS."""
|
||||
key = (b"ActorClass:" + driver_id.binary() + b":" +
|
||||
key = (b"ActorClass:" + job_id.binary() + b":" +
|
||||
function_descriptor.function_id.binary())
|
||||
# Wait for the actor class key to have been imported by the
|
||||
# import thread. TODO(rkn): It shouldn't be possible to end
|
||||
@@ -711,16 +712,14 @@ class FunctionActorManager(object):
|
||||
time.sleep(0.001)
|
||||
|
||||
# Fetch raw data from GCS.
|
||||
(driver_id_str, class_name, module, pickled_class,
|
||||
(job_id_str, class_name, module, pickled_class,
|
||||
actor_method_names) = self._worker.redis_client.hmget(
|
||||
key, [
|
||||
"driver_id", "class_name", "module", "class",
|
||||
"actor_method_names"
|
||||
])
|
||||
key,
|
||||
["job_id", "class_name", "module", "class", "actor_method_names"])
|
||||
|
||||
class_name = ensure_str(class_name)
|
||||
module_name = ensure_str(module)
|
||||
driver_id = ray.DriverID(driver_id_str)
|
||||
job_id = ray.JobID(job_id_str)
|
||||
actor_method_names = json.loads(ensure_str(actor_method_names))
|
||||
|
||||
actor_class = None
|
||||
@@ -741,11 +740,12 @@ class FunctionActorManager(object):
|
||||
traceback.format_exc())
|
||||
# Log the error message.
|
||||
push_error_to_driver(
|
||||
self._worker, ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
self._worker,
|
||||
ray_constants.REGISTER_ACTOR_PUSH_ERROR,
|
||||
"Failed to unpickle actor class '{}' for actor ID {}. "
|
||||
"Traceback:\n{}".format(class_name,
|
||||
self._worker.actor_id.hex(),
|
||||
traceback_str), driver_id)
|
||||
"Traceback:\n{}".format(
|
||||
class_name, self._worker.actor_id.hex(), traceback_str),
|
||||
job_id=job_id)
|
||||
# TODO(rkn): In the future, it might make sense to have the worker
|
||||
# exit here. However, currently that would lead to hanging if
|
||||
# someone calls ray.get on a method invoked on the actor.
|
||||
@@ -859,7 +859,7 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
job_id=self._worker.current_job_id)
|
||||
|
||||
def _restore_and_log_checkpoint(self, actor):
|
||||
"""Restore an actor from a checkpoint if available and log any errors.
|
||||
@@ -898,4 +898,4 @@ class FunctionActorManager(object):
|
||||
self._worker,
|
||||
ray_constants.CHECKPOINT_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=self._worker.task_driver_id)
|
||||
job_id=self._worker.current_job_id)
|
||||
|
||||
@@ -7,7 +7,7 @@ from ray.core.generated.ray.protocol.Task import Task
|
||||
from ray.core.generated.gcs_pb2 import (
|
||||
ActorCheckpointIdData,
|
||||
ClientTableData,
|
||||
DriverTableData,
|
||||
JobTableData,
|
||||
ErrorTableData,
|
||||
ErrorType,
|
||||
GcsEntry,
|
||||
@@ -23,7 +23,7 @@ from ray.core.generated.gcs_pb2 import (
|
||||
__all__ = [
|
||||
"ActorCheckpointIdData",
|
||||
"ClientTableData",
|
||||
"DriverTableData",
|
||||
"JobTableData",
|
||||
"ErrorTableData",
|
||||
"ErrorType",
|
||||
"GcsEntry",
|
||||
@@ -48,8 +48,8 @@ XRAY_HEARTBEAT_CHANNEL = str(
|
||||
XRAY_HEARTBEAT_BATCH_CHANNEL = str(
|
||||
TablePubsub.Value("HEARTBEAT_BATCH_PUBSUB")).encode("ascii")
|
||||
|
||||
# xray driver updates
|
||||
XRAY_DRIVER_CHANNEL = str(TablePubsub.Value("DRIVER_PUBSUB")).encode("ascii")
|
||||
# xray job updates
|
||||
XRAY_JOB_CHANNEL = str(TablePubsub.Value("JOB_PUBSUB")).encode("ascii")
|
||||
|
||||
# These prefixes must be kept up-to-date with the TablePrefix enum in
|
||||
# gcs.proto.
|
||||
@@ -61,11 +61,11 @@ TablePrefix_ERROR_INFO_string = "ERROR_INFO"
|
||||
TablePrefix_PROFILE_string = "PROFILE"
|
||||
|
||||
|
||||
def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
def construct_error_message(job_id, error_type, message, timestamp):
|
||||
"""Construct a serialized ErrorTableData object.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that the error should go to. If this is
|
||||
job_id: The ID of the job that the error should go to. If this is
|
||||
nil, then the error will go to all drivers.
|
||||
error_type: The type of the error.
|
||||
message: The error message.
|
||||
@@ -75,7 +75,7 @@ def construct_error_message(driver_id, error_type, message, timestamp):
|
||||
The serialized object.
|
||||
"""
|
||||
data = ErrorTableData()
|
||||
data.driver_id = driver_id.binary()
|
||||
data.job_id = job_id.binary()
|
||||
data.type = error_type
|
||||
data.error_message = message
|
||||
data.timestamp = timestamp
|
||||
|
||||
@@ -114,13 +114,13 @@ class ImportThread(object):
|
||||
|
||||
def fetch_and_execute_function_to_run(self, key):
|
||||
"""Run on arbitrary function on the worker."""
|
||||
(driver_id, serialized_function,
|
||||
(job_id, serialized_function,
|
||||
run_on_other_drivers) = self.redis_client.hmget(
|
||||
key, ["driver_id", "function", "run_on_other_drivers"])
|
||||
key, ["job_id", "function", "run_on_other_drivers"])
|
||||
|
||||
if (utils.decode(run_on_other_drivers) == "False"
|
||||
and self.worker.mode == ray.SCRIPT_MODE
|
||||
and driver_id != self.worker.task_driver_id.binary()):
|
||||
and job_id != self.worker.current_job_id.binary()):
|
||||
return
|
||||
|
||||
try:
|
||||
@@ -140,4 +140,4 @@ class ImportThread(object):
|
||||
self.worker,
|
||||
ray_constants.FUNCTION_TO_RUN_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=ray.DriverID(driver_id))
|
||||
job_id=ray.JobID(job_id))
|
||||
|
||||
@@ -6,7 +6,8 @@ from libcpp.unordered_map cimport unordered_map
|
||||
from libcpp.vector cimport vector as c_vector
|
||||
|
||||
from ray.includes.unique_ids cimport (
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CWorkerID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -81,7 +82,7 @@ cdef extern from "ray/common/status.h" namespace "ray::StatusCode" nogil:
|
||||
|
||||
|
||||
cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
const CTaskID GenerateTaskId(const CDriverID &driver_id,
|
||||
const CTaskID GenerateTaskId(const CJobID &job_id,
|
||||
const CTaskID &parent_task_id,
|
||||
int parent_task_counter)
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ from ray.includes.unique_ids cimport (
|
||||
CActorCheckpointID,
|
||||
CActorID,
|
||||
CClientID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CWorkerID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -46,7 +47,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
cdef cppclass CRayletClient "RayletClient":
|
||||
CRayletClient(const c_string &raylet_socket,
|
||||
const CClientID &client_id,
|
||||
c_bool is_worker, const CDriverID &driver_id,
|
||||
c_bool is_worker, const CJobID &job_id,
|
||||
const CLanguage &language)
|
||||
CRayStatus Disconnect()
|
||||
CRayStatus SubmitTask(
|
||||
@@ -62,7 +63,7 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
int num_returns, int64_t timeout_milliseconds,
|
||||
c_bool wait_local, const CTaskID ¤t_task_id,
|
||||
WaitResultPair *result)
|
||||
CRayStatus PushError(const CDriverID &driver_id, const c_string &type,
|
||||
CRayStatus PushError(const CJobID &job_id, const c_string &type,
|
||||
const c_string &error_message, double timestamp)
|
||||
CRayStatus PushProfileEvents(
|
||||
const GCSProfileTableDataT &profile_events)
|
||||
@@ -75,6 +76,6 @@ cdef extern from "ray/raylet/raylet_client.h" nogil:
|
||||
CRayStatus SetResource(const c_string &resource_name, const double capacity, const CClientID &client_Id)
|
||||
CLanguage GetLanguage() const
|
||||
CClientID GetClientID() const
|
||||
CDriverID GetDriverID() const
|
||||
CJobID GetJobID() const
|
||||
c_bool IsWorker() const
|
||||
const ResourceMappingType &GetResourceIDs() const
|
||||
|
||||
@@ -12,7 +12,7 @@ from ray.includes.common cimport (
|
||||
from ray.includes.unique_ids cimport (
|
||||
CActorHandleID,
|
||||
CActorID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
)
|
||||
@@ -46,7 +46,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
|
||||
cdef cppclass CTaskSpecification "ray::raylet::TaskSpecification":
|
||||
CTaskSpecification(
|
||||
const CDriverID &driver_id, const CTaskID &parent_task_id,
|
||||
const CJobID &job_id, const CTaskID &parent_task_id,
|
||||
int64_t parent_counter,
|
||||
const c_vector[shared_ptr[CTaskArgument]] &task_arguments,
|
||||
int64_t num_returns,
|
||||
@@ -54,7 +54,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
const CLanguage &language,
|
||||
const c_vector[c_string] &function_descriptor)
|
||||
CTaskSpecification(
|
||||
const CDriverID &driver_id, const CTaskID &parent_task_id,
|
||||
const CJobID &job_id, const CTaskID &parent_task_id,
|
||||
int64_t parent_counter, const CActorID &actor_creation_id,
|
||||
const CObjectID &actor_creation_dummy_object_id,
|
||||
int64_t max_actor_reconstructions, const CActorID &actor_id,
|
||||
@@ -70,7 +70,7 @@ cdef extern from "ray/raylet/task_spec.h" namespace "ray::raylet" nogil:
|
||||
c_string SerializeAsString() const
|
||||
|
||||
CTaskID TaskId() const
|
||||
CDriverID DriverId() const
|
||||
CJobID JobId() const
|
||||
CTaskID ParentTaskId() const
|
||||
int64_t ParentCounter() const
|
||||
c_vector[c_string] FunctionDescriptor() const
|
||||
|
||||
@@ -18,7 +18,7 @@ cdef class Task:
|
||||
unique_ptr[CTaskSpecification] task_spec
|
||||
unique_ptr[c_vector[CObjectID]] execution_dependencies
|
||||
|
||||
def __init__(self, DriverID driver_id, function_descriptor, arguments,
|
||||
def __init__(self, JobID job_id, function_descriptor, arguments,
|
||||
int num_returns, TaskID parent_task_id, int parent_counter,
|
||||
ActorID actor_creation_id,
|
||||
ObjectID actor_creation_dummy_object_id,
|
||||
@@ -72,7 +72,7 @@ cdef class Task:
|
||||
(<ActorHandleID?>new_actor_handle).native())
|
||||
|
||||
self.task_spec.reset(new CTaskSpecification(
|
||||
driver_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
|
||||
job_id.native(), parent_task_id.native(), parent_counter, actor_creation_id.native(),
|
||||
actor_creation_dummy_object_id.native(), max_actor_reconstructions, actor_id.native(),
|
||||
actor_handle_id.native(), actor_counter, task_new_actor_handles, task_args, num_returns,
|
||||
required_resources, required_placement_resources, LANGUAGE_PYTHON,
|
||||
@@ -122,9 +122,9 @@ cdef class Task:
|
||||
return SerializeTaskAsString(
|
||||
self.execution_dependencies.get(), self.task_spec.get())
|
||||
|
||||
def driver_id(self):
|
||||
"""Return the driver ID for this task."""
|
||||
return DriverID(self.task_spec.get().DriverId().Binary())
|
||||
def job_id(self):
|
||||
"""Return the job ID for this task."""
|
||||
return JobID(self.task_spec.get().JobId().Binary())
|
||||
|
||||
def task_id(self):
|
||||
"""Return the task ID for this task."""
|
||||
|
||||
@@ -78,10 +78,10 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil:
|
||||
@staticmethod
|
||||
CFunctionID FromBinary(const c_string &binary)
|
||||
|
||||
cdef cppclass CDriverID "ray::DriverID"(CUniqueID):
|
||||
cdef cppclass CJobID "ray::JobID"(CUniqueID):
|
||||
|
||||
@staticmethod
|
||||
CDriverID FromBinary(const c_string &binary)
|
||||
CJobID FromBinary(const c_string &binary)
|
||||
|
||||
cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]):
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ray.includes.unique_ids cimport (
|
||||
CActorID,
|
||||
CClientID,
|
||||
CConfigID,
|
||||
CDriverID,
|
||||
CJobID,
|
||||
CFunctionID,
|
||||
CObjectID,
|
||||
CTaskID,
|
||||
@@ -212,15 +212,23 @@ cdef class ClientID(UniqueID):
|
||||
return <CClientID>self.data
|
||||
|
||||
|
||||
cdef class DriverID(UniqueID):
|
||||
cdef class JobID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id)
|
||||
self.data = CDriverID.FromBinary(<c_string>id)
|
||||
self.data = CJobID.FromBinary(<c_string>id)
|
||||
|
||||
cdef CDriverID native(self):
|
||||
return <CDriverID>self.data
|
||||
cdef CJobID native(self):
|
||||
return <CJobID>self.data
|
||||
|
||||
cdef class WorkerID(UniqueID):
|
||||
|
||||
def __init__(self, id):
|
||||
check_id(id)
|
||||
self.data = CWorkerID.FromBinary(<c_string>id)
|
||||
|
||||
cdef CWorkerID native(self):
|
||||
return <CWorkerID>self.data
|
||||
|
||||
cdef class ActorID(UniqueID):
|
||||
|
||||
@@ -277,7 +285,8 @@ _ID_TYPES = [
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
DriverID,
|
||||
JobID,
|
||||
WorkerID,
|
||||
FunctionID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
|
||||
+23
-23
@@ -130,14 +130,14 @@ class Monitor(object):
|
||||
"Monitor: "
|
||||
"could not find ip for client {}".format(client_id))
|
||||
|
||||
def _xray_clean_up_entries_for_driver(self, driver_id):
|
||||
"""Remove this driver's object/task entries from redis.
|
||||
def _xray_clean_up_entries_for_job(self, job_id):
|
||||
"""Remove this job's object/task entries from redis.
|
||||
|
||||
Removes control-state entries of all tasks and task return
|
||||
objects belonging to the driver.
|
||||
|
||||
Args:
|
||||
driver_id: The driver id.
|
||||
job_id: The job id.
|
||||
"""
|
||||
|
||||
xray_task_table_prefix = (
|
||||
@@ -146,23 +146,23 @@ class Monitor(object):
|
||||
ray.gcs_utils.TablePrefix_OBJECT_string.encode("ascii"))
|
||||
|
||||
task_table_objects = ray.tasks()
|
||||
driver_id_hex = binary_to_hex(driver_id)
|
||||
driver_task_id_bins = set()
|
||||
job_id_hex = binary_to_hex(job_id)
|
||||
job_task_id_bins = set()
|
||||
for task_id_hex, task_info in task_table_objects.items():
|
||||
task_table_object = task_info["TaskSpec"]
|
||||
task_driver_id_hex = task_table_object["DriverID"]
|
||||
if driver_id_hex != task_driver_id_hex:
|
||||
task_job_id_hex = task_table_object["JobID"]
|
||||
if job_id_hex != task_job_id_hex:
|
||||
# Ignore tasks that aren't from this driver.
|
||||
continue
|
||||
driver_task_id_bins.add(hex_to_binary(task_id_hex))
|
||||
job_task_id_bins.add(hex_to_binary(task_id_hex))
|
||||
|
||||
# Get objects associated with the driver.
|
||||
object_table_objects = ray.objects()
|
||||
driver_object_id_bins = set()
|
||||
job_object_id_bins = set()
|
||||
for object_id, _ in object_table_objects.items():
|
||||
task_id_bin = ray._raylet.compute_task_id(object_id).binary()
|
||||
if task_id_bin in driver_task_id_bins:
|
||||
driver_object_id_bins.add(object_id.binary())
|
||||
if task_id_bin in job_task_id_bins:
|
||||
job_object_id_bins.add(object_id.binary())
|
||||
|
||||
def to_shard_index(id_bin):
|
||||
if len(id_bin) == ray.TaskID.size():
|
||||
@@ -174,10 +174,10 @@ class Monitor(object):
|
||||
|
||||
# Form the redis keys to delete.
|
||||
sharded_keys = [[] for _ in range(len(ray.state.state.redis_clients))]
|
||||
for task_id_bin in driver_task_id_bins:
|
||||
for task_id_bin in job_task_id_bins:
|
||||
sharded_keys[to_shard_index(task_id_bin)].append(
|
||||
xray_task_table_prefix + task_id_bin)
|
||||
for object_id_bin in driver_object_id_bins:
|
||||
for object_id_bin in job_object_id_bins:
|
||||
sharded_keys[to_shard_index(object_id_bin)].append(
|
||||
xray_object_table_prefix + object_id_bin)
|
||||
|
||||
@@ -198,21 +198,21 @@ class Monitor(object):
|
||||
"entries from redis shard {}.".format(
|
||||
len(keys) - num_deleted, shard_index))
|
||||
|
||||
def xray_driver_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a driver has been removed.
|
||||
def xray_job_removed_handler(self, unused_channel, data):
|
||||
"""Handle a notification that a job has been removed.
|
||||
|
||||
Args:
|
||||
unused_channel: The message channel.
|
||||
data: The message data.
|
||||
"""
|
||||
gcs_entries = ray.gcs_utils.GcsEntry.FromString(data)
|
||||
driver_data = gcs_entries.entries[0]
|
||||
message = ray.gcs_utils.DriverTableData.FromString(driver_data)
|
||||
driver_id = message.driver_id
|
||||
job_data = gcs_entries.entries[0]
|
||||
message = ray.gcs_utils.JobTableData.FromString(job_data)
|
||||
job_id = message.job_id
|
||||
logger.info("Monitor: "
|
||||
"XRay Driver {} has been removed.".format(
|
||||
binary_to_hex(driver_id)))
|
||||
self._xray_clean_up_entries_for_driver(driver_id)
|
||||
binary_to_hex(job_id)))
|
||||
self._xray_clean_up_entries_for_job(job_id)
|
||||
|
||||
def process_messages(self, max_messages=10000):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
@@ -240,9 +240,9 @@ class Monitor(object):
|
||||
if channel == ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL:
|
||||
# Similar functionality as raylet info channel
|
||||
message_handler = self.xray_heartbeat_batch_handler
|
||||
elif channel == ray.gcs_utils.XRAY_DRIVER_CHANNEL:
|
||||
elif channel == ray.gcs_utils.XRAY_JOB_CHANNEL:
|
||||
# Handles driver death.
|
||||
message_handler = self.xray_driver_removed_handler
|
||||
message_handler = self.xray_job_removed_handler
|
||||
else:
|
||||
raise Exception("This code should be unreachable.")
|
||||
|
||||
@@ -298,7 +298,7 @@ class Monitor(object):
|
||||
"""
|
||||
# Initialize the subscription channel.
|
||||
self.subscribe(ray.gcs_utils.XRAY_HEARTBEAT_BATCH_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.XRAY_DRIVER_CHANNEL)
|
||||
self.subscribe(ray.gcs_utils.XRAY_JOB_CHANNEL)
|
||||
|
||||
# TODO(rkn): If there were any dead clients at startup, we should clean
|
||||
# up the associated state in the state tables.
|
||||
|
||||
@@ -44,7 +44,7 @@ class RemoteFunction(object):
|
||||
return the resulting ObjectIDs. For an example, see
|
||||
"test_decorated_function" in "python/ray/tests/test_basic.py".
|
||||
_function_signature: The function signature.
|
||||
_last_driver_id_exported_for: The ID of the driver ID of the last Ray
|
||||
_last_job_id_exported_for: The ID of the job ID of the last Ray
|
||||
session during which this remote function definition was exported.
|
||||
This is an imperfect mechanism used to determine if we need to
|
||||
export the remote function again. It is imperfect in the sense that
|
||||
@@ -73,7 +73,7 @@ class RemoteFunction(object):
|
||||
self._function_signature = ray.signature.extract_signature(
|
||||
self._function)
|
||||
|
||||
self._last_driver_id_exported_for = None
|
||||
self._last_job_id_exported_for = None
|
||||
|
||||
# Override task.remote's signature and docstring
|
||||
@wraps(function)
|
||||
@@ -115,11 +115,11 @@ class RemoteFunction(object):
|
||||
worker = ray.worker.get_global_worker()
|
||||
worker.check_connected()
|
||||
|
||||
if (self._last_driver_id_exported_for is None
|
||||
or self._last_driver_id_exported_for != worker.task_driver_id):
|
||||
if (self._last_job_id_exported_for is None
|
||||
or self._last_job_id_exported_for != worker.current_job_id):
|
||||
# If this function was exported in a previous session, we need to
|
||||
# export this function again, because current GCS doesn't have it.
|
||||
self._last_driver_id_exported_for = worker.task_driver_id
|
||||
self._last_job_id_exported_for = worker.current_job_id
|
||||
worker.function_actor_manager.export(self)
|
||||
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
@@ -20,7 +20,7 @@ class RuntimeContext(object):
|
||||
a task, return the driver ID of the associated driver.
|
||||
"""
|
||||
assert self.worker is not None
|
||||
return self.worker.task_driver_id
|
||||
return self.worker.current_job_id
|
||||
|
||||
|
||||
_runtime_context = None
|
||||
|
||||
+19
-20
@@ -316,7 +316,7 @@ class GlobalState(object):
|
||||
function_descriptor_list)
|
||||
|
||||
task_spec_info = {
|
||||
"DriverID": task.driver_id().hex(),
|
||||
"JobID": task.job_id().hex(),
|
||||
"TaskID": task.task_id().hex(),
|
||||
"ParentTaskID": task.parent_task_id().hex(),
|
||||
"ParentCounter": task.parent_counter(),
|
||||
@@ -817,19 +817,19 @@ class GlobalState(object):
|
||||
|
||||
return dict(total_available_resources)
|
||||
|
||||
def _error_messages(self, driver_id):
|
||||
def _error_messages(self, job_id):
|
||||
"""Get the error messages for a specific driver.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver to get the errors for.
|
||||
job_id: The ID of the job to get the errors for.
|
||||
|
||||
Returns:
|
||||
A list of the error messages for this driver.
|
||||
"""
|
||||
assert isinstance(driver_id, ray.DriverID)
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
message = self.redis_client.execute_command(
|
||||
"RAY.TABLE_LOOKUP", gcs_utils.TablePrefix.Value("ERROR_INFO"), "",
|
||||
driver_id.binary())
|
||||
job_id.binary())
|
||||
|
||||
# If there are no errors, return early.
|
||||
if message is None:
|
||||
@@ -839,7 +839,7 @@ class GlobalState(object):
|
||||
error_messages = []
|
||||
for entry in gcs_entries.entries:
|
||||
error_data = gcs_utils.ErrorTableData.FromString(entry)
|
||||
assert driver_id.binary() == error_data.driver_id
|
||||
assert job_id.binary() == error_data.job_id
|
||||
error_message = {
|
||||
"type": error_data.type,
|
||||
"message": error_data.error_message,
|
||||
@@ -848,12 +848,12 @@ class GlobalState(object):
|
||||
error_messages.append(error_message)
|
||||
return error_messages
|
||||
|
||||
def error_messages(self, driver_id=None):
|
||||
def error_messages(self, job_id=None):
|
||||
"""Get the error messages for all drivers or a specific driver.
|
||||
|
||||
Args:
|
||||
driver_id: The specific driver to get the errors for. If this is
|
||||
None, then this method retrieves the errors for all drivers.
|
||||
job_id: The specific job to get the errors for. If this is
|
||||
None, then this method retrieves the errors for all jobs.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping driver ID to a list of the error messages for
|
||||
@@ -861,21 +861,20 @@ class GlobalState(object):
|
||||
"""
|
||||
self._check_connected()
|
||||
|
||||
if driver_id is not None:
|
||||
assert isinstance(driver_id, ray.DriverID)
|
||||
return self._error_messages(driver_id)
|
||||
if job_id is not None:
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
return self._error_messages(job_id)
|
||||
|
||||
error_table_keys = self.redis_client.keys(
|
||||
gcs_utils.TablePrefix_ERROR_INFO_string + "*")
|
||||
driver_ids = [
|
||||
job_ids = [
|
||||
key[len(gcs_utils.TablePrefix_ERROR_INFO_string):]
|
||||
for key in error_table_keys
|
||||
]
|
||||
|
||||
return {
|
||||
binary_to_hex(driver_id): self._error_messages(
|
||||
ray.DriverID(driver_id))
|
||||
for driver_id in driver_ids
|
||||
binary_to_hex(job_id): self._error_messages(ray.JobID(job_id))
|
||||
for job_id in job_ids
|
||||
}
|
||||
|
||||
def actor_checkpoint_info(self, actor_id):
|
||||
@@ -969,12 +968,12 @@ class DeprecatedGlobalState(object):
|
||||
"instead.")
|
||||
return ray.available_resources()
|
||||
|
||||
def error_messages(self, driver_id=None):
|
||||
def error_messages(self, job_id=None):
|
||||
logger.warning(
|
||||
"ray.global_state.error_messages() is deprecated and will be "
|
||||
"removed in a subsequent release. Use ray.errors() "
|
||||
"instead.")
|
||||
return ray.errors(driver_id=driver_id)
|
||||
return ray.errors(job_id=job_id)
|
||||
|
||||
|
||||
state = GlobalState()
|
||||
@@ -1095,7 +1094,7 @@ def errors(include_cluster_errors=True):
|
||||
Error messages pushed from the cluster.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
error_messages = state.error_messages(driver_id=worker.task_driver_id)
|
||||
error_messages = state.error_messages(job_id=worker.current_job_id)
|
||||
if include_cluster_errors:
|
||||
error_messages += state.error_messages(driver_id=ray.DriverID.nil())
|
||||
error_messages += state.error_messages(job_id=ray.JobID.nil())
|
||||
return error_messages
|
||||
|
||||
@@ -2439,7 +2439,7 @@ def test_global_state_api(shutdown_only):
|
||||
|
||||
assert ray.objects() == {}
|
||||
|
||||
driver_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
|
||||
job_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id)
|
||||
driver_task_id = ray.worker.global_worker.current_task_id.hex()
|
||||
|
||||
# One task is put in the task table which corresponds to this driver.
|
||||
@@ -2453,7 +2453,7 @@ def test_global_state_api(shutdown_only):
|
||||
assert task_spec["TaskID"] == driver_task_id
|
||||
assert task_spec["ActorID"] == nil_id_hex
|
||||
assert task_spec["Args"] == []
|
||||
assert task_spec["DriverID"] == driver_id
|
||||
assert task_spec["JobID"] == job_id
|
||||
assert task_spec["FunctionID"] == nil_id_hex
|
||||
assert task_spec["ReturnObjectIDs"] == []
|
||||
|
||||
@@ -2481,7 +2481,7 @@ def test_global_state_api(shutdown_only):
|
||||
task_spec = task_table[task_id]["TaskSpec"]
|
||||
assert task_spec["ActorID"] == nil_id_hex
|
||||
assert task_spec["Args"] == [1, "hi", x_id]
|
||||
assert task_spec["DriverID"] == driver_id
|
||||
assert task_spec["JobID"] == job_id
|
||||
assert task_spec["ReturnObjectIDs"] == [result_id]
|
||||
|
||||
assert task_table[task_id] == ray.tasks(task_id)
|
||||
@@ -2613,9 +2613,9 @@ def test_workers(shutdown_only):
|
||||
worker_ids = set(ray.get([f.remote() for _ in range(10)]))
|
||||
|
||||
|
||||
def test_specific_driver_id():
|
||||
dummy_driver_id = ray.DriverID(b"00112233445566778899")
|
||||
ray.init(num_cpus=1, driver_id=dummy_driver_id)
|
||||
def test_specific_job_id():
|
||||
dummy_driver_id = ray.JobID(b"00112233445566778899")
|
||||
ray.init(num_cpus=1, job_id=dummy_driver_id)
|
||||
|
||||
# in driver
|
||||
assert dummy_driver_id == ray._get_runtime_context().current_driver_id
|
||||
@@ -2727,7 +2727,7 @@ def test_ray_setproctitle(ray_start_2_cpus):
|
||||
def test_duplicate_error_messages(shutdown_only):
|
||||
ray.init(num_cpus=0)
|
||||
|
||||
driver_id = ray.DriverID.nil()
|
||||
driver_id = ray.WorkerID.nil()
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, "test",
|
||||
"message", 0)
|
||||
|
||||
|
||||
+15
-14
@@ -51,7 +51,7 @@ def format_error_message(exception_message, task_exception=False):
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def push_error_to_driver(worker, error_type, message, driver_id=None):
|
||||
def push_error_to_driver(worker, error_type, message, job_id=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Args:
|
||||
@@ -59,19 +59,19 @@ def push_error_to_driver(worker, error_type, message, driver_id=None):
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
driver_id: The ID of the driver to push the error message to. If this
|
||||
job_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray.DriverID.nil()
|
||||
worker.raylet_client.push_error(driver_id, error_type, message,
|
||||
time.time())
|
||||
if job_id is None:
|
||||
job_id = ray.JobID.nil()
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
worker.raylet_client.push_error(job_id, error_type, message, time.time())
|
||||
|
||||
|
||||
def push_error_to_driver_through_redis(redis_client,
|
||||
error_type,
|
||||
message,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Push an error message to the driver to be printed in the background.
|
||||
|
||||
Normally the push_error_to_driver function should be used. However, in some
|
||||
@@ -84,19 +84,20 @@ def push_error_to_driver_through_redis(redis_client,
|
||||
error_type (str): The type of the error.
|
||||
message (str): The message that will be printed in the background
|
||||
on the driver.
|
||||
driver_id: The ID of the driver to push the error message to. If this
|
||||
job_id: The ID of the driver to push the error message to. If this
|
||||
is None, then the message will be pushed to all drivers.
|
||||
"""
|
||||
if driver_id is None:
|
||||
driver_id = ray.DriverID.nil()
|
||||
if job_id is None:
|
||||
job_id = ray.JobID.nil()
|
||||
assert isinstance(job_id, ray.JobID)
|
||||
# Do everything in Python and through the Python Redis client instead
|
||||
# of through the raylet.
|
||||
error_data = ray.gcs_utils.construct_error_message(driver_id, error_type,
|
||||
error_data = ray.gcs_utils.construct_error_message(job_id, error_type,
|
||||
message, time.time())
|
||||
redis_client.execute_command(
|
||||
"RAY.TABLE_APPEND", ray.gcs_utils.TablePrefix.Value("ERROR_INFO"),
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"),
|
||||
driver_id.binary(), error_data)
|
||||
ray.gcs_utils.TablePubsub.Value("ERROR_INFO_PUBSUB"), job_id.binary(),
|
||||
error_data)
|
||||
|
||||
|
||||
def is_cython(obj):
|
||||
@@ -443,7 +444,7 @@ def check_oversized_pickle(pickled, name, obj_type, worker):
|
||||
worker,
|
||||
ray_constants.PICKLING_LARGE_OBJECT_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=worker.task_driver_id)
|
||||
job_id=worker.current_job_id)
|
||||
|
||||
|
||||
class _ThreadSafeProxy(object):
|
||||
|
||||
+77
-80
@@ -40,7 +40,8 @@ from ray import (
|
||||
ActorHandleID,
|
||||
ActorID,
|
||||
ClientID,
|
||||
DriverID,
|
||||
WorkerID,
|
||||
JobID,
|
||||
ObjectID,
|
||||
TaskID,
|
||||
)
|
||||
@@ -145,9 +146,9 @@ class Worker(object):
|
||||
# TODO: clean up the SerializationContext once the job finished.
|
||||
self.serialization_context_map = {}
|
||||
self.function_actor_manager = FunctionActorManager(self)
|
||||
# Identity of the driver that this worker is processing.
|
||||
# It is a DriverID.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
# Identity of the job that this worker is processing.
|
||||
# It is a JobID.
|
||||
self.current_job_id = JobID.nil()
|
||||
self._task_context = threading.local()
|
||||
# This event is checked regularly by all of the threads so that they
|
||||
# know when to exit.
|
||||
@@ -227,24 +228,24 @@ class Worker(object):
|
||||
if self.actor_init_error is not None:
|
||||
raise self.actor_init_error
|
||||
|
||||
def get_serialization_context(self, driver_id):
|
||||
"""Get the SerializationContext of the driver that this worker is processing.
|
||||
def get_serialization_context(self, job_id):
|
||||
"""Get the SerializationContext of the job that this worker is processing.
|
||||
|
||||
Args:
|
||||
driver_id: The ID of the driver that indicates which driver to get
|
||||
job_id: The ID of the job that indicates which job to get
|
||||
the serialization context for.
|
||||
|
||||
Returns:
|
||||
The serialization context of the given driver.
|
||||
The serialization context of the given job.
|
||||
"""
|
||||
# This function needs to be proctected by a lock, because it will be
|
||||
# called by`register_class_for_serialization`, as well as the import
|
||||
# thread, from different threads. Also, this function will recursively
|
||||
# call itself, so we use RLock here.
|
||||
with self.lock:
|
||||
if driver_id not in self.serialization_context_map:
|
||||
_initialize_serialization(driver_id)
|
||||
return self.serialization_context_map[driver_id]
|
||||
if job_id not in self.serialization_context_map:
|
||||
_initialize_serialization(job_id)
|
||||
return self.serialization_context_map[job_id]
|
||||
|
||||
def check_connected(self):
|
||||
"""Check if the worker is connected.
|
||||
@@ -314,7 +315,7 @@ class Worker(object):
|
||||
object_id=pyarrow.plasma.ObjectID(object_id.binary()),
|
||||
memcopy_threads=self.memcopy_threads,
|
||||
serialization_context=self.get_serialization_context(
|
||||
self.task_driver_id))
|
||||
self.current_job_id))
|
||||
break
|
||||
except pyarrow.SerializationCallbackError as e:
|
||||
try:
|
||||
@@ -388,17 +389,17 @@ class Worker(object):
|
||||
# should return an error code to the caller instead of printing a
|
||||
# message.
|
||||
logger.info(
|
||||
"The object with ID {} already exists in the object store."
|
||||
.format(object_id))
|
||||
"The object with ID {} already exists in the object store.".
|
||||
format(object_id))
|
||||
except TypeError:
|
||||
# This error can happen because one of the members of the object
|
||||
# may not be serializable for cloudpickle. So we need these extra
|
||||
# fallbacks here to start from the beginning. Hopefully the object
|
||||
# could have a `__reduce__` method.
|
||||
register_custom_serializer(type(value), use_pickle=True)
|
||||
warning_message = ("WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle."
|
||||
.format(type(value)))
|
||||
warning_message = (
|
||||
"WARNING: Serializing the class {} failed, "
|
||||
"so are are falling back to cloudpickle.".format(type(value)))
|
||||
logger.warning(warning_message)
|
||||
self.store_and_register(object_id, value)
|
||||
|
||||
@@ -407,7 +408,7 @@ class Worker(object):
|
||||
# Only send the warning once.
|
||||
warning_sent = False
|
||||
serialization_context = self.get_serialization_context(
|
||||
self.task_driver_id)
|
||||
self.current_job_id)
|
||||
while True:
|
||||
try:
|
||||
# We divide very large get requests into smaller get requests
|
||||
@@ -449,7 +450,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.WAIT_FOR_CLASS_PUSH_ERROR,
|
||||
warning_message,
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
warning_sent = True
|
||||
|
||||
def _deserialize_object_from_arrow(self, data, metadata, object_id,
|
||||
@@ -575,7 +576,7 @@ class Worker(object):
|
||||
num_return_vals=None,
|
||||
resources=None,
|
||||
placement_resources=None,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Submit a remote task to the scheduler.
|
||||
|
||||
Tell the scheduler to schedule the execution of the function with
|
||||
@@ -601,11 +602,11 @@ class Worker(object):
|
||||
placement_resources: The resources required for placing the task.
|
||||
If this is not provided or if it is an empty dictionary, then
|
||||
the placement resources will be equal to resources.
|
||||
driver_id: The ID of the relevant driver. This is almost always the
|
||||
driver ID of the driver that is currently running. However, in
|
||||
job_id: The ID of the relevant job. This is almost always the
|
||||
job ID of the job that is currently running. However, in
|
||||
the exceptional case that an actor task is being dispatched to
|
||||
an actor created by a different driver, this should be the
|
||||
driver ID of the driver that created the actor.
|
||||
an actor created by a different job, this should be the
|
||||
job ID of the job that created the actor.
|
||||
|
||||
Returns:
|
||||
The return object IDs for this task.
|
||||
@@ -642,8 +643,8 @@ class Worker(object):
|
||||
if new_actor_handles is None:
|
||||
new_actor_handles = []
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = self.task_driver_id
|
||||
if job_id is None:
|
||||
job_id = self.current_job_id
|
||||
|
||||
if resources is None:
|
||||
raise ValueError("The resources dictionary is required.")
|
||||
@@ -674,13 +675,13 @@ class Worker(object):
|
||||
assert not self.current_task_id.is_nil()
|
||||
# Current driver id must not be nil when submitting a task.
|
||||
# Because every task must belong to a driver.
|
||||
assert not self.task_driver_id.is_nil()
|
||||
assert not self.current_job_id.is_nil()
|
||||
# Submit the task to raylet.
|
||||
function_descriptor_list = (
|
||||
function_descriptor.get_function_descriptor_list())
|
||||
assert isinstance(driver_id, DriverID)
|
||||
assert isinstance(job_id, JobID)
|
||||
task = ray._raylet.Task(
|
||||
driver_id,
|
||||
job_id,
|
||||
function_descriptor_list,
|
||||
args_for_raylet,
|
||||
num_return_vals,
|
||||
@@ -747,7 +748,7 @@ class Worker(object):
|
||||
# Run the function on all workers.
|
||||
self.redis_client.hmset(
|
||||
key, {
|
||||
"driver_id": self.task_driver_id.binary(),
|
||||
"job_id": self.current_job_id.binary(),
|
||||
"function_id": function_to_run_id,
|
||||
"function": pickled_function,
|
||||
"run_on_other_drivers": str(run_on_other_drivers)
|
||||
@@ -853,17 +854,17 @@ class Worker(object):
|
||||
assert self.task_context.task_index == 0
|
||||
assert self.task_context.put_index == 1
|
||||
if task.actor_id().is_nil():
|
||||
# If this worker is not an actor, check that `task_driver_id`
|
||||
# If this worker is not an actor, check that `current_job_id`
|
||||
# was reset when the worker finished the previous task.
|
||||
assert self.task_driver_id.is_nil()
|
||||
assert self.current_job_id.is_nil()
|
||||
# Set the driver ID of the current running task. This is
|
||||
# needed so that if the task throws an exception, we propagate
|
||||
# the error message to the correct driver.
|
||||
self.task_driver_id = task.driver_id()
|
||||
self.current_job_id = task.job_id()
|
||||
else:
|
||||
# If this worker is an actor, task_driver_id wasn't reset.
|
||||
# If this worker is an actor, current_job_id wasn't reset.
|
||||
# Check that current task's driver ID equals the previous one.
|
||||
assert self.task_driver_id == task.driver_id()
|
||||
assert self.current_job_id == task.job_id()
|
||||
|
||||
self.task_context.current_task_id = task.task_id()
|
||||
|
||||
@@ -945,7 +946,7 @@ class Worker(object):
|
||||
self,
|
||||
ray_constants.TASK_PUSH_ERROR,
|
||||
str(failure_object),
|
||||
driver_id=self.task_driver_id)
|
||||
job_id=self.current_job_id)
|
||||
# Mark the actor init as failed
|
||||
if not self.actor_id.is_nil() and function_name == "__init__":
|
||||
self.mark_actor_init_failed(error)
|
||||
@@ -960,7 +961,7 @@ class Worker(object):
|
||||
"""
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
driver_id = task.driver_id()
|
||||
job_id = task.job_id()
|
||||
|
||||
# TODO(rkn): It would be preferable for actor creation tasks to share
|
||||
# more of the code path with regular task execution.
|
||||
@@ -969,7 +970,7 @@ class Worker(object):
|
||||
self.actor_id = task.actor_creation_id()
|
||||
self.actor_creation_task_id = task.task_id()
|
||||
actor_class = self.function_actor_manager.load_actor_class(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
self.actors[self.actor_id] = actor_class.__new__(actor_class)
|
||||
self.actor_checkpoint_info[self.actor_id] = ActorCheckpointInfo(
|
||||
num_tasks_since_last_checkpoint=0,
|
||||
@@ -978,7 +979,7 @@ class Worker(object):
|
||||
)
|
||||
|
||||
execution_info = self.function_actor_manager.get_execution_info(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
# Execute the task.
|
||||
function_name = execution_info.function_name
|
||||
@@ -1005,20 +1006,20 @@ class Worker(object):
|
||||
self.task_context.task_index = 0
|
||||
self.task_context.put_index = 1
|
||||
if self.actor_id.is_nil():
|
||||
# Don't need to reset task_driver_id if the worker is an
|
||||
# Don't need to reset `current_job_id` if the worker is an
|
||||
# actor. Because the following tasks should all have the
|
||||
# same driver id.
|
||||
self.task_driver_id = DriverID.nil()
|
||||
self.current_job_id = WorkerID.nil()
|
||||
# Reset signal counters so that the next task can get
|
||||
# all past signals.
|
||||
ray_signal.reset()
|
||||
|
||||
# Increase the task execution counter.
|
||||
self.function_actor_manager.increase_task_counter(
|
||||
driver_id, function_descriptor)
|
||||
job_id, function_descriptor)
|
||||
|
||||
reached_max_executions = (self.function_actor_manager.get_task_counter(
|
||||
driver_id, function_descriptor) == execution_info.max_calls)
|
||||
job_id, function_descriptor) == execution_info.max_calls)
|
||||
if reached_max_executions:
|
||||
self.raylet_client.disconnect()
|
||||
sys.exit(0)
|
||||
@@ -1141,7 +1142,7 @@ def print_failed_task(task_status):
|
||||
task_status["error_message"]))
|
||||
|
||||
|
||||
def _initialize_serialization(driver_id, worker=global_worker):
|
||||
def _initialize_serialization(job_id, worker=global_worker):
|
||||
"""Initialize the serialization library.
|
||||
|
||||
This defines a custom serializer for object IDs and also tells ray to
|
||||
@@ -1177,7 +1178,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
custom_serializer=actor_handle_serializer,
|
||||
custom_deserializer=actor_handle_deserializer)
|
||||
|
||||
worker.serialization_context_map[driver_id] = serialization_context
|
||||
worker.serialization_context_map[job_id] = serialization_context
|
||||
|
||||
# Register exception types.
|
||||
for error_cls in RAY_EXCEPTION_TYPES:
|
||||
@@ -1185,7 +1186,7 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
error_cls,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id=error_cls.__module__ + ". " + error_cls.__name__,
|
||||
)
|
||||
# Tell Ray to serialize lambdas with pickle.
|
||||
@@ -1193,22 +1194,18 @@ def _initialize_serialization(driver_id, worker=global_worker):
|
||||
type(lambda: 0),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="lambda")
|
||||
# Tell Ray to serialize types with pickle.
|
||||
register_custom_serializer(
|
||||
type(int),
|
||||
use_pickle=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
class_id="type")
|
||||
type(int), use_pickle=True, local=True, job_id=job_id, class_id="type")
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
register_custom_serializer(
|
||||
ray.signature.FunctionSignature,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
driver_id=driver_id,
|
||||
job_id=job_id,
|
||||
class_id="ray.signature.FunctionSignature")
|
||||
|
||||
|
||||
@@ -1231,7 +1228,7 @@ def init(redis_address=None,
|
||||
plasma_directory=None,
|
||||
huge_pages=False,
|
||||
include_webui=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
configure_logging=True,
|
||||
logging_level=logging.INFO,
|
||||
logging_format=ray_constants.LOGGER_FORMAT,
|
||||
@@ -1302,7 +1299,7 @@ def init(redis_address=None,
|
||||
Store with hugetlbfs support. Requires plasma_directory.
|
||||
include_webui: Boolean flag indicating whether to start the web
|
||||
UI, which displays the status of the Ray cluster.
|
||||
driver_id: The ID of driver.
|
||||
job_id: The ID of this job.
|
||||
configure_logging: True if allow the logging cofiguration here.
|
||||
Otherwise, the users may want to configure it by their own.
|
||||
logging_level: Logging level, default will be logging.INFO.
|
||||
@@ -1449,7 +1446,7 @@ def init(redis_address=None,
|
||||
mode=driver_mode,
|
||||
log_to_driver=log_to_driver,
|
||||
worker=global_worker,
|
||||
driver_id=driver_id)
|
||||
job_id=job_id)
|
||||
|
||||
for hook in _post_init_hooks:
|
||||
hook()
|
||||
@@ -1660,10 +1657,10 @@ def listen_error_messages_raylet(worker, task_error_queue, threads_stopped):
|
||||
assert len(gcs_entry.entries) == 1
|
||||
error_data = ray.gcs_utils.ErrorTableData.FromString(
|
||||
gcs_entry.entries[0])
|
||||
driver_id = error_data.driver_id
|
||||
if driver_id not in [
|
||||
worker.task_driver_id.binary(),
|
||||
DriverID.nil().binary()
|
||||
job_id = error_data.job_id
|
||||
if job_id not in [
|
||||
worker.current_job_id.binary(),
|
||||
JobID.nil().binary()
|
||||
]:
|
||||
continue
|
||||
|
||||
@@ -1691,7 +1688,7 @@ def connect(node,
|
||||
mode=WORKER_MODE,
|
||||
log_to_driver=False,
|
||||
worker=global_worker,
|
||||
driver_id=None):
|
||||
job_id=None):
|
||||
"""Connect this worker to the raylet, to Plasma, and to Redis.
|
||||
|
||||
Args:
|
||||
@@ -1701,7 +1698,7 @@ def connect(node,
|
||||
log_to_driver (bool): If true, then output from all of the worker
|
||||
processes on all nodes will be directed to the driver.
|
||||
worker: The ray.Worker instance.
|
||||
driver_id: The ID of driver. If it's None, then we will generate one.
|
||||
job_id: The ID of job. If it's None, then we will generate one.
|
||||
"""
|
||||
# Do some basic checking to make sure we didn't call ray.init twice.
|
||||
error_message = "Perhaps you called ray.init twice by accident?"
|
||||
@@ -1721,20 +1718,20 @@ def connect(node,
|
||||
setproctitle.setproctitle("ray_worker")
|
||||
else:
|
||||
# This is the code path of driver mode.
|
||||
if driver_id is None:
|
||||
driver_id = DriverID.from_random()
|
||||
if job_id is None:
|
||||
job_id = JobID.from_random()
|
||||
|
||||
if not isinstance(driver_id, DriverID):
|
||||
raise TypeError("The type of given driver id must be DriverID.")
|
||||
if not isinstance(job_id, JobID):
|
||||
raise TypeError("The type of given job id must be JobID.")
|
||||
|
||||
worker.worker_id = driver_id.binary()
|
||||
worker.worker_id = job_id.binary()
|
||||
|
||||
# When tasks are executed on remote workers in the context of multiple
|
||||
# drivers, the task driver ID is used to keep track of which driver is
|
||||
# drivers, the current job ID is used to keep track of which driver is
|
||||
# responsible for the task so that error messages will be propagated to
|
||||
# the correct driver.
|
||||
if mode != WORKER_MODE:
|
||||
worker.task_driver_id = DriverID(worker.worker_id)
|
||||
worker.current_job_id = JobID(worker.worker_id)
|
||||
|
||||
# All workers start out as non-actors. A worker can be turned into an actor
|
||||
# after it is created.
|
||||
@@ -1766,7 +1763,7 @@ def connect(node,
|
||||
worker.redis_client,
|
||||
ray_constants.VERSION_MISMATCH_PUSH_ERROR,
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
job_id=None)
|
||||
|
||||
worker.lock = threading.RLock()
|
||||
|
||||
@@ -1831,7 +1828,7 @@ def connect(node,
|
||||
# Create an object store client.
|
||||
worker.plasma_client = thread_safe_client(
|
||||
plasma.connect(node.plasma_store_socket_name, None, 0, 300))
|
||||
driver_id_str = _random_string()
|
||||
job_id_str = _random_string()
|
||||
|
||||
# If this is a driver, set the current task ID, the task driver ID, and set
|
||||
# the task index to 0.
|
||||
@@ -1859,11 +1856,11 @@ def connect(node,
|
||||
|
||||
function_descriptor = FunctionDescriptor.for_driver_task()
|
||||
driver_task = ray._raylet.Task(
|
||||
worker.task_driver_id,
|
||||
worker.current_job_id,
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
[], # arguments.
|
||||
0, # num_returns.
|
||||
TaskID(driver_id_str[:TaskID.size()]), # parent_task_id.
|
||||
TaskID(job_id_str[:TaskID.size()]), # parent_task_id.
|
||||
0, # parent_counter.
|
||||
ActorID.nil(), # actor_creation_id.
|
||||
ObjectID.nil(), # actor_creation_dummy_object_id.
|
||||
@@ -1895,7 +1892,7 @@ def connect(node,
|
||||
node.raylet_socket_name,
|
||||
ClientID(worker.worker_id),
|
||||
(mode == WORKER_MODE),
|
||||
DriverID(driver_id_str),
|
||||
JobID(job_id_str),
|
||||
)
|
||||
|
||||
# Start the import thread
|
||||
@@ -2057,7 +2054,7 @@ def register_custom_serializer(cls,
|
||||
serializer=None,
|
||||
deserializer=None,
|
||||
local=False,
|
||||
driver_id=None,
|
||||
job_id=None,
|
||||
class_id=None):
|
||||
"""Enable serialization and deserialization for a particular class.
|
||||
|
||||
@@ -2078,7 +2075,7 @@ def register_custom_serializer(cls,
|
||||
if and only if use_pickle and use_dict are False.
|
||||
local: True if the serializers should only be registered on the current
|
||||
worker. This should usually be False.
|
||||
driver_id: ID of the driver that we want to register the class for.
|
||||
job_id: ID of the job that we want to register the class for.
|
||||
class_id: ID of the class that we are registering. If this is not
|
||||
specified, we will calculate a new one inside the function.
|
||||
|
||||
@@ -2126,9 +2123,9 @@ def register_custom_serializer(cls,
|
||||
# Make sure class_id is a string.
|
||||
class_id = ray.utils.binary_to_hex(class_id)
|
||||
|
||||
if driver_id is None:
|
||||
driver_id = worker.task_driver_id
|
||||
assert isinstance(driver_id, DriverID)
|
||||
if job_id is None:
|
||||
job_id = worker.current_job_id
|
||||
assert isinstance(job_id, JobID)
|
||||
|
||||
def register_class_for_serialization(worker_info):
|
||||
# TODO(rkn): We need to be more thoughtful about what to do if custom
|
||||
@@ -2138,7 +2135,7 @@ def register_custom_serializer(cls,
|
||||
# system.
|
||||
|
||||
serialization_context = worker_info[
|
||||
"worker"].get_serialization_context(driver_id)
|
||||
"worker"].get_serialization_context(job_id)
|
||||
serialization_context.register_type(
|
||||
cls,
|
||||
class_id,
|
||||
|
||||
@@ -102,7 +102,7 @@ if __name__ == "__main__":
|
||||
ray.worker.global_worker,
|
||||
"worker_crash",
|
||||
traceback_str,
|
||||
driver_id=None)
|
||||
job_id=None)
|
||||
# TODO(rkn): Note that if the worker was in the middle of executing
|
||||
# a task, then any worker or driver that is blocking in a get call
|
||||
# and waiting for the output of that task will hang. We need to
|
||||
|
||||
+13
-3
@@ -85,7 +85,7 @@ uint64_t MurmurHash64A(const void *key, int len, unsigned int seed) {
|
||||
return h;
|
||||
}
|
||||
|
||||
TaskID TaskID::GetDriverTaskID(const DriverID &driver_id) {
|
||||
TaskID TaskID::GetDriverTaskID(const WorkerID &driver_id) {
|
||||
std::string driver_id_str = driver_id.Binary();
|
||||
driver_id_str.resize(Size());
|
||||
return TaskID::FromBinary(driver_id_str);
|
||||
@@ -113,12 +113,12 @@ ObjectID ObjectID::ForTaskReturn(const TaskID &task_id, int64_t return_index) {
|
||||
return object_id;
|
||||
}
|
||||
|
||||
const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id,
|
||||
const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id,
|
||||
int parent_task_counter) {
|
||||
// Compute hashes.
|
||||
SHA256_CTX ctx;
|
||||
sha256_init(&ctx);
|
||||
sha256_update(&ctx, reinterpret_cast<const BYTE *>(driver_id.Data()), driver_id.Size());
|
||||
sha256_update(&ctx, reinterpret_cast<const BYTE *>(job_id.Data()), job_id.Size());
|
||||
sha256_update(&ctx, reinterpret_cast<const BYTE *>(parent_task_id.Data()),
|
||||
parent_task_id.Size());
|
||||
sha256_update(&ctx, (const BYTE *)&parent_task_counter, sizeof(parent_task_counter));
|
||||
@@ -129,6 +129,16 @@ const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task
|
||||
return TaskID::FromBinary(std::string(buff, buff + TaskID::Size()));
|
||||
}
|
||||
|
||||
const WorkerID ComputeDriverId(const JobID &job_id) {
|
||||
// Currently, a job id equals its driver id.
|
||||
return WorkerID(job_id);
|
||||
}
|
||||
|
||||
const JobID ComputeJobId(const WorkerID &driver_id) {
|
||||
// Currently, a job id equals its driver id.
|
||||
return JobID(driver_id);
|
||||
}
|
||||
|
||||
#define ID_OSTREAM_OPERATOR(id_type) \
|
||||
std::ostream &operator<<(std::ostream &os, const id_type &id) { \
|
||||
if (id.IsNil()) { \
|
||||
|
||||
+4
-4
@@ -17,7 +17,7 @@
|
||||
|
||||
namespace ray {
|
||||
|
||||
class DriverID;
|
||||
class WorkerID;
|
||||
class UniqueID;
|
||||
|
||||
// Declaration.
|
||||
@@ -72,7 +72,7 @@ class TaskID : public BaseID<TaskID> {
|
||||
public:
|
||||
TaskID() : BaseID() {}
|
||||
static size_t Size() { return kTaskIDSize; }
|
||||
static TaskID GetDriverTaskID(const DriverID &driver_id);
|
||||
static TaskID GetDriverTaskID(const WorkerID &driver_id);
|
||||
|
||||
private:
|
||||
uint8_t id_[kTaskIDSize];
|
||||
@@ -152,11 +152,11 @@ std::ostream &operator<<(std::ostream &os, const ObjectID &id);
|
||||
|
||||
/// Generate a task ID from the given info.
|
||||
///
|
||||
/// \param driver_id The driver that creates the task.
|
||||
/// \param job_id The job that creates the task.
|
||||
/// \param parent_task_id The parent task of this task.
|
||||
/// \param parent_task_counter The task index of the worker.
|
||||
/// \return The task ID generated from the given info.
|
||||
const TaskID GenerateTaskId(const DriverID &driver_id, const TaskID &parent_task_id,
|
||||
const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id,
|
||||
int parent_task_counter);
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -10,6 +10,6 @@ DEFINE_UNIQUE_ID(ActorID)
|
||||
DEFINE_UNIQUE_ID(ActorHandleID)
|
||||
DEFINE_UNIQUE_ID(ActorCheckpointID)
|
||||
DEFINE_UNIQUE_ID(WorkerID)
|
||||
DEFINE_UNIQUE_ID(DriverID)
|
||||
DEFINE_UNIQUE_ID(JobID)
|
||||
DEFINE_UNIQUE_ID(ConfigID)
|
||||
DEFINE_UNIQUE_ID(ClientID)
|
||||
|
||||
@@ -38,35 +38,34 @@ struct WorkerThreadContext {
|
||||
thread_local std::unique_ptr<WorkerThreadContext> WorkerContext::thread_context_ =
|
||||
nullptr;
|
||||
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const DriverID &driver_id)
|
||||
WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id)
|
||||
: worker_type(worker_type),
|
||||
worker_id(worker_type == WorkerType::DRIVER
|
||||
? ClientID::FromBinary(driver_id.Binary())
|
||||
: ClientID::FromRandom()),
|
||||
current_driver_id(worker_type == WorkerType::DRIVER ? driver_id : DriverID::Nil()) {
|
||||
worker_id(worker_type == WorkerType::DRIVER ? WorkerID::FromBinary(job_id.Binary())
|
||||
: WorkerID::FromRandom()),
|
||||
current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) {
|
||||
// For worker main thread which initializes the WorkerContext,
|
||||
// set task_id according to whether current worker is a driver.
|
||||
// (For other threads it's set to randmom ID via GetThreadContext).
|
||||
// (For other threads it's set to random ID via GetThreadContext).
|
||||
GetThreadContext().SetCurrentTask(
|
||||
(worker_type == WorkerType::DRIVER) ? TaskID::FromRandom() : TaskID::Nil());
|
||||
}
|
||||
|
||||
const WorkerType WorkerContext::GetWorkerType() const { return worker_type; }
|
||||
|
||||
const ClientID &WorkerContext::GetWorkerID() const { return worker_id; }
|
||||
const WorkerID &WorkerContext::GetWorkerID() const { return worker_id; }
|
||||
|
||||
int WorkerContext::GetNextTaskIndex() { return GetThreadContext().GetNextTaskIndex(); }
|
||||
|
||||
int WorkerContext::GetNextPutIndex() { return GetThreadContext().GetNextPutIndex(); }
|
||||
|
||||
const DriverID &WorkerContext::GetCurrentDriverID() const { return current_driver_id; }
|
||||
const JobID &WorkerContext::GetCurrentJobID() const { return current_job_id; }
|
||||
|
||||
const TaskID &WorkerContext::GetCurrentTaskID() const {
|
||||
return GetThreadContext().GetCurrentTaskID();
|
||||
}
|
||||
|
||||
void WorkerContext::SetCurrentTask(const raylet::TaskSpecification &spec) {
|
||||
current_driver_id = spec.DriverId();
|
||||
current_job_id = spec.JobId();
|
||||
GetThreadContext().SetCurrentTask(spec);
|
||||
}
|
||||
|
||||
|
||||
@@ -10,13 +10,13 @@ struct WorkerThreadContext;
|
||||
|
||||
class WorkerContext {
|
||||
public:
|
||||
WorkerContext(WorkerType worker_type, const DriverID &driver_id);
|
||||
WorkerContext(WorkerType worker_type, const JobID &job_id);
|
||||
|
||||
const WorkerType GetWorkerType() const;
|
||||
|
||||
const ClientID &GetWorkerID() const;
|
||||
const WorkerID &GetWorkerID() const;
|
||||
|
||||
const DriverID &GetCurrentDriverID() const;
|
||||
const JobID &GetCurrentJobID() const;
|
||||
|
||||
const TaskID &GetCurrentTaskID() const;
|
||||
|
||||
@@ -31,10 +31,10 @@ class WorkerContext {
|
||||
const WorkerType worker_type;
|
||||
|
||||
/// ID for this worker.
|
||||
const ClientID worker_id;
|
||||
const WorkerID worker_id;
|
||||
|
||||
/// Driver ID for this worker.
|
||||
DriverID current_driver_id;
|
||||
/// Job ID for this worker.
|
||||
JobID current_job_id;
|
||||
|
||||
private:
|
||||
static WorkerThreadContext &GetThreadContext();
|
||||
|
||||
@@ -6,15 +6,16 @@ namespace ray {
|
||||
CoreWorker::CoreWorker(const enum WorkerType worker_type,
|
||||
const enum WorkerLanguage language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
DriverID driver_id)
|
||||
const JobID &job_id)
|
||||
: worker_type_(worker_type),
|
||||
language_(language),
|
||||
store_socket_(store_socket),
|
||||
raylet_socket_(raylet_socket),
|
||||
worker_context_(worker_type, driver_id),
|
||||
raylet_client_(raylet_socket_, worker_context_.GetWorkerID(),
|
||||
worker_context_(worker_type, job_id),
|
||||
raylet_client_(raylet_socket_,
|
||||
ClientID::FromBinary(worker_context_.GetWorkerID().Binary()),
|
||||
(worker_type_ == ray::WorkerType::WORKER),
|
||||
worker_context_.GetCurrentDriverID(), ToTaskLanguage(language_)),
|
||||
worker_context_.GetCurrentJobID(), ToTaskLanguage(language_)),
|
||||
task_interface_(*this),
|
||||
object_interface_(*this),
|
||||
task_execution_interface_(*this) {
|
||||
|
||||
@@ -24,7 +24,7 @@ class CoreWorker {
|
||||
/// NOTE(zhijunfu): the constructor would throw if a failure happens.
|
||||
CoreWorker(const WorkerType worker_type, const WorkerLanguage language,
|
||||
const std::string &store_socket, const std::string &raylet_socket,
|
||||
DriverID driver_id = DriverID::Nil());
|
||||
const JobID &job_id = JobID::Nil());
|
||||
|
||||
/// Type of this worker.
|
||||
enum WorkerType WorkerType() const { return worker_type_; }
|
||||
|
||||
@@ -126,7 +126,7 @@ class CoreWorkerTest : public ::testing::Test {
|
||||
void TestNormalTask(const std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
DriverID::FromRandom());
|
||||
JobID::FromRandom());
|
||||
|
||||
// Test pass by value.
|
||||
{
|
||||
@@ -184,7 +184,7 @@ class CoreWorkerTest : public ::testing::Test {
|
||||
void TestActorTask(const std::unordered_map<std::string, double> &resources) {
|
||||
CoreWorker driver(WorkerType::DRIVER, WorkerLanguage::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
DriverID::FromRandom());
|
||||
JobID::FromRandom());
|
||||
|
||||
std::unique_ptr<ActorHandle> actor_handle;
|
||||
|
||||
@@ -275,9 +275,9 @@ TEST_F(ZeroNodeTest, TestTaskArg) {
|
||||
}
|
||||
|
||||
TEST_F(ZeroNodeTest, TestWorkerContext) {
|
||||
auto driver_id = DriverID::FromRandom();
|
||||
auto job_id = JobID::FromRandom();
|
||||
|
||||
WorkerContext context(WorkerType::WORKER, driver_id);
|
||||
WorkerContext context(WorkerType::WORKER, job_id);
|
||||
ASSERT_TRUE(context.GetCurrentTaskID().IsNil());
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 1);
|
||||
ASSERT_EQ(context.GetNextTaskIndex(), 2);
|
||||
@@ -302,7 +302,7 @@ TEST_F(ZeroNodeTest, TestWorkerContext) {
|
||||
TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
DriverID::FromRandom());
|
||||
JobID::FromRandom());
|
||||
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
@@ -358,11 +358,11 @@ TEST_F(SingleNodeTest, TestObjectInterface) {
|
||||
TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) {
|
||||
CoreWorker worker1(WorkerType::DRIVER, WorkerLanguage::PYTHON,
|
||||
raylet_store_socket_names_[0], raylet_socket_names_[0],
|
||||
DriverID::FromRandom());
|
||||
JobID::FromRandom());
|
||||
|
||||
CoreWorker worker2(WorkerType::DRIVER, WorkerLanguage::PYTHON,
|
||||
raylet_store_socket_names_[1], raylet_socket_names_[1],
|
||||
DriverID::FromRandom());
|
||||
JobID::FromRandom());
|
||||
|
||||
uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8};
|
||||
uint8_t array2[] = {10, 11, 12, 13, 14, 15};
|
||||
@@ -446,7 +446,7 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) {
|
||||
TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) {
|
||||
try {
|
||||
CoreWorker core_worker(WorkerType::DRIVER, WorkerLanguage::PYTHON, "",
|
||||
raylet_socket_names_[0], DriverID::FromRandom());
|
||||
raylet_socket_names_[0], JobID::FromRandom());
|
||||
} catch (const std::exception &e) {
|
||||
std::cout << "Caught exception when constructing core worker: " << e.what();
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ class MockWorker {
|
||||
public:
|
||||
MockWorker(const std::string &store_socket, const std::string &raylet_socket)
|
||||
: worker_(WorkerType::WORKER, WorkerLanguage::PYTHON, store_socket, raylet_socket,
|
||||
DriverID::FromRandom()) {}
|
||||
JobID::FromRandom()) {}
|
||||
|
||||
void Run() {
|
||||
auto executor_func = [this](const RayFunction &ray_function,
|
||||
|
||||
@@ -20,7 +20,7 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
|
||||
std::vector<ObjectID> *return_ids) {
|
||||
auto &context = core_worker_.worker_context_;
|
||||
auto next_task_index = context.GetNextTaskIndex();
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentDriverID(),
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
|
||||
context.GetCurrentTaskID(), next_task_index);
|
||||
|
||||
auto num_returns = task_options.num_returns;
|
||||
@@ -32,7 +32,7 @@ Status CoreWorkerTaskInterface::SubmitTask(const RayFunction &function,
|
||||
auto task_arguments = BuildTaskArguments(args);
|
||||
auto language = core_worker_.ToTaskLanguage(function.language);
|
||||
|
||||
ray::raylet::TaskSpecification spec(context.GetCurrentDriverID(),
|
||||
ray::raylet::TaskSpecification spec(context.GetCurrentJobID(),
|
||||
context.GetCurrentTaskID(), next_task_index,
|
||||
task_arguments, num_returns, task_options.resources,
|
||||
language, function.function_descriptor);
|
||||
@@ -48,7 +48,7 @@ Status CoreWorkerTaskInterface::CreateActor(
|
||||
std::unique_ptr<ActorHandle> *actor_handle) {
|
||||
auto &context = core_worker_.worker_context_;
|
||||
auto next_task_index = context.GetNextTaskIndex();
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentDriverID(),
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
|
||||
context.GetCurrentTaskID(), next_task_index);
|
||||
|
||||
std::vector<ObjectID> return_ids;
|
||||
@@ -66,7 +66,7 @@ Status CoreWorkerTaskInterface::CreateActor(
|
||||
// Note that the caller is supposed to specify required placement resources
|
||||
// correctly via actor_creation_options.resources.
|
||||
ray::raylet::TaskSpecification spec(
|
||||
context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index,
|
||||
context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
|
||||
actor_creation_id, ObjectID::Nil(), actor_creation_options.max_reconstructions,
|
||||
ActorID::Nil(), ActorHandleID::Nil(), 0, {}, task_arguments, 1,
|
||||
actor_creation_options.resources, actor_creation_options.resources, language,
|
||||
@@ -84,7 +84,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle,
|
||||
std::vector<ObjectID> *return_ids) {
|
||||
auto &context = core_worker_.worker_context_;
|
||||
auto next_task_index = context.GetNextTaskIndex();
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentDriverID(),
|
||||
const auto task_id = GenerateTaskId(context.GetCurrentJobID(),
|
||||
context.GetCurrentTaskID(), next_task_index);
|
||||
|
||||
// add one for actor cursor object id.
|
||||
@@ -102,7 +102,7 @@ Status CoreWorkerTaskInterface::SubmitActorTask(ActorHandle &actor_handle,
|
||||
|
||||
std::vector<ActorHandleID> new_actor_handles;
|
||||
ray::raylet::TaskSpecification spec(
|
||||
context.GetCurrentDriverID(), context.GetCurrentTaskID(), next_task_index,
|
||||
context.GetCurrentJobID(), context.GetCurrentTaskID(), next_task_index,
|
||||
ActorID::Nil(), actor_creation_dummy_object_id, 0, actor_handle.ActorID(),
|
||||
actor_handle.ActorHandleID(), actor_handle.IncreaseTaskCounter(), new_actor_handles,
|
||||
task_arguments, num_returns, task_options.resources, task_options.resources,
|
||||
|
||||
@@ -109,7 +109,7 @@ AsyncGcsClient::AsyncGcsClient(const std::string &address, int port,
|
||||
actor_table_.reset(new ActorTable({primary_context_}, this));
|
||||
client_table_.reset(new ClientTable({primary_context_}, this, client_id));
|
||||
error_table_.reset(new ErrorTable({primary_context_}, this));
|
||||
driver_table_.reset(new DriverTable({primary_context_}, this));
|
||||
job_table_.reset(new JobTable({primary_context_}, this));
|
||||
heartbeat_batch_table_.reset(new HeartbeatBatchTable({primary_context_}, this));
|
||||
// Tables below would be sharded.
|
||||
object_table_.reset(new ObjectTable(shard_contexts_, this));
|
||||
@@ -188,7 +188,7 @@ std::string AsyncGcsClient::DebugString() const {
|
||||
result << "\n- ErrorTable: " << error_table_->DebugString();
|
||||
result << "\n- ProfileTable: " << profile_table_->DebugString();
|
||||
result << "\n- ClientTable: " << client_table_->DebugString();
|
||||
result << "\n- DriverTable: " << driver_table_->DebugString();
|
||||
result << "\n- JobTable: " << job_table_->DebugString();
|
||||
return result.str();
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ HeartbeatBatchTable &AsyncGcsClient::heartbeat_batch_table() {
|
||||
|
||||
ErrorTable &AsyncGcsClient::error_table() { return *error_table_; }
|
||||
|
||||
DriverTable &AsyncGcsClient::driver_table() { return *driver_table_; }
|
||||
JobTable &AsyncGcsClient::job_table() { return *job_table_; }
|
||||
|
||||
ProfileTable &AsyncGcsClient::profile_table() { return *profile_table_; }
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
HeartbeatTable &heartbeat_table();
|
||||
HeartbeatBatchTable &heartbeat_batch_table();
|
||||
ErrorTable &error_table();
|
||||
DriverTable &driver_table();
|
||||
JobTable &job_table();
|
||||
ProfileTable &profile_table();
|
||||
ActorCheckpointTable &actor_checkpoint_table();
|
||||
ActorCheckpointIdTable &actor_checkpoint_id_table();
|
||||
@@ -64,8 +64,8 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
// driver (to set the PYTHONPATH)
|
||||
|
||||
using GetExportCallback = std::function<void(const std::string &data)>;
|
||||
Status AddExport(const std::string &driver_id, std::string &export_data);
|
||||
Status GetExport(const std::string &driver_id, int64_t export_index,
|
||||
Status AddExport(const std::string &job_id, std::string &export_data);
|
||||
Status GetExport(const std::string &job_id, int64_t export_index,
|
||||
const GetExportCallback &done_callback);
|
||||
|
||||
std::vector<std::shared_ptr<RedisContext>> shard_contexts() { return shard_contexts_; }
|
||||
@@ -96,7 +96,7 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
std::vector<std::unique_ptr<RedisAsioClient>> shard_asio_subscribe_clients_;
|
||||
// The following context writes everything to the primary shard
|
||||
std::shared_ptr<RedisContext> primary_context_;
|
||||
std::unique_ptr<DriverTable> driver_table_;
|
||||
std::unique_ptr<JobTable> job_table_;
|
||||
std::unique_ptr<RedisAsioClient> asio_async_auxiliary_client_;
|
||||
std::unique_ptr<RedisAsioClient> asio_subscribe_auxiliary_client_;
|
||||
CommandType command_type_;
|
||||
@@ -105,14 +105,14 @@ class RAY_EXPORT AsyncGcsClient {
|
||||
class SyncGcsClient {
|
||||
Status LogEvent(const std::string &key, const std::string &value, double timestamp);
|
||||
Status NotifyError(const std::map<std::string, std::string> &error_info);
|
||||
Status RegisterFunction(const DriverID &driver_id, const FunctionID &function_id,
|
||||
Status RegisterFunction(const JobID &job_id, const FunctionID &function_id,
|
||||
const std::string &language, const std::string &name,
|
||||
const std::string &data);
|
||||
Status RetrieveFunction(const DriverID &driver_id, const FunctionID &function_id,
|
||||
Status RetrieveFunction(const JobID &job_id, const FunctionID &function_id,
|
||||
std::string *name, std::string *data);
|
||||
|
||||
Status AddExport(const std::string &driver_id, std::string &export_data);
|
||||
Status GetExport(const std::string &driver_id, int64_t export_index, std::string *data);
|
||||
Status AddExport(const std::string &job_id, std::string &export_data);
|
||||
Status GetExport(const std::string &job_id, int64_t export_index, std::string *data);
|
||||
};
|
||||
|
||||
} // namespace gcs
|
||||
|
||||
+202
-214
File diff suppressed because it is too large
Load Diff
@@ -18,8 +18,8 @@ table Arg {
|
||||
}
|
||||
|
||||
table TaskInfo {
|
||||
// ID of the driver that created this task.
|
||||
driver_id: string;
|
||||
// ID of the job that created this task.
|
||||
job_id: string;
|
||||
// Task ID of the task.
|
||||
task_id: string;
|
||||
// Task ID of the parent task.
|
||||
|
||||
+50
-51
@@ -39,7 +39,7 @@ namespace ray {
|
||||
namespace gcs {
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Append(const DriverID &driver_id, const ID &id,
|
||||
Status Log<ID, Data>::Append(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) {
|
||||
num_appends_++;
|
||||
auto callback = [this, id, data, done](const CallbackReply &reply) {
|
||||
@@ -58,7 +58,7 @@ Status Log<ID, Data>::Append(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::AppendAt(const DriverID &driver_id, const ID &id,
|
||||
Status Log<ID, Data>::AppendAt(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done,
|
||||
const WriteCallback &failure, int log_length) {
|
||||
num_appends_++;
|
||||
@@ -81,8 +81,7 @@ Status Log<ID, Data>::AppendAt(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
const Callback &lookup) {
|
||||
Status Log<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup) {
|
||||
num_lookups_++;
|
||||
auto callback = [this, id, lookup](const CallbackReply &reply) {
|
||||
if (lookup != nullptr) {
|
||||
@@ -106,7 +105,7 @@ Status Log<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
auto subscribe_wrapper = [subscribe](AsyncGcsClient *client, const ID &id,
|
||||
@@ -115,11 +114,11 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
|
||||
RAY_CHECK(change_mode != GcsChangeMode::REMOVE);
|
||||
subscribe(client, id, data);
|
||||
};
|
||||
return Subscribe(driver_id, client_id, subscribe_wrapper, done);
|
||||
return Subscribe(job_id, client_id, subscribe_wrapper, done);
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Log<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const NotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
RAY_CHECK(subscribe_callback_index_ == -1)
|
||||
@@ -160,7 +159,7 @@ Status Log<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clien
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::RequestNotifications(const DriverID &driver_id, const ID &id,
|
||||
Status Log<ID, Data>::RequestNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) {
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client requested notifications on a key before Subscribe completed";
|
||||
@@ -170,7 +169,7 @@ Status Log<ID, Data>::RequestNotifications(const DriverID &driver_id, const ID &
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Log<ID, Data>::CancelNotifications(const DriverID &driver_id, const ID &id,
|
||||
Status Log<ID, Data>::CancelNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) {
|
||||
RAY_CHECK(subscribe_callback_index_ >= 0)
|
||||
<< "Client canceled notifications on a key before Subscribe completed";
|
||||
@@ -180,7 +179,7 @@ Status Log<ID, Data>::CancelNotifications(const DriverID &driver_id, const ID &i
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
void Log<ID, Data>::Delete(const DriverID &driver_id, const std::vector<ID> &ids) {
|
||||
void Log<ID, Data>::Delete(const JobID &job_id, const std::vector<ID> &ids) {
|
||||
if (ids.empty()) {
|
||||
return;
|
||||
}
|
||||
@@ -214,8 +213,8 @@ void Log<ID, Data>::Delete(const DriverID &driver_id, const std::vector<ID> &ids
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
void Log<ID, Data>::Delete(const DriverID &driver_id, const ID &id) {
|
||||
Delete(driver_id, std::vector<ID>({id}));
|
||||
void Log<ID, Data>::Delete(const JobID &job_id, const ID &id) {
|
||||
Delete(job_id, std::vector<ID>({id}));
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
@@ -226,7 +225,7 @@ std::string Log<ID, Data>::DebugString() const {
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Add(const DriverID &driver_id, const ID &id,
|
||||
Status Table<ID, Data>::Add(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) {
|
||||
num_adds_++;
|
||||
auto callback = [this, id, data, done](const CallbackReply &reply) {
|
||||
@@ -241,10 +240,10 @@ Status Table<ID, Data>::Add(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
const Callback &lookup, const FailureCallback &failure) {
|
||||
Status Table<ID, Data>::Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
|
||||
const FailureCallback &failure) {
|
||||
num_lookups_++;
|
||||
return Log<ID, Data>::Lookup(driver_id, id,
|
||||
return Log<ID, Data>::Lookup(job_id, id,
|
||||
[lookup, failure](AsyncGcsClient *client, const ID &id,
|
||||
const std::vector<Data> &data) {
|
||||
if (data.empty()) {
|
||||
@@ -261,12 +260,12 @@ Status Table<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Table<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Table<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe,
|
||||
const FailureCallback &failure,
|
||||
const SubscriptionCallback &done) {
|
||||
return Log<ID, Data>::Subscribe(
|
||||
driver_id, client_id,
|
||||
job_id, client_id,
|
||||
[subscribe, failure](AsyncGcsClient *client, const ID &id,
|
||||
const std::vector<Data> &data) {
|
||||
RAY_CHECK(data.empty() || data.size() == 1);
|
||||
@@ -289,8 +288,8 @@ std::string Table<ID, Data>::DebugString() const {
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Set<ID, Data>::Add(const DriverID &driver_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) {
|
||||
Status Set<ID, Data>::Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done) {
|
||||
num_adds_++;
|
||||
auto callback = [this, id, data, done](const CallbackReply &reply) {
|
||||
if (done != nullptr) {
|
||||
@@ -303,7 +302,7 @@ Status Set<ID, Data>::Add(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Set<ID, Data>::Remove(const DriverID &driver_id, const ID &id,
|
||||
Status Set<ID, Data>::Remove(const JobID &job_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) {
|
||||
num_removes_++;
|
||||
auto callback = [this, id, data, done](const CallbackReply &reply) {
|
||||
@@ -325,8 +324,8 @@ std::string Set<ID, Data>::DebugString() const {
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Update(const DriverID &driver_id, const ID &id,
|
||||
const DataMap &data_map, const HashCallback &done) {
|
||||
Status Hash<ID, Data>::Update(const JobID &job_id, const ID &id, const DataMap &data_map,
|
||||
const HashCallback &done) {
|
||||
num_adds_++;
|
||||
auto callback = [this, id, data_map, done](const CallbackReply &reply) {
|
||||
if (done != nullptr) {
|
||||
@@ -346,7 +345,7 @@ Status Hash<ID, Data>::Update(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
Status Hash<ID, Data>::RemoveEntries(const JobID &job_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) {
|
||||
num_removes_++;
|
||||
@@ -375,7 +374,7 @@ std::string Hash<ID, Data>::DebugString() const {
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
Status Hash<ID, Data>::Lookup(const JobID &job_id, const ID &id,
|
||||
const HashCallback &lookup) {
|
||||
num_lookups_++;
|
||||
auto callback = [this, id, lookup](const CallbackReply &reply) {
|
||||
@@ -403,7 +402,7 @@ Status Hash<ID, Data>::Lookup(const DriverID &driver_id, const ID &id,
|
||||
}
|
||||
|
||||
template <typename ID, typename Data>
|
||||
Status Hash<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Hash<ID, Data>::Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
RAY_CHECK(subscribe_callback_index_ == -1)
|
||||
@@ -450,25 +449,25 @@ Status Hash<ID, Data>::Subscribe(const DriverID &driver_id, const ClientID &clie
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ErrorTable::PushErrorToDriver(const DriverID &driver_id, const std::string &type,
|
||||
Status ErrorTable::PushErrorToDriver(const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
auto data = std::make_shared<ErrorTableData>();
|
||||
data->set_driver_id(driver_id.Binary());
|
||||
data->set_job_id(job_id.Binary());
|
||||
data->set_type(type);
|
||||
data->set_error_message(error_message);
|
||||
data->set_timestamp(timestamp);
|
||||
return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr);
|
||||
return Append(job_id, job_id, data, /*done_callback=*/nullptr);
|
||||
}
|
||||
|
||||
std::string ErrorTable::DebugString() const {
|
||||
return Log<DriverID, ErrorTableData>::DebugString();
|
||||
return Log<JobID, ErrorTableData>::DebugString();
|
||||
}
|
||||
|
||||
Status ProfileTable::AddProfileEventBatch(const ProfileTableData &profile_events) {
|
||||
// TODO(hchen): Change the parameter to shared_ptr to avoid copying data.
|
||||
auto data = std::make_shared<ProfileTableData>();
|
||||
data->CopyFrom(profile_events);
|
||||
return Append(DriverID::Nil(), UniqueID::FromRandom(), data,
|
||||
return Append(JobID::Nil(), UniqueID::FromRandom(), data,
|
||||
/*done_callback=*/nullptr);
|
||||
}
|
||||
|
||||
@@ -476,11 +475,11 @@ std::string ProfileTable::DebugString() const {
|
||||
return Log<UniqueID, ProfileTableData>::DebugString();
|
||||
}
|
||||
|
||||
Status DriverTable::AppendDriverData(const DriverID &driver_id, bool is_dead) {
|
||||
auto data = std::make_shared<DriverTableData>();
|
||||
data->set_driver_id(driver_id.Binary());
|
||||
Status JobTable::AppendJobData(const JobID &job_id, bool is_dead) {
|
||||
auto data = std::make_shared<JobTableData>();
|
||||
data->set_job_id(job_id.Binary());
|
||||
data->set_is_dead(is_dead);
|
||||
return Append(DriverID(driver_id), driver_id, data, /*done_callback=*/nullptr);
|
||||
return Append(JobID(job_id), job_id, data, /*done_callback=*/nullptr);
|
||||
}
|
||||
|
||||
void ClientTable::RegisterClientAddedCallback(const ClientTableCallback &callback) {
|
||||
@@ -694,13 +693,13 @@ Status ClientTable::Connect(const ClientTableData &local_client) {
|
||||
// Callback to request notifications from the client table once we've
|
||||
// successfully subscribed.
|
||||
auto subscription_callback = [this](AsyncGcsClient *c) {
|
||||
RAY_CHECK_OK(RequestNotifications(DriverID::Nil(), client_log_key_, client_id_));
|
||||
RAY_CHECK_OK(RequestNotifications(JobID::Nil(), client_log_key_, client_id_));
|
||||
};
|
||||
// Subscribe to the client table.
|
||||
RAY_CHECK_OK(Subscribe(DriverID::Nil(), client_id_, notification_callback,
|
||||
RAY_CHECK_OK(Subscribe(JobID::Nil(), client_id_, notification_callback,
|
||||
subscription_callback));
|
||||
};
|
||||
return Append(DriverID::Nil(), client_log_key_, data, add_callback);
|
||||
return Append(JobID::Nil(), client_log_key_, data, add_callback);
|
||||
}
|
||||
|
||||
Status ClientTable::Disconnect(const DisconnectCallback &callback) {
|
||||
@@ -709,12 +708,12 @@ Status ClientTable::Disconnect(const DisconnectCallback &callback) {
|
||||
auto add_callback = [this, callback](AsyncGcsClient *client, const ClientID &id,
|
||||
const ClientTableData &data) {
|
||||
HandleConnected(client, data);
|
||||
RAY_CHECK_OK(CancelNotifications(DriverID::Nil(), client_log_key_, id));
|
||||
RAY_CHECK_OK(CancelNotifications(JobID::Nil(), client_log_key_, id));
|
||||
if (callback != nullptr) {
|
||||
callback();
|
||||
}
|
||||
};
|
||||
RAY_RETURN_NOT_OK(Append(DriverID::Nil(), client_log_key_, data, add_callback));
|
||||
RAY_RETURN_NOT_OK(Append(JobID::Nil(), client_log_key_, data, add_callback));
|
||||
// We successfully added the deletion entry. Mark ourselves as disconnected.
|
||||
disconnected_ = true;
|
||||
return Status::OK();
|
||||
@@ -724,7 +723,7 @@ ray::Status ClientTable::MarkDisconnected(const ClientID &dead_client_id) {
|
||||
auto data = std::make_shared<ClientTableData>();
|
||||
data->set_client_id(dead_client_id.Binary());
|
||||
data->set_entry_type(ClientTableData::DELETION);
|
||||
return Append(DriverID::Nil(), client_log_key_, data, nullptr);
|
||||
return Append(JobID::Nil(), client_log_key_, data, nullptr);
|
||||
}
|
||||
|
||||
void ClientTable::GetClient(const ClientID &client_id,
|
||||
@@ -744,7 +743,7 @@ const std::unordered_map<ClientID, ClientTableData> &ClientTable::GetAllClients(
|
||||
|
||||
Status ClientTable::Lookup(const Callback &lookup) {
|
||||
RAY_CHECK(lookup != nullptr);
|
||||
return Log::Lookup(DriverID::Nil(), client_log_key_, lookup);
|
||||
return Log::Lookup(JobID::Nil(), client_log_key_, lookup);
|
||||
}
|
||||
|
||||
std::string ClientTable::DebugString() const {
|
||||
@@ -755,10 +754,10 @@ std::string ClientTable::DebugString() const {
|
||||
return result.str();
|
||||
}
|
||||
|
||||
Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id,
|
||||
Status ActorCheckpointIdTable::AddCheckpointId(const JobID &job_id,
|
||||
const ActorID &actor_id,
|
||||
const ActorCheckpointID &checkpoint_id) {
|
||||
auto lookup_callback = [this, checkpoint_id, driver_id, actor_id](
|
||||
auto lookup_callback = [this, checkpoint_id, job_id, actor_id](
|
||||
ray::gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
const ActorCheckpointIdData &data) {
|
||||
std::shared_ptr<ActorCheckpointIdData> copy =
|
||||
@@ -772,20 +771,20 @@ Status ActorCheckpointIdTable::AddCheckpointId(const DriverID &driver_id,
|
||||
RAY_LOG(DEBUG) << "Deleting checkpoint " << to_delete << " for actor " << actor_id;
|
||||
copy->mutable_checkpoint_ids()->erase(copy->mutable_checkpoint_ids()->begin());
|
||||
copy->mutable_timestamps()->erase(copy->mutable_timestamps()->begin());
|
||||
client_->actor_checkpoint_table().Delete(driver_id, to_delete);
|
||||
client_->actor_checkpoint_table().Delete(job_id, to_delete);
|
||||
}
|
||||
RAY_CHECK_OK(Add(driver_id, actor_id, copy, nullptr));
|
||||
RAY_CHECK_OK(Add(job_id, actor_id, copy, nullptr));
|
||||
};
|
||||
auto failure_callback = [this, checkpoint_id, driver_id, actor_id](
|
||||
auto failure_callback = [this, checkpoint_id, job_id, actor_id](
|
||||
ray::gcs::AsyncGcsClient *client, const UniqueID &id) {
|
||||
std::shared_ptr<ActorCheckpointIdData> data =
|
||||
std::make_shared<ActorCheckpointIdData>();
|
||||
data->set_actor_id(id.Binary());
|
||||
data->add_timestamps(current_sys_time_ms());
|
||||
*data->add_checkpoint_ids() = checkpoint_id.Binary();
|
||||
RAY_CHECK_OK(Add(driver_id, actor_id, data, nullptr));
|
||||
RAY_CHECK_OK(Add(job_id, actor_id, data, nullptr));
|
||||
};
|
||||
return Lookup(driver_id, actor_id, lookup_callback, failure_callback);
|
||||
return Lookup(job_id, actor_id, lookup_callback, failure_callback);
|
||||
}
|
||||
|
||||
template class Log<ObjectID, ObjectTableData>;
|
||||
@@ -797,9 +796,9 @@ template class Log<TaskID, TaskReconstructionData>;
|
||||
template class Table<TaskID, TaskLeaseData>;
|
||||
template class Table<ClientID, HeartbeatTableData>;
|
||||
template class Table<ClientID, HeartbeatBatchTableData>;
|
||||
template class Log<DriverID, ErrorTableData>;
|
||||
template class Log<JobID, ErrorTableData>;
|
||||
template class Log<ClientID, ClientTableData>;
|
||||
template class Log<DriverID, DriverTableData>;
|
||||
template class Log<JobID, JobTableData>;
|
||||
template class Log<UniqueID, ProfileTableData>;
|
||||
template class Table<ActorCheckpointID, ActorCheckpointData>;
|
||||
template class Table<ActorID, ActorCheckpointIdData>;
|
||||
|
||||
+77
-82
@@ -24,12 +24,12 @@ using rpc::ActorCheckpointData;
|
||||
using rpc::ActorCheckpointIdData;
|
||||
using rpc::ActorTableData;
|
||||
using rpc::ClientTableData;
|
||||
using rpc::DriverTableData;
|
||||
using rpc::ErrorTableData;
|
||||
using rpc::GcsChangeMode;
|
||||
using rpc::GcsEntry;
|
||||
using rpc::HeartbeatBatchTableData;
|
||||
using rpc::HeartbeatTableData;
|
||||
using rpc::JobTableData;
|
||||
using rpc::ObjectTableData;
|
||||
using rpc::ProfileTableData;
|
||||
using rpc::RayResource;
|
||||
@@ -55,9 +55,9 @@ enum class CommandType { kRegular, kChain };
|
||||
template <typename ID>
|
||||
class PubsubInterface {
|
||||
public:
|
||||
virtual Status RequestNotifications(const DriverID &driver_id, const ID &id,
|
||||
virtual Status RequestNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) = 0;
|
||||
virtual Status CancelNotifications(const DriverID &driver_id, const ID &id,
|
||||
virtual Status CancelNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id) = 0;
|
||||
virtual ~PubsubInterface(){};
|
||||
};
|
||||
@@ -67,9 +67,9 @@ class LogInterface {
|
||||
public:
|
||||
using WriteCallback =
|
||||
std::function<void(AsyncGcsClient *client, const ID &id, const Data &data)>;
|
||||
virtual Status Append(const DriverID &driver_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
|
||||
virtual Status AppendAt(const DriverID &driver_id, const ID &task_id,
|
||||
virtual Status Append(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual Status AppendAt(const JobID &job_id, const ID &task_id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done,
|
||||
const WriteCallback &failure, int log_length) = 0;
|
||||
virtual ~LogInterface(){};
|
||||
@@ -119,20 +119,20 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
|
||||
/// Append a log entry to a key.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param data Data to append to the log. TODO(rkn): This can be made const,
|
||||
/// right?
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Append(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
Status Append(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
/// Append a log entry to a key if and only if the log has the given number
|
||||
/// of entries.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param data Data to append to the log.
|
||||
/// \param done Callback that is called if the data was appended to the log.
|
||||
@@ -141,25 +141,22 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// \param log_length The number of entries that the log must have for the
|
||||
/// append to succeed.
|
||||
/// \return Status
|
||||
Status AppendAt(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
Status AppendAt(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done, const WriteCallback &failure,
|
||||
int log_length);
|
||||
|
||||
/// Lookup the log values at a key asynchronously.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is looked up in the GCS.
|
||||
/// \param lookup Callback that is called after lookup. If the callback is
|
||||
/// called with an empty vector, then there was no data at the key.
|
||||
/// \return Status
|
||||
Status Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup);
|
||||
|
||||
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup);
|
||||
/// Subscribe to any Append operations to this table. The caller may choose
|
||||
/// to subscribe to all Appends, or to subscribe only to keys that it
|
||||
/// requests notifications for. This may only be called once per Log
|
||||
/// instance.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Add to the table will be received. Else, only
|
||||
/// messages for the given client will be received. In the latter
|
||||
@@ -170,7 +167,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// \param done Callback that is called when subscription is complete and we
|
||||
/// are ready to receive messages.
|
||||
/// \return Status
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe, const SubscriptionCallback &done);
|
||||
|
||||
/// Request notifications about a key in this table.
|
||||
@@ -182,37 +179,37 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// notifications can be requested, the caller must first call `Subscribe`,
|
||||
/// with the same `client_id`.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the key to request notifications for.
|
||||
/// \param client_id The client who is requesting notifications. Before
|
||||
/// notifications can be requested, a call to `Subscribe` to this
|
||||
/// table with the same `client_id` must complete successfully.
|
||||
/// \return Status
|
||||
Status RequestNotifications(const DriverID &driver_id, const ID &id,
|
||||
Status RequestNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id);
|
||||
|
||||
/// Cancel notifications about a key in this table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the key to request notifications for.
|
||||
/// \param client_id The client who originally requested notifications.
|
||||
/// \return Status
|
||||
Status CancelNotifications(const DriverID &driver_id, const ID &id,
|
||||
Status CancelNotifications(const JobID &job_id, const ID &id,
|
||||
const ClientID &client_id);
|
||||
|
||||
/// Delete an entire key from redis.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data to delete from the GCS.
|
||||
/// \return Void.
|
||||
void Delete(const DriverID &driver_id, const ID &id);
|
||||
void Delete(const JobID &job_id, const ID &id);
|
||||
|
||||
/// Delete several keys from redis.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param ids The vector of IDs to delete from the GCS.
|
||||
/// \return Void.
|
||||
void Delete(const DriverID &driver_id, const std::vector<ID> &ids);
|
||||
void Delete(const JobID &job_id, const std::vector<ID> &ids);
|
||||
|
||||
/// Returns debug string for class.
|
||||
///
|
||||
@@ -232,7 +229,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// an additional parameter change_mode in NotificationCallback. Therefore this
|
||||
/// function supports notifications of remove operations.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Add to the table will be received. Else, only
|
||||
/// messages for the given client will be received. In the latter
|
||||
@@ -243,7 +240,7 @@ class Log : public LogInterface<ID, Data>, virtual public PubsubInterface<ID> {
|
||||
/// \param done Callback that is called when subscription is complete and we
|
||||
/// are ready to receive messages.
|
||||
/// \return Status
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const NotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done);
|
||||
|
||||
@@ -275,8 +272,8 @@ template <typename ID, typename Data>
|
||||
class TableInterface {
|
||||
public:
|
||||
using WriteCallback = typename Log<ID, Data>::WriteCallback;
|
||||
virtual Status Add(const DriverID &driver_id, const ID &task_id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
|
||||
virtual Status Add(const JobID &job_id, const ID &task_id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual ~TableInterface(){};
|
||||
};
|
||||
|
||||
@@ -312,32 +309,32 @@ class Table : private Log<ID, Data>,
|
||||
|
||||
/// Add an entry to the table. This overwrites any existing data at the key.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param data Data that is added to the GCS.
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
/// Lookup an entry asynchronously.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is looked up in the GCS.
|
||||
/// \param lookup Callback that is called after lookup if there was data the
|
||||
/// key.
|
||||
/// \param failure Callback that is called after lookup if there was no data
|
||||
/// at the key.
|
||||
/// \return Status
|
||||
Status Lookup(const DriverID &driver_id, const ID &id, const Callback &lookup,
|
||||
Status Lookup(const JobID &job_id, const ID &id, const Callback &lookup,
|
||||
const FailureCallback &failure);
|
||||
|
||||
/// Subscribe to any Add operations to this table. The caller may choose to
|
||||
/// subscribe to all Adds, or to subscribe only to keys that it requests
|
||||
/// notifications for. This may only be called once per Table instance.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Add to the table will be received. Else, only
|
||||
/// messages for the given client will be received. In the latter
|
||||
@@ -350,16 +347,14 @@ class Table : private Log<ID, Data>,
|
||||
/// \param done Callback that is called when subscription is complete and we
|
||||
/// are ready to receive messages.
|
||||
/// \return Status
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const Callback &subscribe, const FailureCallback &failure,
|
||||
const SubscriptionCallback &done);
|
||||
|
||||
void Delete(const DriverID &driver_id, const ID &id) {
|
||||
Log<ID, Data>::Delete(driver_id, id);
|
||||
}
|
||||
void Delete(const JobID &job_id, const ID &id) { Log<ID, Data>::Delete(job_id, id); }
|
||||
|
||||
void Delete(const DriverID &driver_id, const std::vector<ID> &ids) {
|
||||
Log<ID, Data>::Delete(driver_id, ids);
|
||||
void Delete(const JobID &job_id, const std::vector<ID> &ids) {
|
||||
Log<ID, Data>::Delete(job_id, ids);
|
||||
}
|
||||
|
||||
/// Returns debug string for class.
|
||||
@@ -383,10 +378,10 @@ template <typename ID, typename Data>
|
||||
class SetInterface {
|
||||
public:
|
||||
using WriteCallback = typename Log<ID, Data>::WriteCallback;
|
||||
virtual Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
virtual Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual Status Remove(const DriverID &driver_id, const ID &id,
|
||||
std::shared_ptr<Data> &data, const WriteCallback &done) = 0;
|
||||
virtual Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done) = 0;
|
||||
virtual ~SetInterface(){};
|
||||
};
|
||||
|
||||
@@ -419,30 +414,30 @@ class Set : private Log<ID, Data>,
|
||||
|
||||
/// Add an entry to the set.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param data Data to add to the set.
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Add(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
Status Add(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
/// Remove an entry from the set.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is removed from the GCS.
|
||||
/// \param data Data to remove from the set.
|
||||
/// \param done Callback that is called once the data has been written to the
|
||||
/// GCS.
|
||||
/// \return Status
|
||||
Status Remove(const DriverID &driver_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
Status Remove(const JobID &job_id, const ID &id, std::shared_ptr<Data> &data,
|
||||
const WriteCallback &done);
|
||||
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const NotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) {
|
||||
return Log<ID, Data>::Subscribe(driver_id, client_id, subscribe, done);
|
||||
return Log<ID, Data>::Subscribe(job_id, client_id, subscribe, done);
|
||||
}
|
||||
|
||||
/// Returns debug string for class.
|
||||
@@ -499,40 +494,40 @@ class HashInterface {
|
||||
|
||||
/// Add entries of a hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is added to the GCS.
|
||||
/// \param pairs Map data to add to the hash table.
|
||||
/// \param done HashCallback that is called once the request data has been written to
|
||||
/// the GCS.
|
||||
/// \return Status
|
||||
virtual Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
|
||||
virtual Status Update(const JobID &job_id, const ID &id, const DataMap &pairs,
|
||||
const HashCallback &done) = 0;
|
||||
|
||||
/// Remove entries from the hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is removed from the GCS.
|
||||
/// \param keys The entry keys of the hash table.
|
||||
/// \param remove_callback HashRemoveCallback that is called once the data has been
|
||||
/// written to the GCS no matter whether the key exists in the hash table.
|
||||
/// \return Status
|
||||
virtual Status RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
virtual Status RemoveEntries(const JobID &job_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) = 0;
|
||||
|
||||
/// Lookup the map data of a hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param id The ID of the data that is looked up in the GCS.
|
||||
/// \param lookup HashCallback that is called after lookup. If the callback is
|
||||
/// called with an empty hash table, then there was no data in the callback.
|
||||
/// \return Status
|
||||
virtual Status Lookup(const DriverID &driver_id, const ID &id,
|
||||
virtual Status Lookup(const JobID &job_id, const ID &id,
|
||||
const HashCallback &lookup) = 0;
|
||||
|
||||
/// Subscribe to any Update or Remove operations to this hash table.
|
||||
///
|
||||
/// \param driver_id The ID of the driver.
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param client_id The type of update to listen to. If this is nil, then a
|
||||
/// message for each Update to the table will be received. Else, only
|
||||
/// messages for the given client will be received. In the latter
|
||||
@@ -542,7 +537,7 @@ class HashInterface {
|
||||
/// \param done SubscriptionCallback that is called when subscription is complete and
|
||||
/// we are ready to receive messages.
|
||||
/// \return Status
|
||||
virtual Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
virtual Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) = 0;
|
||||
|
||||
@@ -567,17 +562,16 @@ class Hash : private Log<ID, Data>,
|
||||
using Log<ID, Data>::RequestNotifications;
|
||||
using Log<ID, Data>::CancelNotifications;
|
||||
|
||||
Status Update(const DriverID &driver_id, const ID &id, const DataMap &pairs,
|
||||
Status Update(const JobID &job_id, const ID &id, const DataMap &pairs,
|
||||
const HashCallback &done) override;
|
||||
|
||||
Status Subscribe(const DriverID &driver_id, const ClientID &client_id,
|
||||
Status Subscribe(const JobID &job_id, const ClientID &client_id,
|
||||
const HashNotificationCallback &subscribe,
|
||||
const SubscriptionCallback &done) override;
|
||||
|
||||
Status Lookup(const DriverID &driver_id, const ID &id,
|
||||
const HashCallback &lookup) override;
|
||||
Status Lookup(const JobID &job_id, const ID &id, const HashCallback &lookup) override;
|
||||
|
||||
Status RemoveEntries(const DriverID &driver_id, const ID &id,
|
||||
Status RemoveEntries(const JobID &job_id, const ID &id,
|
||||
const std::vector<std::string> &keys,
|
||||
const HashRemoveCallback &remove_callback) override;
|
||||
|
||||
@@ -645,23 +639,23 @@ class HeartbeatBatchTable : public Table<ClientID, HeartbeatBatchTableData> {
|
||||
virtual ~HeartbeatBatchTable() {}
|
||||
};
|
||||
|
||||
class DriverTable : public Log<DriverID, DriverTableData> {
|
||||
class JobTable : public Log<JobID, JobTableData> {
|
||||
public:
|
||||
DriverTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
|
||||
AsyncGcsClient *client)
|
||||
JobTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
|
||||
AsyncGcsClient *client)
|
||||
: Log(contexts, client) {
|
||||
pubsub_channel_ = TablePubsub::DRIVER_PUBSUB;
|
||||
prefix_ = TablePrefix::DRIVER;
|
||||
pubsub_channel_ = TablePubsub::JOB_PUBSUB;
|
||||
prefix_ = TablePrefix::JOB;
|
||||
};
|
||||
|
||||
virtual ~DriverTable() {}
|
||||
virtual ~JobTable() {}
|
||||
|
||||
/// Appends driver data to the driver table.
|
||||
/// Appends job data to the job table.
|
||||
///
|
||||
/// \param driver_id The driver id.
|
||||
/// \param is_dead Whether the driver is dead.
|
||||
/// \param job_id The job id.
|
||||
/// \param is_dead Whether the job is dead.
|
||||
/// \return The return status.
|
||||
Status AppendDriverData(const DriverID &driver_id, bool is_dead);
|
||||
Status AppendJobData(const JobID &job_id, bool is_dead);
|
||||
};
|
||||
|
||||
/// Actor table starts with an ALIVE entry, which represents the first time the actor
|
||||
@@ -697,9 +691,9 @@ class TaskLeaseTable : public Table<TaskID, TaskLeaseData> {
|
||||
prefix_ = TablePrefix::TASK_LEASE;
|
||||
}
|
||||
|
||||
Status Add(const DriverID &driver_id, const TaskID &id,
|
||||
std::shared_ptr<TaskLeaseData> &data, const WriteCallback &done) override {
|
||||
RAY_RETURN_NOT_OK((Table<TaskID, TaskLeaseData>::Add(driver_id, id, data, done)));
|
||||
Status Add(const JobID &job_id, const TaskID &id, std::shared_ptr<TaskLeaseData> &data,
|
||||
const WriteCallback &done) override {
|
||||
RAY_RETURN_NOT_OK((Table<TaskID, TaskLeaseData>::Add(job_id, id, data, done)));
|
||||
// Mark the entry for expiration in Redis. It's okay if this command fails
|
||||
// since the lease entry itself contains the expiration period. In the
|
||||
// worst case, if the command fails, then a client that looks up the lease
|
||||
@@ -733,11 +727,11 @@ class ActorCheckpointIdTable : public Table<ActorID, ActorCheckpointIdData> {
|
||||
/// Add a checkpoint id to an actor, and remove a previous checkpoint if the
|
||||
/// total number of checkpoints in GCS exceeds the max allowed value.
|
||||
///
|
||||
/// \param driver_id The ID of the job (= driver).
|
||||
/// \param job_id The ID of the job.
|
||||
/// \param actor_id ID of the actor.
|
||||
/// \param checkpoint_id ID of the checkpoint.
|
||||
/// \return Status.
|
||||
Status AddCheckpointId(const DriverID &driver_id, const ActorID &actor_id,
|
||||
Status AddCheckpointId(const JobID &job_id, const ActorID &actor_id,
|
||||
const ActorCheckpointID &checkpoint_id);
|
||||
};
|
||||
|
||||
@@ -761,7 +755,7 @@ class TaskTable : public Table<TaskID, TaskTableData> {
|
||||
|
||||
} // namespace raylet
|
||||
|
||||
class ErrorTable : private Log<DriverID, ErrorTableData> {
|
||||
class ErrorTable : private Log<JobID, ErrorTableData> {
|
||||
public:
|
||||
ErrorTable(const std::vector<std::shared_ptr<RedisContext>> &contexts,
|
||||
AsyncGcsClient *client)
|
||||
@@ -770,19 +764,20 @@ class ErrorTable : private Log<DriverID, ErrorTableData> {
|
||||
prefix_ = TablePrefix::ERROR_INFO;
|
||||
};
|
||||
|
||||
/// Push an error message for a specific job.
|
||||
/// Push an error message for the driver of a specific.
|
||||
///
|
||||
/// TODO(rkn): We need to make sure that the errors are unique because
|
||||
/// duplicate messages currently cause failures (the GCS doesn't allow it). A
|
||||
/// natural way to do this is to have finer-grained time stamps.
|
||||
///
|
||||
/// \param driver_id The ID of the job that generated the error. If the error
|
||||
/// should be pushed to all jobs, then this should be nil.
|
||||
/// \param job_id The ID of the job that generated the error. If the error
|
||||
/// should be pushed to all drivers, then this should be nil.
|
||||
/// \param type The type of the error.
|
||||
/// \param error_message The error message to push.
|
||||
/// \param timestamp The timestamp of the error.
|
||||
/// \return Status.
|
||||
Status PushErrorToDriver(const DriverID &driver_id, const std::string &type,
|
||||
// TODO(qwang): refactor this API to implement broadcast.
|
||||
Status PushErrorToDriver(const JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp);
|
||||
|
||||
/// Returns debug string for class.
|
||||
|
||||
@@ -74,7 +74,7 @@ void ObjectDirectory::RegisterBackend() {
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_->object_table().Subscribe(
|
||||
DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
object_notification_callback, nullptr));
|
||||
}
|
||||
|
||||
@@ -87,7 +87,7 @@ ray::Status ObjectDirectory::ReportObjectAdded(
|
||||
data->set_manager(client_id.Binary());
|
||||
data->set_object_size(object_info.data_size);
|
||||
ray::Status status =
|
||||
gcs_client_->object_table().Add(DriverID::Nil(), object_id, data, nullptr);
|
||||
gcs_client_->object_table().Add(JobID::Nil(), object_id, data, nullptr);
|
||||
return status;
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ ray::Status ObjectDirectory::ReportObjectRemoved(
|
||||
data->set_manager(client_id.Binary());
|
||||
data->set_object_size(object_info.data_size);
|
||||
ray::Status status =
|
||||
gcs_client_->object_table().Remove(DriverID::Nil(), object_id, data, nullptr);
|
||||
gcs_client_->object_table().Remove(JobID::Nil(), object_id, data, nullptr);
|
||||
return status;
|
||||
};
|
||||
|
||||
@@ -159,7 +159,7 @@ ray::Status ObjectDirectory::SubscribeObjectLocations(const UniqueID &callback_i
|
||||
if (it == listeners_.end()) {
|
||||
it = listeners_.emplace(object_id, LocationListenerState()).first;
|
||||
status = gcs_client_->object_table().RequestNotifications(
|
||||
DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
}
|
||||
auto &listener_state = it->second;
|
||||
// TODO(hme): Make this fatal after implementing Pull suppression.
|
||||
@@ -187,7 +187,7 @@ ray::Status ObjectDirectory::UnsubscribeObjectLocations(const UniqueID &callback
|
||||
entry->second.callbacks.erase(callback_id);
|
||||
if (entry->second.callbacks.empty()) {
|
||||
status = gcs_client_->object_table().CancelNotifications(
|
||||
DriverID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
JobID::Nil(), object_id, gcs_client_->client_table().GetLocalClientId());
|
||||
listeners_.erase(entry);
|
||||
}
|
||||
return status;
|
||||
@@ -210,7 +210,7 @@ ray::Status ObjectDirectory::LookupLocations(const ObjectID &object_id,
|
||||
// SubscribeObjectLocations call, so look up the object's locations
|
||||
// directly from the GCS.
|
||||
status = gcs_client_->object_table().Lookup(
|
||||
DriverID::Nil(), object_id,
|
||||
JobID::Nil(), object_id,
|
||||
[this, callback](gcs::AsyncGcsClient *client, const ObjectID &object_id,
|
||||
const std::vector<ObjectTableData> &location_updates) {
|
||||
// Build the set of current locations based on the entries in the log.
|
||||
|
||||
@@ -25,7 +25,7 @@ enum TablePrefix {
|
||||
HEARTBEAT = 9;
|
||||
HEARTBEAT_BATCH = 10;
|
||||
ERROR_INFO = 11;
|
||||
DRIVER = 12;
|
||||
JOB = 12;
|
||||
PROFILE = 13;
|
||||
TASK_LEASE = 14;
|
||||
ACTOR_CHECKPOINT = 15;
|
||||
@@ -47,7 +47,7 @@ enum TablePubsub {
|
||||
HEARTBEAT_BATCH_PUBSUB = 8;
|
||||
ERROR_INFO_PUBSUB = 9;
|
||||
TASK_LEASE_PUBSUB = 10;
|
||||
DRIVER_PUBSUB = 11;
|
||||
JOB_PUBSUB = 11;
|
||||
NODE_RESOURCE_PUBSUB = 12;
|
||||
TABLE_PUBSUB_MAX = 13;
|
||||
}
|
||||
@@ -102,8 +102,8 @@ message ActorTableData {
|
||||
// dies, then this is the object that should be reconstructed for the actor
|
||||
// to be recreated.
|
||||
bytes actor_creation_dummy_object_id = 2;
|
||||
// The ID of the driver that created the actor.
|
||||
bytes driver_id = 3;
|
||||
// The ID of the job that created the actor.
|
||||
bytes job_id = 3;
|
||||
// The ID of the node manager that created the actor.
|
||||
bytes node_manager_id = 4;
|
||||
// Current state of this actor.
|
||||
@@ -115,8 +115,8 @@ message ActorTableData {
|
||||
}
|
||||
|
||||
message ErrorTableData {
|
||||
// The ID of the driver that the error is for.
|
||||
bytes driver_id = 1;
|
||||
// The ID of the job that the error is for.
|
||||
bytes job_id = 1;
|
||||
// The type of the error.
|
||||
string type = 2;
|
||||
// The error message.
|
||||
@@ -222,9 +222,9 @@ message TaskLeaseData {
|
||||
uint64 timeout = 3;
|
||||
}
|
||||
|
||||
message DriverTableData {
|
||||
// The driver ID.
|
||||
bytes driver_id = 1;
|
||||
message JobTableData {
|
||||
// The job ID.
|
||||
bytes job_id = 1;
|
||||
// Whether it's dead.
|
||||
bool is_dead = 2;
|
||||
}
|
||||
|
||||
@@ -43,8 +43,8 @@ const ObjectID ActorRegistration::GetExecutionDependency() const {
|
||||
return execution_dependency_;
|
||||
}
|
||||
|
||||
const DriverID ActorRegistration::GetDriverId() const {
|
||||
return DriverID::FromBinary(actor_table_data_.driver_id());
|
||||
const JobID ActorRegistration::GetJobId() const {
|
||||
return JobID::FromBinary(actor_table_data_.job_id());
|
||||
}
|
||||
|
||||
const int64_t ActorRegistration::GetMaxReconstructions() const {
|
||||
|
||||
@@ -73,8 +73,8 @@ class ActorRegistration {
|
||||
/// \return The execution dependency returned by the actor's creation task.
|
||||
const ObjectID GetActorCreationDependency() const;
|
||||
|
||||
/// Get actor's driver ID.
|
||||
const DriverID GetDriverId() const;
|
||||
/// Get actor's job ID.
|
||||
const JobID GetJobId() const;
|
||||
|
||||
/// Get the max number of times this actor should be reconstructed.
|
||||
const int64_t GetMaxReconstructions() const;
|
||||
|
||||
@@ -135,6 +135,7 @@ table RegisterClientRequest {
|
||||
// The process ID of this worker.
|
||||
worker_pid: long;
|
||||
// The driver ID. This is non-nil if the client is a driver.
|
||||
// TODO(qwang): rename this to driver_task_id.
|
||||
driver_id: string;
|
||||
// Language of this worker.
|
||||
language: Language;
|
||||
@@ -196,7 +197,7 @@ table WaitReply {
|
||||
// This struct is the same as ErrorTableData.
|
||||
table PushErrorRequest {
|
||||
// The ID of the job that the error is for.
|
||||
driver_id: string;
|
||||
job_id: string;
|
||||
// The type of the error.
|
||||
type: string;
|
||||
// The error message.
|
||||
|
||||
@@ -43,12 +43,12 @@ inline bool ThrowRayExceptionIfNotOK(JNIEnv *env, const ray::Status &status) {
|
||||
*/
|
||||
JNIEXPORT jlong JNICALL Java_org_ray_runtime_raylet_RayletClientImpl_nativeInit(
|
||||
JNIEnv *env, jclass, jstring sockName, jbyteArray workerId, jboolean isWorker,
|
||||
jbyteArray driverId) {
|
||||
jbyteArray jobId) {
|
||||
UniqueIdFromJByteArray<ClientID> worker_id(env, workerId);
|
||||
UniqueIdFromJByteArray<DriverID> driver_id(env, driverId);
|
||||
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
|
||||
const char *nativeString = env->GetStringUTFChars(sockName, JNI_FALSE);
|
||||
auto raylet_client = new RayletClient(nativeString, worker_id.GetId(), isWorker,
|
||||
driver_id.GetId(), Language::JAVA);
|
||||
job_id.GetId(), Language::JAVA);
|
||||
env->ReleaseStringUTFChars(sockName, nativeString);
|
||||
return reinterpret_cast<jlong>(raylet_client);
|
||||
}
|
||||
@@ -224,13 +224,13 @@ Java_org_ray_runtime_raylet_RayletClientImpl_nativeWaitObject(
|
||||
*/
|
||||
JNIEXPORT jbyteArray JNICALL
|
||||
Java_org_ray_runtime_raylet_RayletClientImpl_nativeGenerateTaskId(
|
||||
JNIEnv *env, jclass, jbyteArray driverId, jbyteArray parentTaskId,
|
||||
JNIEnv *env, jclass, jbyteArray jobId, jbyteArray parentTaskId,
|
||||
jint parent_task_counter) {
|
||||
UniqueIdFromJByteArray<DriverID> driver_id(env, driverId);
|
||||
UniqueIdFromJByteArray<JobID> job_id(env, jobId);
|
||||
UniqueIdFromJByteArray<TaskID> parent_task_id(env, parentTaskId);
|
||||
|
||||
TaskID task_id =
|
||||
ray::GenerateTaskId(driver_id.GetId(), parent_task_id.GetId(), parent_task_counter);
|
||||
ray::GenerateTaskId(job_id.GetId(), parent_task_id.GetId(), parent_task_counter);
|
||||
jbyteArray result = env->NewByteArray(task_id.Size());
|
||||
if (nullptr == result) {
|
||||
return nullptr;
|
||||
|
||||
@@ -275,9 +275,8 @@ void LineageCache::FlushTask(const TaskID &task_id) {
|
||||
// TODO(swang): Make this better...
|
||||
auto task_data = std::make_shared<TaskTableData>();
|
||||
task_data->set_task(task->TaskData().Serialize());
|
||||
RAY_CHECK_OK(
|
||||
task_storage_.Add(DriverID(task->TaskData().GetTaskSpecification().DriverId()),
|
||||
task_id, task_data, task_callback));
|
||||
RAY_CHECK_OK(task_storage_.Add(JobID(task->TaskData().GetTaskSpecification().JobId()),
|
||||
task_id, task_data, task_callback));
|
||||
|
||||
// We successfully wrote the task, so mark it as committing.
|
||||
// TODO(swang): Use a batched interface and write with all object entries.
|
||||
@@ -290,7 +289,7 @@ bool LineageCache::SubscribeTask(const TaskID &task_id) {
|
||||
if (unsubscribed) {
|
||||
// Request notifications for the task if we haven't already requested
|
||||
// notifications for it.
|
||||
RAY_CHECK_OK(task_pubsub_.RequestNotifications(DriverID::Nil(), task_id, client_id_));
|
||||
RAY_CHECK_OK(task_pubsub_.RequestNotifications(JobID::Nil(), task_id, client_id_));
|
||||
}
|
||||
// Return whether we were previously unsubscribed to this task and are now
|
||||
// subscribed.
|
||||
@@ -303,7 +302,7 @@ bool LineageCache::UnsubscribeTask(const TaskID &task_id) {
|
||||
if (subscribed) {
|
||||
// Cancel notifications for the task if we previously requested
|
||||
// notifications for it.
|
||||
RAY_CHECK_OK(task_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_));
|
||||
RAY_CHECK_OK(task_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_));
|
||||
subscribed_tasks_.erase(it);
|
||||
}
|
||||
// Return whether we were previously subscribed to this task and are now
|
||||
|
||||
@@ -22,7 +22,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
|
||||
notification_callback_ = notification_callback;
|
||||
}
|
||||
|
||||
Status Add(const DriverID &driver_id, const TaskID &task_id,
|
||||
Status Add(const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<TaskTableData> &task_data,
|
||||
const gcs::TableInterface<TaskID, TaskTableData>::WriteCallback &done) {
|
||||
task_table_[task_id] = task_data;
|
||||
@@ -57,10 +57,10 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
|
||||
notification_callback_(client, task_id, data);
|
||||
}
|
||||
};
|
||||
return Add(DriverID::Nil(), task_id, task_data, callback);
|
||||
return Add(JobID::Nil(), task_id, task_data, callback);
|
||||
}
|
||||
|
||||
Status RequestNotifications(const DriverID &driver_id, const TaskID &task_id,
|
||||
Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
|
||||
const ClientID &client_id) {
|
||||
subscribed_tasks_.insert(task_id);
|
||||
if (task_table_.count(task_id) == 1) {
|
||||
@@ -70,7 +70,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskTableData>,
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
Status CancelNotifications(const DriverID &driver_id, const TaskID &task_id,
|
||||
Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
|
||||
const ClientID &client_id) {
|
||||
subscribed_tasks_.erase(task_id);
|
||||
return ray::Status::OK();
|
||||
@@ -133,7 +133,7 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
|
||||
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
|
||||
}
|
||||
std::vector<std::string> function_descriptor(3);
|
||||
auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments,
|
||||
auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
|
||||
num_returns, required_resources, Language::PYTHON,
|
||||
function_descriptor);
|
||||
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
|
||||
|
||||
@@ -35,7 +35,7 @@ void Monitor::Start() {
|
||||
HandleHeartbeat(id, heartbeat_data);
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_.heartbeat_table().Subscribe(
|
||||
DriverID::Nil(), ClientID::Nil(), heartbeat_callback, nullptr, nullptr));
|
||||
JobID::Nil(), ClientID::Nil(), heartbeat_callback, nullptr, nullptr));
|
||||
Tick();
|
||||
}
|
||||
|
||||
@@ -68,9 +68,9 @@ void Monitor::Tick() {
|
||||
error_message << "The node with client ID " << client_id
|
||||
<< " has been marked dead because the monitor"
|
||||
<< " has missed too many heartbeats from it.";
|
||||
// We use the nil DriverID to broadcast the message to all drivers.
|
||||
// We use the nil JobID to broadcast the message to all drivers.
|
||||
RAY_CHECK_OK(gcs_client_.error_table().PushErrorToDriver(
|
||||
DriverID::Nil(), type, error_message.str(), current_time_ms()));
|
||||
JobID::Nil(), type, error_message.str(), current_time_ms()));
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_.client_table().Lookup(lookup_callback));
|
||||
@@ -88,7 +88,7 @@ void Monitor::Tick() {
|
||||
for (const auto &heartbeat : heartbeat_buffer_) {
|
||||
batch->add_batch()->CopyFrom(heartbeat.second);
|
||||
}
|
||||
RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(DriverID::Nil(), ClientID::Nil(),
|
||||
RAY_CHECK_OK(gcs_client_.heartbeat_batch_table().Add(JobID::Nil(), ClientID::Nil(),
|
||||
batch, nullptr));
|
||||
heartbeat_buffer_.clear();
|
||||
}
|
||||
|
||||
@@ -136,7 +136,7 @@ ray::Status NodeManager::RegisterGcs() {
|
||||
lineage_cache_.HandleEntryCommitted(task_id);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(gcs_client_->raylet_task_table().Subscribe(
|
||||
DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
task_committed_callback, nullptr, nullptr));
|
||||
|
||||
const auto task_lease_notification_callback = [this](gcs::AsyncGcsClient *client,
|
||||
@@ -160,7 +160,7 @@ ray::Status NodeManager::RegisterGcs() {
|
||||
reconstruction_policy_.HandleTaskLeaseNotification(task_id, 0);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(gcs_client_->task_lease_table().Subscribe(
|
||||
DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(),
|
||||
task_lease_notification_callback, task_lease_empty_callback, nullptr));
|
||||
|
||||
// Register a callback to handle actor notifications.
|
||||
@@ -175,7 +175,7 @@ ray::Status NodeManager::RegisterGcs() {
|
||||
};
|
||||
|
||||
RAY_RETURN_NOT_OK(gcs_client_->actor_table().Subscribe(
|
||||
DriverID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr));
|
||||
JobID::Nil(), ClientID::Nil(), actor_notification_callback, nullptr));
|
||||
|
||||
// Register a callback on the client table for new clients.
|
||||
auto node_manager_client_added = [this](gcs::AsyncGcsClient *client, const UniqueID &id,
|
||||
@@ -210,18 +210,17 @@ ray::Status NodeManager::RegisterGcs() {
|
||||
HeartbeatBatchAdded(heartbeat_batch);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(gcs_client_->heartbeat_batch_table().Subscribe(
|
||||
DriverID::Nil(), ClientID::Nil(), heartbeat_batch_added,
|
||||
JobID::Nil(), ClientID::Nil(), heartbeat_batch_added,
|
||||
/*subscribe_callback=*/nullptr,
|
||||
/*done_callback=*/nullptr));
|
||||
|
||||
// Subscribe to driver table updates.
|
||||
const auto driver_table_handler =
|
||||
[this](gcs::AsyncGcsClient *client, const DriverID &client_id,
|
||||
const std::vector<DriverTableData> &driver_data) {
|
||||
HandleDriverTableUpdate(client_id, driver_data);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(gcs_client_->driver_table().Subscribe(
|
||||
DriverID::Nil(), ClientID::Nil(), driver_table_handler, nullptr));
|
||||
const auto job_table_handler = [this](gcs::AsyncGcsClient *client, const JobID &job_id,
|
||||
const std::vector<JobTableData> &job_data) {
|
||||
HandleJobTableUpdate(job_id, job_data);
|
||||
};
|
||||
RAY_RETURN_NOT_OK(gcs_client_->job_table().Subscribe(JobID::Nil(), ClientID::Nil(),
|
||||
job_table_handler, nullptr));
|
||||
|
||||
// Start sending heartbeats to the GCS.
|
||||
last_heartbeat_at_ms_ = current_time_ms();
|
||||
@@ -252,14 +251,14 @@ void NodeManager::KillWorker(std::shared_ptr<Worker> worker) {
|
||||
});
|
||||
}
|
||||
|
||||
void NodeManager::HandleDriverTableUpdate(
|
||||
const DriverID &id, const std::vector<DriverTableData> &driver_data) {
|
||||
for (const auto &entry : driver_data) {
|
||||
RAY_LOG(DEBUG) << "HandleDriverTableUpdate "
|
||||
<< UniqueID::FromBinary(entry.driver_id()) << " " << entry.is_dead();
|
||||
void NodeManager::HandleJobTableUpdate(const JobID &id,
|
||||
const std::vector<JobTableData> &job_data) {
|
||||
for (const auto &entry : job_data) {
|
||||
RAY_LOG(DEBUG) << "HandleJobTableUpdate " << UniqueID::FromBinary(entry.job_id())
|
||||
<< " " << entry.is_dead();
|
||||
if (entry.is_dead()) {
|
||||
auto driver_id = DriverID::FromBinary(entry.driver_id());
|
||||
auto workers = worker_pool_.GetWorkersRunningTasksForDriver(driver_id);
|
||||
auto job_id = JobID::FromBinary(entry.job_id());
|
||||
auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id);
|
||||
|
||||
// Kill all the workers. The actual cleanup for these workers is done
|
||||
// later when we receive the DisconnectClient message from them.
|
||||
@@ -271,11 +270,11 @@ void NodeManager::HandleDriverTableUpdate(
|
||||
KillWorker(worker);
|
||||
}
|
||||
|
||||
// Remove all tasks for this driver from the scheduling queues, mark
|
||||
// Remove all tasks for this job from the scheduling queues, mark
|
||||
// the results for these tasks as not required, cancel any attempts
|
||||
// at reconstruction. Note that at this time the workers are likely
|
||||
// alive because of the delay in killing workers.
|
||||
CleanUpTasksForDeadDriver(driver_id);
|
||||
CleanUpTasksForFinishedJob(job_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -313,7 +312,7 @@ void NodeManager::Heartbeat() {
|
||||
}
|
||||
|
||||
ray::Status status = heartbeat_table.Add(
|
||||
DriverID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
|
||||
JobID::Nil(), gcs_client_->client_table().GetLocalClientId(), heartbeat_data,
|
||||
/*success_callback=*/nullptr);
|
||||
RAY_CHECK_OK_PREPEND(status, "Heartbeat failed");
|
||||
|
||||
@@ -605,7 +604,7 @@ void NodeManager::PublishActorStateTransition(
|
||||
RAY_CHECK_OK(redis_context->RunArgvAsync(args));
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(DriverID::Nil(), actor_id,
|
||||
RAY_CHECK_OK(gcs_client_->actor_table().AppendAt(JobID::Nil(), actor_id,
|
||||
actor_notification, success_callback,
|
||||
failure_callback, log_length));
|
||||
}
|
||||
@@ -690,8 +689,8 @@ void NodeManager::HandleActorStateTransition(const ActorID &actor_id,
|
||||
}
|
||||
}
|
||||
|
||||
void NodeManager::CleanUpTasksForDeadDriver(const DriverID &driver_id) {
|
||||
auto tasks_to_remove = local_queues_.GetTaskIdsForDriver(driver_id);
|
||||
void NodeManager::CleanUpTasksForFinishedJob(const JobID &job_id) {
|
||||
auto tasks_to_remove = local_queues_.GetTaskIdsForJob(job_id);
|
||||
task_dependency_manager_.RemoveTasksAndRelatedObjects(tasks_to_remove);
|
||||
// NOTE(swang): SchedulingQueue::RemoveTasks modifies its argument so we must
|
||||
// call it last.
|
||||
@@ -749,7 +748,7 @@ void NodeManager::ProcessClientMessage(
|
||||
<< (registered_worker ? std::to_string(registered_worker->Pid())
|
||||
: "nil");
|
||||
if (registered_worker && registered_worker->IsDead()) {
|
||||
// For a worker that is marked as dead (because the driver has died already),
|
||||
// For a worker that is marked as dead (because the job has died already),
|
||||
// all the messages are ignored except DisconnectClient.
|
||||
if ((message_type_value != protocol::MessageType::DisconnectClient) &&
|
||||
(message_type_value != protocol::MessageType::IntentionalDisconnectClient)) {
|
||||
@@ -824,7 +823,7 @@ void NodeManager::ProcessClientMessage(
|
||||
for (const auto &object_id : object_ids) {
|
||||
creating_task_ids.push_back(object_id.TaskId());
|
||||
}
|
||||
gcs_client_->raylet_task_table().Delete(DriverID::Nil(), creating_task_ids);
|
||||
gcs_client_->raylet_task_table().Delete(JobID::Nil(), creating_task_ids);
|
||||
}
|
||||
} break;
|
||||
case protocol::MessageType::PrepareActorCheckpointRequest: {
|
||||
@@ -857,10 +856,11 @@ void NodeManager::ProcessRegisterClientRequestMessage(
|
||||
// message is actually the ID of the driver task, while client_id represents the
|
||||
// real driver ID, which can associate all the tasks/actors for a given driver,
|
||||
// which is set to the worker ID.
|
||||
const DriverID driver_id = from_flatbuf<DriverID>(*message->driver_id());
|
||||
// TODO(qwang): Use driver_task_id instead here.
|
||||
const WorkerID driver_id = from_flatbuf<WorkerID>(*message->driver_id());
|
||||
TaskID driver_task_id = TaskID::GetDriverTaskID(driver_id);
|
||||
worker->AssignTaskId(driver_task_id);
|
||||
worker->AssignDriverId(from_flatbuf<DriverID>(*message->client_id()));
|
||||
worker->AssignJobId(from_flatbuf<JobID>(*message->client_id()));
|
||||
worker_pool_.RegisterDriver(std::move(worker));
|
||||
local_queues_.AddDriverTaskId(driver_task_id);
|
||||
}
|
||||
@@ -992,14 +992,14 @@ void NodeManager::ProcessDisconnectClientMessage(
|
||||
|
||||
if (!intentional_disconnect) {
|
||||
// Push the error to driver.
|
||||
const DriverID &driver_id = worker->GetAssignedDriverId();
|
||||
const JobID &job_id = worker->GetAssignedJobId();
|
||||
// TODO(rkn): Define this constant somewhere else.
|
||||
std::string type = "worker_died";
|
||||
std::ostringstream error_message;
|
||||
error_message << "A worker died or was killed while executing task " << task_id
|
||||
<< ".";
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
driver_id, type, error_message.str(), current_time_ms()));
|
||||
job_id, type, error_message.str(), current_time_ms()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1022,22 +1022,21 @@ void NodeManager::ProcessDisconnectClientMessage(
|
||||
worker->ResetLifetimeResourceIds();
|
||||
|
||||
RAY_LOG(DEBUG) << "Worker (pid=" << worker->Pid() << ") is disconnected. "
|
||||
<< "driver_id: " << worker->GetAssignedDriverId();
|
||||
<< "job_id: " << worker->GetAssignedJobId();
|
||||
|
||||
// Since some resources may have been released, we can try to dispatch more tasks.
|
||||
DispatchTasks(local_queues_.GetReadyTasksWithResources());
|
||||
} else if (is_driver) {
|
||||
// The client is a driver.
|
||||
RAY_CHECK_OK(
|
||||
gcs_client_->driver_table().AppendDriverData(DriverID(client->GetClientId()),
|
||||
/*is_dead=*/true));
|
||||
auto driver_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!driver_id.IsNil());
|
||||
local_queues_.RemoveDriverTaskId(driver_id);
|
||||
RAY_CHECK_OK(gcs_client_->job_table().AppendJobData(JobID(client->GetClientId()),
|
||||
/*is_dead=*/true));
|
||||
auto job_id = worker->GetAssignedTaskId();
|
||||
RAY_CHECK(!job_id.IsNil());
|
||||
local_queues_.RemoveDriverTaskId(job_id);
|
||||
worker_pool_.DisconnectDriver(worker);
|
||||
|
||||
RAY_LOG(DEBUG) << "Driver (pid=" << worker->Pid() << ") is disconnected. "
|
||||
<< "driver_id: " << worker->GetAssignedDriverId();
|
||||
<< "job_id: " << worker->GetAssignedJobId();
|
||||
}
|
||||
|
||||
// TODO(rkn): Tell the object manager that this client has disconnected so
|
||||
@@ -1142,13 +1141,13 @@ void NodeManager::ProcessWaitRequestMessage(
|
||||
void NodeManager::ProcessPushErrorRequestMessage(const uint8_t *message_data) {
|
||||
auto message = flatbuffers::GetRoot<protocol::PushErrorRequest>(message_data);
|
||||
|
||||
DriverID driver_id = from_flatbuf<DriverID>(*message->driver_id());
|
||||
JobID job_id = from_flatbuf<JobID>(*message->job_id());
|
||||
auto const &type = string_from_flatbuf(*message->type());
|
||||
auto const &error_message = string_from_flatbuf(*message->error_message());
|
||||
double timestamp = message->timestamp();
|
||||
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(driver_id, type,
|
||||
error_message, timestamp));
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(job_id, type, error_message,
|
||||
timestamp));
|
||||
}
|
||||
|
||||
void NodeManager::ProcessPrepareActorCheckpointRequest(
|
||||
@@ -1173,7 +1172,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
|
||||
|
||||
// Write checkpoint data to GCS.
|
||||
RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Add(
|
||||
DriverID::Nil(), checkpoint_id, checkpoint_data,
|
||||
JobID::Nil(), checkpoint_id, checkpoint_data,
|
||||
[worker, actor_id, this](ray::gcs::AsyncGcsClient *client,
|
||||
const ActorCheckpointID &checkpoint_id,
|
||||
const ActorCheckpointData &data) {
|
||||
@@ -1182,7 +1181,7 @@ void NodeManager::ProcessPrepareActorCheckpointRequest(
|
||||
// Save this actor-to-checkpoint mapping, and remove old checkpoints associated
|
||||
// with this actor.
|
||||
RAY_CHECK_OK(gcs_client_->actor_checkpoint_id_table().AddCheckpointId(
|
||||
DriverID::Nil(), actor_id, checkpoint_id));
|
||||
JobID::Nil(), actor_id, checkpoint_id));
|
||||
// Send reply to worker.
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto reply = ray::protocol::CreatePrepareActorCheckpointReply(
|
||||
@@ -1284,7 +1283,7 @@ void NodeManager::ProcessSetResourceRequest(
|
||||
auto data_shared_ptr = std::make_shared<ClientTableData>(data);
|
||||
auto client_table = gcs_client_->client_table();
|
||||
RAY_CHECK_OK(gcs_client_->client_table().Append(
|
||||
DriverID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr));
|
||||
JobID::Nil(), client_table.client_log_key_, data_shared_ptr, nullptr));
|
||||
}
|
||||
|
||||
void NodeManager::ScheduleTasks(
|
||||
@@ -1354,7 +1353,7 @@ void NodeManager::ScheduleTasks(
|
||||
<< task.GetTaskSpecification().GetRequiredPlacementResources().ToString()
|
||||
<< " for placement. Check the client table to view node resources.";
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
task.GetTaskSpecification().DriverId(), type, error_message.str(),
|
||||
task.GetTaskSpecification().JobId(), type, error_message.str(),
|
||||
current_time_ms()));
|
||||
}
|
||||
// Assert that this placeable task is not feasible locally (necessary but not
|
||||
@@ -1415,8 +1414,7 @@ void NodeManager::TreatTaskAsFailed(const Task &task, const ErrorType &error_typ
|
||||
std::string error_message = stream.str();
|
||||
RAY_LOG(WARNING) << error_message;
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
task.GetTaskSpecification().DriverId(), "task", error_message,
|
||||
current_time_ms()));
|
||||
task.GetTaskSpecification().JobId(), "task", error_message, current_time_ms()));
|
||||
}
|
||||
}
|
||||
task_dependency_manager_.TaskCanceled(spec.TaskId());
|
||||
@@ -1558,7 +1556,7 @@ void NodeManager::SubmitTask(const Task &task, const Lineage &uncommitted_lineag
|
||||
HandleActorStateTransition(actor_id, ActorRegistration(data.back()));
|
||||
}
|
||||
};
|
||||
RAY_CHECK_OK(gcs_client_->actor_table().Lookup(DriverID::Nil(), spec.ActorId(),
|
||||
RAY_CHECK_OK(gcs_client_->actor_table().Lookup(JobID::Nil(), spec.ActorId(),
|
||||
lookup_callback));
|
||||
actor_creation_dummy_object = spec.ActorCreationDummyObjectId();
|
||||
} else {
|
||||
@@ -1783,7 +1781,7 @@ bool NodeManager::AssignTask(const Task &task) {
|
||||
auto spec = assigned_task.GetTaskSpecification();
|
||||
// We successfully assigned the task to the worker.
|
||||
worker->AssignTaskId(spec.TaskId());
|
||||
worker->AssignDriverId(spec.DriverId());
|
||||
worker->AssignJobId(spec.JobId());
|
||||
// Actor tasks require extra accounting to track the actor's state.
|
||||
if (spec.IsActorTask()) {
|
||||
auto actor_entry = actor_registry_.find(spec.ActorId());
|
||||
@@ -1870,10 +1868,10 @@ void NodeManager::FinishAssignedTask(Worker &worker) {
|
||||
|
||||
// Unset the worker's assigned task.
|
||||
worker.AssignTaskId(TaskID::Nil());
|
||||
// Unset the worker's assigned driver Id if this is not an actor.
|
||||
// Unset the worker's assigned job Id if this is not an actor.
|
||||
if (!task.GetTaskSpecification().IsActorCreationTask() &&
|
||||
!task.GetTaskSpecification().IsActorTask()) {
|
||||
worker.AssignDriverId(DriverID::Nil());
|
||||
worker.AssignJobId(JobID::Nil());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1892,7 +1890,7 @@ ActorTableData NodeManager::CreateActorTableDataFromCreationTask(const Task &tas
|
||||
new_actor_data.set_actor_id(actor_id.Binary());
|
||||
new_actor_data.set_actor_creation_dummy_object_id(
|
||||
task.GetTaskSpecification().ActorDummyObject().Binary());
|
||||
new_actor_data.set_driver_id(task.GetTaskSpecification().DriverId().Binary());
|
||||
new_actor_data.set_job_id(task.GetTaskSpecification().JobId().Binary());
|
||||
new_actor_data.set_max_reconstructions(
|
||||
task.GetTaskSpecification().MaxActorReconstructions());
|
||||
// This is the first time that the actor has been created, so the number
|
||||
@@ -1948,7 +1946,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
|
||||
RAY_LOG(DEBUG) << "Looking up checkpoint " << checkpoint_id << " for actor "
|
||||
<< actor_id;
|
||||
RAY_CHECK_OK(gcs_client_->actor_checkpoint_table().Lookup(
|
||||
DriverID::Nil(), checkpoint_id,
|
||||
JobID::Nil(), checkpoint_id,
|
||||
[this, actor_id, new_actor_data](ray::gcs::AsyncGcsClient *client,
|
||||
const UniqueID &checkpoint_id,
|
||||
const ActorCheckpointData &checkpoint_data) {
|
||||
@@ -2017,7 +2015,7 @@ void NodeManager::FinishAssignedActorTask(Worker &worker, const Task &task) {
|
||||
void NodeManager::HandleTaskReconstruction(const TaskID &task_id) {
|
||||
// Retrieve the task spec in order to re-execute the task.
|
||||
RAY_CHECK_OK(gcs_client_->raylet_task_table().Lookup(
|
||||
DriverID::Nil(), task_id,
|
||||
JobID::Nil(), task_id,
|
||||
/*success_callback=*/
|
||||
[this](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
|
||||
const TaskTableData &task_data) {
|
||||
@@ -2072,7 +2070,7 @@ void NodeManager::ResubmitTask(const Task &task) {
|
||||
<< " is a driver task and so the object created by ray.put "
|
||||
<< "could not be reconstructed.";
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
task.GetTaskSpecification().DriverId(), type, error_message.str(),
|
||||
task.GetTaskSpecification().JobId(), type, error_message.str(),
|
||||
current_time_ms()));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -27,10 +27,10 @@ namespace raylet {
|
||||
|
||||
using rpc::ActorTableData;
|
||||
using rpc::ClientTableData;
|
||||
using rpc::DriverTableData;
|
||||
using rpc::ErrorType;
|
||||
using rpc::HeartbeatBatchTableData;
|
||||
using rpc::HeartbeatTableData;
|
||||
using rpc::JobTableData;
|
||||
|
||||
struct NodeManagerConfig {
|
||||
/// The node's resource configuration.
|
||||
@@ -326,12 +326,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
const ActorID &actor_id, const ActorTableData &data,
|
||||
const ray::gcs::ActorTable::WriteCallback &failure_callback);
|
||||
|
||||
/// When a driver dies, loop over all of the queued tasks for that driver and
|
||||
/// When a job finished, loop over all of the queued tasks for that job and
|
||||
/// treat them as failed.
|
||||
///
|
||||
/// \param driver_id The driver that died.
|
||||
/// \param job_id The job that exited.
|
||||
/// \return Void.
|
||||
void CleanUpTasksForDeadDriver(const DriverID &driver_id);
|
||||
void CleanUpTasksForFinishedJob(const JobID &job_id);
|
||||
|
||||
/// Handle an object becoming local. This updates any local accounting, but
|
||||
/// does not write to any global accounting in the GCS.
|
||||
@@ -346,13 +346,12 @@ class NodeManager : public rpc::NodeManagerServiceHandler {
|
||||
/// \return Void.
|
||||
void HandleObjectMissing(const ObjectID &object_id);
|
||||
|
||||
/// Handles updates to driver table.
|
||||
/// Handles updates to job table.
|
||||
///
|
||||
/// \param id An unused value. TODO(rkn): Should this be removed?
|
||||
/// \param driver_data Data associated with a driver table event.
|
||||
/// \param job_data Data associated with a job table event.
|
||||
/// \return Void.
|
||||
void HandleDriverTableUpdate(const DriverID &id,
|
||||
const std::vector<DriverTableData> &driver_data);
|
||||
void HandleJobTableUpdate(const JobID &id, const std::vector<JobTableData> &job_data);
|
||||
|
||||
/// Check if certain invariants associated with the task dependency manager
|
||||
/// and the local queues are satisfied. This is only used for debugging
|
||||
|
||||
@@ -202,18 +202,14 @@ ray::Status RayletConnection::AtomicRequestReply(
|
||||
}
|
||||
|
||||
RayletClient::RayletClient(const std::string &raylet_socket, const ClientID &client_id,
|
||||
bool is_worker, const DriverID &driver_id,
|
||||
const Language &language)
|
||||
: client_id_(client_id),
|
||||
is_worker_(is_worker),
|
||||
driver_id_(driver_id),
|
||||
language_(language) {
|
||||
bool is_worker, const JobID &job_id, const Language &language)
|
||||
: client_id_(client_id), is_worker_(is_worker), job_id_(job_id), language_(language) {
|
||||
// For C++14, we could use std::make_unique
|
||||
conn_ = std::unique_ptr<RayletConnection>(new RayletConnection(raylet_socket, -1, -1));
|
||||
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreateRegisterClientRequest(
|
||||
fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, driver_id),
|
||||
fbb, is_worker, to_flatbuf(fbb, client_id), getpid(), to_flatbuf(fbb, job_id),
|
||||
language);
|
||||
fbb.Finish(message);
|
||||
// Register the process ID with the raylet.
|
||||
@@ -323,11 +319,11 @@ ray::Status RayletClient::Wait(const std::vector<ObjectID> &object_ids, int num_
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
ray::Status RayletClient::PushError(const DriverID &driver_id, const std::string &type,
|
||||
ray::Status RayletClient::PushError(const ray::JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp) {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
auto message = ray::protocol::CreatePushErrorRequest(
|
||||
fbb, to_flatbuf(fbb, driver_id), fbb.CreateString(type),
|
||||
fbb, to_flatbuf(fbb, job_id), fbb.CreateString(type),
|
||||
fbb.CreateString(error_message), timestamp);
|
||||
fbb.Finish(message);
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
using ray::ActorCheckpointID;
|
||||
using ray::ActorID;
|
||||
using ray::ClientID;
|
||||
using ray::DriverID;
|
||||
using ray::JobID;
|
||||
using ray::ObjectID;
|
||||
using ray::TaskID;
|
||||
using ray::UniqueID;
|
||||
@@ -30,7 +30,7 @@ class RayletConnection {
|
||||
/// \param worker_id A unique ID to represent the worker.
|
||||
/// \param is_worker Whether this client is a worker. If it is a worker, an
|
||||
/// additional message will be sent to register as one.
|
||||
/// \param driver_id The ID of the driver. This is non-nil if the client is a
|
||||
/// \param job_id The ID of the driver. This is non-nil if the client is a
|
||||
/// driver.
|
||||
/// \return The connection information.
|
||||
RayletConnection(const std::string &raylet_socket, int num_retries, int64_t timeout);
|
||||
@@ -66,10 +66,10 @@ class RayletClient {
|
||||
/// \param worker_id A unique ID to represent the worker.
|
||||
/// \param is_worker Whether this client is a worker. If it is a worker, an
|
||||
/// additional message will be sent to register as one.
|
||||
/// \param driver_id The ID of the driver. This is non-nil if the client is a driver.
|
||||
/// \param job_id The ID of the driver. This is non-nil if the client is a driver.
|
||||
/// \return The connection information.
|
||||
RayletClient(const std::string &raylet_socket, const ClientID &client_id,
|
||||
bool is_worker, const DriverID &driver_id, const Language &language);
|
||||
bool is_worker, const JobID &job_id, const Language &language);
|
||||
|
||||
ray::Status Disconnect() { return conn_->Disconnect(); };
|
||||
|
||||
@@ -125,12 +125,12 @@ class RayletClient {
|
||||
|
||||
/// Push an error to the relevant driver.
|
||||
///
|
||||
/// \param The ID of the job that the error is for.
|
||||
/// \param The ID of the job_id that the error is for.
|
||||
/// \param The type of the error.
|
||||
/// \param The error message.
|
||||
/// \param The timestamp of the error.
|
||||
/// \return ray::Status.
|
||||
ray::Status PushError(const DriverID &driver_id, const std::string &type,
|
||||
ray::Status PushError(const ray::JobID &job_id, const std::string &type,
|
||||
const std::string &error_message, double timestamp);
|
||||
|
||||
/// Store some profile events in the GCS.
|
||||
@@ -177,7 +177,7 @@ class RayletClient {
|
||||
|
||||
ClientID GetClientID() const { return client_id_; }
|
||||
|
||||
DriverID GetDriverID() const { return driver_id_; }
|
||||
JobID GetJobID() const { return job_id_; }
|
||||
|
||||
bool IsWorker() const { return is_worker_; }
|
||||
|
||||
@@ -186,7 +186,7 @@ class RayletClient {
|
||||
private:
|
||||
const ClientID client_id_;
|
||||
const bool is_worker_;
|
||||
const DriverID driver_id_;
|
||||
const JobID job_id_;
|
||||
const Language language_;
|
||||
/// A map from resource name to the resource IDs that are currently reserved
|
||||
/// for this worker. Each pair consists of the resource ID and the fraction
|
||||
|
||||
@@ -52,7 +52,7 @@ void ReconstructionPolicy::SetTaskTimeout(
|
||||
// required by the task are no longer needed soon after. If the
|
||||
// task is still required after this initial period, then we now
|
||||
// subscribe to task lease notifications.
|
||||
RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(DriverID::Nil(), task_id,
|
||||
RAY_CHECK_OK(task_lease_pubsub_.RequestNotifications(JobID::Nil(), task_id,
|
||||
client_id_));
|
||||
it->second.subscribed = true;
|
||||
}
|
||||
@@ -110,7 +110,7 @@ void ReconstructionPolicy::AttemptReconstruction(const TaskID &task_id,
|
||||
reconstruction_entry->set_num_reconstructions(reconstruction_attempt);
|
||||
reconstruction_entry->set_node_manager_id(client_id_.Binary());
|
||||
RAY_CHECK_OK(task_reconstruction_log_.AppendAt(
|
||||
DriverID::Nil(), task_id, reconstruction_entry,
|
||||
JobID::Nil(), task_id, reconstruction_entry,
|
||||
/*success_callback=*/
|
||||
[this](gcs::AsyncGcsClient *client, const TaskID &task_id,
|
||||
const TaskReconstructionData &data) {
|
||||
@@ -199,7 +199,7 @@ void ReconstructionPolicy::Cancel(const ObjectID &object_id) {
|
||||
// Cancel notifications for the task lease if we were subscribed to them.
|
||||
if (it->second.subscribed) {
|
||||
RAY_CHECK_OK(
|
||||
task_lease_pubsub_.CancelNotifications(DriverID::Nil(), task_id, client_id_));
|
||||
task_lease_pubsub_.CancelNotifications(JobID::Nil(), task_id, client_id_));
|
||||
}
|
||||
listening_tasks_.erase(it);
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
|
||||
failure_callback_ = failure_callback;
|
||||
}
|
||||
|
||||
void Add(const DriverID &driver_id, const TaskID &task_id,
|
||||
void Add(const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<TaskLeaseData> &task_lease_data) {
|
||||
task_lease_table_[task_id] = task_lease_data;
|
||||
if (subscribed_tasks_.count(task_id) == 1) {
|
||||
@@ -92,7 +92,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
|
||||
}
|
||||
}
|
||||
|
||||
Status RequestNotifications(const DriverID &driver_id, const TaskID &task_id,
|
||||
Status RequestNotifications(const JobID &job_id, const TaskID &task_id,
|
||||
const ClientID &client_id) {
|
||||
subscribed_tasks_.insert(task_id);
|
||||
auto entry = task_lease_table_.find(task_id);
|
||||
@@ -104,14 +104,14 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
Status CancelNotifications(const DriverID &driver_id, const TaskID &task_id,
|
||||
Status CancelNotifications(const JobID &job_id, const TaskID &task_id,
|
||||
const ClientID &client_id) {
|
||||
subscribed_tasks_.erase(task_id);
|
||||
return ray::Status::OK();
|
||||
}
|
||||
|
||||
Status AppendAt(
|
||||
const DriverID &driver_id, const TaskID &task_id,
|
||||
const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<TaskReconstructionData> &task_data,
|
||||
const ray::gcs::LogInterface<TaskID, TaskReconstructionData>::WriteCallback
|
||||
&success_callback,
|
||||
@@ -134,7 +134,7 @@ class MockGcs : public gcs::PubsubInterface<TaskID>,
|
||||
MOCK_METHOD4(
|
||||
Append,
|
||||
ray::Status(
|
||||
const DriverID &, const TaskID &, std::shared_ptr<TaskReconstructionData> &,
|
||||
const JobID &, const TaskID &, std::shared_ptr<TaskReconstructionData> &,
|
||||
const ray::gcs::LogInterface<TaskID, TaskReconstructionData>::WriteCallback &));
|
||||
|
||||
private:
|
||||
@@ -320,7 +320,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionSuppressed) {
|
||||
task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary());
|
||||
task_lease_data->set_acquired_at(current_sys_time_ms());
|
||||
task_lease_data->set_timeout(2 * test_period);
|
||||
mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data);
|
||||
mock_gcs_.Add(JobID::Nil(), task_id, task_lease_data);
|
||||
|
||||
// Listen for an object.
|
||||
reconstruction_policy_->ListenAndMaybeReconstruct(object_id);
|
||||
@@ -347,7 +347,7 @@ TEST_F(ReconstructionPolicyTest, TestReconstructionContinuallySuppressed) {
|
||||
task_lease_data->set_node_manager_id(ClientID::FromRandom().Binary());
|
||||
task_lease_data->set_acquired_at(current_sys_time_ms());
|
||||
task_lease_data->set_timeout(reconstruction_timeout_ms_);
|
||||
mock_gcs_.Add(DriverID::Nil(), task_id, task_lease_data);
|
||||
mock_gcs_.Add(JobID::Nil(), task_id, task_lease_data);
|
||||
});
|
||||
// Run the test for much longer than the reconstruction timeout.
|
||||
Run(reconstruction_timeout_ms_ * 2);
|
||||
@@ -399,7 +399,7 @@ TEST_F(ReconstructionPolicyTest, TestSimultaneousReconstructionSuppressed) {
|
||||
task_reconstruction_data->set_node_manager_id(ClientID::FromRandom().Binary());
|
||||
task_reconstruction_data->set_num_reconstructions(0);
|
||||
RAY_CHECK_OK(
|
||||
mock_gcs_.AppendAt(DriverID::Nil(), task_id, task_reconstruction_data, nullptr,
|
||||
mock_gcs_.AppendAt(JobID::Nil(), task_id, task_reconstruction_data, nullptr,
|
||||
/*failure_callback=*/
|
||||
[](ray::gcs::AsyncGcsClient *client, const TaskID &task_id,
|
||||
const TaskReconstructionData &data) { ASSERT_TRUE(false); },
|
||||
|
||||
@@ -19,15 +19,14 @@ inline const char *GetTaskStateString(ray::raylet::TaskState task_state) {
|
||||
return task_state_strings[static_cast<int>(task_state)];
|
||||
}
|
||||
|
||||
// Helper function to get tasks for a driver from a given state.
|
||||
// Helper function to get tasks for a job from a given state.
|
||||
template <typename TaskQueue>
|
||||
inline void GetDriverTasksFromQueue(const TaskQueue &queue,
|
||||
const ray::DriverID &driver_id,
|
||||
inline void GetTasksForJobFromQueue(const TaskQueue &queue, const ray::JobID &job_id,
|
||||
std::unordered_set<ray::TaskID> &task_ids) {
|
||||
const auto &tasks = queue.GetTasks();
|
||||
for (const auto &task : tasks) {
|
||||
auto const &spec = task.GetTaskSpecification();
|
||||
if (driver_id == spec.DriverId()) {
|
||||
if (job_id == spec.JobId()) {
|
||||
task_ids.insert(spec.TaskId());
|
||||
}
|
||||
}
|
||||
@@ -187,9 +186,9 @@ void SchedulingQueue::FilterState(std::unordered_set<TaskID> &task_ids,
|
||||
}
|
||||
} break;
|
||||
case TaskState::DRIVER: {
|
||||
const auto driver_ids = GetDriverTaskIds();
|
||||
const auto driver_task_ids = GetDriverTaskIds();
|
||||
for (auto it = task_ids.begin(); it != task_ids.end();) {
|
||||
if (driver_ids.count(*it) == 1) {
|
||||
if (driver_task_ids.count(*it) == 1) {
|
||||
it = task_ids.erase(it);
|
||||
} else {
|
||||
it++;
|
||||
@@ -356,11 +355,10 @@ bool SchedulingQueue::HasTask(const TaskID &task_id) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::unordered_set<TaskID> SchedulingQueue::GetTaskIdsForDriver(
|
||||
const DriverID &driver_id) const {
|
||||
std::unordered_set<TaskID> SchedulingQueue::GetTaskIdsForJob(const JobID &job_id) const {
|
||||
std::unordered_set<TaskID> task_ids;
|
||||
for (const auto &task_queue : task_queues_) {
|
||||
GetDriverTasksFromQueue(*task_queue, driver_id, task_ids);
|
||||
GetTasksForJobFromQueue(*task_queue, job_id, task_ids);
|
||||
}
|
||||
return task_ids;
|
||||
}
|
||||
@@ -394,15 +392,15 @@ void SchedulingQueue::RemoveBlockedTaskId(const TaskID &task_id) {
|
||||
RAY_CHECK(erased == 1);
|
||||
}
|
||||
|
||||
void SchedulingQueue::AddDriverTaskId(const TaskID &driver_id) {
|
||||
RAY_LOG(DEBUG) << "Added driver task " << driver_id;
|
||||
auto inserted = driver_task_ids_.insert(driver_id);
|
||||
void SchedulingQueue::AddDriverTaskId(const TaskID &task_id) {
|
||||
RAY_LOG(DEBUG) << "Added driver task " << task_id;
|
||||
auto inserted = driver_task_ids_.insert(task_id);
|
||||
RAY_CHECK(inserted.second);
|
||||
}
|
||||
|
||||
void SchedulingQueue::RemoveDriverTaskId(const TaskID &driver_id) {
|
||||
RAY_LOG(DEBUG) << "Removed driver task " << driver_id;
|
||||
auto erased = driver_task_ids_.erase(driver_id);
|
||||
void SchedulingQueue::RemoveDriverTaskId(const TaskID &task_id) {
|
||||
RAY_LOG(DEBUG) << "Removed driver task " << task_id;
|
||||
auto erased = driver_task_ids_.erase(task_id);
|
||||
RAY_CHECK(erased == 1);
|
||||
}
|
||||
|
||||
|
||||
@@ -283,11 +283,11 @@ class SchedulingQueue {
|
||||
/// \param filter_state The task state to filter out.
|
||||
void FilterState(std::unordered_set<TaskID> &task_ids, TaskState filter_state) const;
|
||||
|
||||
/// \brief Get all the task IDs for a driver.
|
||||
/// \brief Get all the task IDs for a job.
|
||||
///
|
||||
/// \param driver_id All the tasks that have the given driver_id are returned.
|
||||
/// \return All the tasks that have the given driver ID.
|
||||
std::unordered_set<TaskID> GetTaskIdsForDriver(const DriverID &driver_id) const;
|
||||
/// \param job_id All the tasks that have the given job_id are returned.
|
||||
/// \return All the tasks that have the given job ID.
|
||||
std::unordered_set<TaskID> GetTaskIdsForJob(const JobID &job_id) const;
|
||||
|
||||
/// \brief Get all the task IDs for an actor.
|
||||
///
|
||||
|
||||
@@ -265,7 +265,7 @@ void TaskDependencyManager::AcquireTaskLease(const TaskID &task_id) {
|
||||
task_lease_data->set_node_manager_id(client_id_.Hex());
|
||||
task_lease_data->set_acquired_at(current_sys_time_ms());
|
||||
task_lease_data->set_timeout(it->second.lease_period);
|
||||
RAY_CHECK_OK(task_lease_table_.Add(DriverID::Nil(), task_id, task_lease_data, nullptr));
|
||||
RAY_CHECK_OK(task_lease_table_.Add(JobID::Nil(), task_id, task_lease_data, nullptr));
|
||||
|
||||
auto period = boost::posix_time::milliseconds(it->second.lease_period / 2);
|
||||
it->second.lease_timer->expires_from_now(period);
|
||||
|
||||
@@ -29,7 +29,7 @@ class MockGcs : public gcs::TableInterface<TaskID, TaskLeaseData> {
|
||||
public:
|
||||
MOCK_METHOD4(
|
||||
Add,
|
||||
ray::Status(const DriverID &driver_id, const TaskID &task_id,
|
||||
ray::Status(const JobID &job_id, const TaskID &task_id,
|
||||
std::shared_ptr<TaskLeaseData> &task_data,
|
||||
const gcs::TableInterface<TaskID, TaskLeaseData>::WriteCallback &done));
|
||||
};
|
||||
@@ -75,7 +75,7 @@ static inline Task ExampleTask(const std::vector<ObjectID> &arguments,
|
||||
task_arguments.emplace_back(std::make_shared<TaskArgumentByReference>(references));
|
||||
}
|
||||
std::vector<std::string> function_descriptor(3);
|
||||
auto spec = TaskSpecification(DriverID::Nil(), TaskID::FromRandom(), 0, task_arguments,
|
||||
auto spec = TaskSpecification(JobID::Nil(), TaskID::FromRandom(), 0, task_arguments,
|
||||
num_returns, required_resources, Language::PYTHON,
|
||||
function_descriptor);
|
||||
auto execution_spec = TaskExecutionSpecification(std::vector<ObjectID>());
|
||||
|
||||
@@ -61,18 +61,18 @@ TaskSpecification::TaskSpecification(const uint8_t *spec, size_t spec_size) {
|
||||
}
|
||||
|
||||
TaskSpecification::TaskSpecification(
|
||||
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments, int64_t num_returns,
|
||||
const std::unordered_map<std::string, double> &required_resources,
|
||||
const Language &language, const std::vector<std::string> &function_descriptor)
|
||||
: TaskSpecification(driver_id, parent_task_id, parent_counter, ActorID::Nil(),
|
||||
: TaskSpecification(job_id, parent_task_id, parent_counter, ActorID::Nil(),
|
||||
ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), -1, {},
|
||||
task_arguments, num_returns, required_resources,
|
||||
std::unordered_map<std::string, double>(), language,
|
||||
function_descriptor) {}
|
||||
|
||||
TaskSpecification::TaskSpecification(
|
||||
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
|
||||
const int64_t max_actor_reconstructions, const ActorID &actor_id,
|
||||
const ActorHandleID &actor_handle_id, int64_t actor_counter,
|
||||
@@ -85,7 +85,7 @@ TaskSpecification::TaskSpecification(
|
||||
: spec_() {
|
||||
flatbuffers::FlatBufferBuilder fbb;
|
||||
|
||||
TaskID task_id = GenerateTaskId(driver_id, parent_task_id, parent_counter);
|
||||
TaskID task_id = GenerateTaskId(job_id, parent_task_id, parent_counter);
|
||||
// Add argument object IDs.
|
||||
std::vector<flatbuffers::Offset<Arg>> arguments;
|
||||
for (auto &argument : task_arguments) {
|
||||
@@ -94,7 +94,7 @@ TaskSpecification::TaskSpecification(
|
||||
|
||||
// Serialize the TaskSpecification.
|
||||
auto spec = CreateTaskInfo(
|
||||
fbb, to_flatbuf(fbb, driver_id), to_flatbuf(fbb, task_id),
|
||||
fbb, to_flatbuf(fbb, job_id), to_flatbuf(fbb, task_id),
|
||||
to_flatbuf(fbb, parent_task_id), parent_counter, to_flatbuf(fbb, actor_creation_id),
|
||||
to_flatbuf(fbb, actor_creation_dummy_object_id), max_actor_reconstructions,
|
||||
to_flatbuf(fbb, actor_id), to_flatbuf(fbb, actor_handle_id), actor_counter,
|
||||
@@ -123,9 +123,9 @@ TaskID TaskSpecification::TaskId() const {
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
return from_flatbuf<TaskID>(*message->task_id());
|
||||
}
|
||||
DriverID TaskSpecification::DriverId() const {
|
||||
JobID TaskSpecification::JobId() const {
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
return from_flatbuf<DriverID>(*message->driver_id());
|
||||
return from_flatbuf<JobID>(*message->job_id());
|
||||
}
|
||||
TaskID TaskSpecification::ParentTaskId() const {
|
||||
auto message = flatbuffers::GetRoot<TaskInfo>(spec_.data());
|
||||
|
||||
@@ -86,7 +86,7 @@ class TaskSpecification {
|
||||
/// Create a task specification from the raw fields. This constructor omits
|
||||
/// some values and sets them to sensible defaults.
|
||||
///
|
||||
/// \param driver_id The driver ID, representing the job that this task is a
|
||||
/// \param job_id The driver ID, representing the job that this task is a
|
||||
/// part of.
|
||||
/// \param parent_task_id The task ID of the task that spawned this task.
|
||||
/// \param parent_counter The number of tasks that this task's parent spawned
|
||||
@@ -96,7 +96,7 @@ class TaskSpecification {
|
||||
/// \param num_returns The number of values returned by the task.
|
||||
/// \param required_resources The task's resource demands.
|
||||
/// \param language The language of the worker that must execute the function.
|
||||
TaskSpecification(const DriverID &driver_id, const TaskID &parent_task_id,
|
||||
TaskSpecification(const JobID &job_id, const TaskID &parent_task_id,
|
||||
int64_t parent_counter,
|
||||
const std::vector<std::shared_ptr<TaskArgument>> &task_arguments,
|
||||
int64_t num_returns,
|
||||
@@ -107,7 +107,7 @@ class TaskSpecification {
|
||||
// TODO(swang): Define an actor task constructor.
|
||||
/// Create a task specification from the raw fields.
|
||||
///
|
||||
/// \param driver_id The driver ID, representing the job that this task is a
|
||||
/// \param job_id The driver ID, representing the job that this task is a
|
||||
/// part of.
|
||||
/// \param parent_task_id The task ID of the task that spawned this task.
|
||||
/// \param parent_counter The number of tasks that this task's parent spawned
|
||||
@@ -130,7 +130,7 @@ class TaskSpecification {
|
||||
/// \param function_descriptor The function descriptor.
|
||||
/// \param dynamic_worker_options The dynamic options for starting an actor worker.
|
||||
TaskSpecification(
|
||||
const DriverID &driver_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const JobID &job_id, const TaskID &parent_task_id, int64_t parent_counter,
|
||||
const ActorID &actor_creation_id, const ObjectID &actor_creation_dummy_object_id,
|
||||
int64_t max_actor_reconstructions, const ActorID &actor_id,
|
||||
const ActorHandleID &actor_handle_id, int64_t actor_counter,
|
||||
@@ -171,7 +171,7 @@ class TaskSpecification {
|
||||
|
||||
// TODO(swang): Finalize and document these methods.
|
||||
TaskID TaskId() const;
|
||||
DriverID DriverId() const;
|
||||
JobID JobId() const;
|
||||
TaskID ParentTaskId() const;
|
||||
int64_t ParentCounter() const;
|
||||
std::vector<std::string> FunctionDescriptor() const;
|
||||
|
||||
@@ -64,7 +64,7 @@ TEST(TaskSpecTest, TaskInfoSize) {
|
||||
}
|
||||
// General task.
|
||||
auto spec = CreateTaskInfo(
|
||||
fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id),
|
||||
fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id),
|
||||
to_flatbuf(fbb, TaskID::FromRandom()), 0, to_flatbuf(fbb, ActorID::Nil()),
|
||||
to_flatbuf(fbb, ObjectID::Nil()), 0, to_flatbuf(fbb, ActorID::Nil()),
|
||||
to_flatbuf(fbb, ActorHandleID::Nil()), 0,
|
||||
@@ -83,7 +83,7 @@ TEST(TaskSpecTest, TaskInfoSize) {
|
||||
}
|
||||
// General task.
|
||||
auto spec = CreateTaskInfo(
|
||||
fbb, to_flatbuf(fbb, DriverID::FromRandom()), to_flatbuf(fbb, task_id),
|
||||
fbb, to_flatbuf(fbb, JobID::FromRandom()), to_flatbuf(fbb, task_id),
|
||||
to_flatbuf(fbb, TaskID::FromRandom()), 10, to_flatbuf(fbb, ActorID::FromRandom()),
|
||||
to_flatbuf(fbb, ObjectID::FromRandom()), 10000000,
|
||||
to_flatbuf(fbb, ActorID::FromRandom()),
|
||||
|
||||
@@ -50,11 +50,9 @@ const std::unordered_set<TaskID> &Worker::GetBlockedTaskIds() const {
|
||||
return blocked_task_ids_;
|
||||
}
|
||||
|
||||
void Worker::AssignDriverId(const DriverID &driver_id) {
|
||||
assigned_driver_id_ = driver_id;
|
||||
}
|
||||
void Worker::AssignJobId(const JobID &job_id) { assigned_job_id_ = job_id; }
|
||||
|
||||
const DriverID &Worker::GetAssignedDriverId() const { return assigned_driver_id_; }
|
||||
const JobID &Worker::GetAssignedJobId() const { return assigned_job_id_; }
|
||||
|
||||
void Worker::AssignActorId(const ActorID &actor_id) {
|
||||
RAY_CHECK(actor_id_.IsNil())
|
||||
|
||||
@@ -34,8 +34,8 @@ class Worker {
|
||||
bool AddBlockedTaskId(const TaskID &task_id);
|
||||
bool RemoveBlockedTaskId(const TaskID &task_id);
|
||||
const std::unordered_set<TaskID> &GetBlockedTaskIds() const;
|
||||
void AssignDriverId(const DriverID &driver_id);
|
||||
const DriverID &GetAssignedDriverId() const;
|
||||
void AssignJobId(const JobID &job_id);
|
||||
const JobID &GetAssignedJobId() const;
|
||||
void AssignActorId(const ActorID &actor_id);
|
||||
const ActorID &GetActorId() const;
|
||||
/// Return the worker's connection.
|
||||
@@ -60,8 +60,8 @@ class Worker {
|
||||
std::shared_ptr<LocalClientConnection> connection_;
|
||||
/// The worker's currently assigned task.
|
||||
TaskID assigned_task_id_;
|
||||
/// Driver ID for the worker's current assigned task.
|
||||
DriverID assigned_driver_id_;
|
||||
/// Job ID for the worker's current assigned task.
|
||||
JobID assigned_job_id_;
|
||||
/// The worker's actor ID. If this is nil, then the worker is not an actor.
|
||||
ActorID actor_id_;
|
||||
/// Whether the worker is dead.
|
||||
|
||||
@@ -319,13 +319,13 @@ inline WorkerPool::State &WorkerPool::GetStateForLanguage(const Language &langua
|
||||
return state->second;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForDriver(
|
||||
const DriverID &driver_id) const {
|
||||
std::vector<std::shared_ptr<Worker>> WorkerPool::GetWorkersRunningTasksForJob(
|
||||
const JobID &job_id) const {
|
||||
std::vector<std::shared_ptr<Worker>> workers;
|
||||
|
||||
for (const auto &entry : states_by_lang_) {
|
||||
for (const auto &worker : entry.second.registered_workers) {
|
||||
if (worker->GetAssignedDriverId() == driver_id) {
|
||||
if (worker->GetAssignedJobId() == job_id) {
|
||||
workers.push_back(worker);
|
||||
}
|
||||
}
|
||||
@@ -355,7 +355,7 @@ void WorkerPool::WarnAboutSize() {
|
||||
<< "(see https://github.com/ray-project/ray/issues/3644) for "
|
||||
<< "some a discussion of workarounds.";
|
||||
RAY_CHECK_OK(gcs_client_->error_table().PushErrorToDriver(
|
||||
DriverID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms()));
|
||||
JobID::Nil(), "worker_pool_large", warning_message.str(), current_time_ms()));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -102,12 +102,12 @@ class WorkerPool {
|
||||
/// \return The total count of all workers (actor and non-actor) in the pool.
|
||||
uint32_t Size(const Language &language) const;
|
||||
|
||||
/// Get all the workers which are running tasks for a given driver.
|
||||
/// Get all the workers which are running tasks for a given job.
|
||||
///
|
||||
/// \param driver_id The driver ID.
|
||||
/// \return A list containing all the workers which are running tasks for the driver.
|
||||
std::vector<std::shared_ptr<Worker>> GetWorkersRunningTasksForDriver(
|
||||
const DriverID &driver_id) const;
|
||||
/// \param job_id The job ID.
|
||||
/// \return A list containing all the workers which are running tasks for the job.
|
||||
std::vector<std::shared_ptr<Worker>> GetWorkersRunningTasksForJob(
|
||||
const JobID &job_id) const;
|
||||
|
||||
/// Whether there is a pending worker for the given task.
|
||||
/// Note that, this is only used for actor creation task with dynamic options.
|
||||
|
||||
@@ -109,7 +109,7 @@ static inline TaskSpecification ExampleTaskSpec(
|
||||
const ActorID actor_id = ActorID::Nil(), const Language &language = Language::PYTHON,
|
||||
const ActorID actor_creation_id = ActorID::Nil()) {
|
||||
std::vector<std::string> function_descriptor(3);
|
||||
return TaskSpecification(DriverID::Nil(), TaskID::Nil(), 0, actor_creation_id,
|
||||
return TaskSpecification(JobID::Nil(), TaskID::Nil(), 0, actor_creation_id,
|
||||
ObjectID::Nil(), 0, actor_id, ActorHandleID::Nil(), 0, {}, {},
|
||||
0, {}, {}, language, function_descriptor);
|
||||
}
|
||||
@@ -226,7 +226,7 @@ TEST_F(WorkerPoolTest, StartWorkerWithDynamicOptionsCommand) {
|
||||
SetWorkerCommands({{Language::PYTHON, {"dummy_py_worker_command"}},
|
||||
{Language::JAVA, java_worker_command}});
|
||||
|
||||
TaskSpecification task_spec(DriverID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(),
|
||||
TaskSpecification task_spec(JobID::Nil(), TaskID::Nil(), 0, ActorID::FromRandom(),
|
||||
ObjectID::Nil(), 0, ActorID::Nil(), ActorHandleID::Nil(), 0,
|
||||
{}, {}, 0, {}, {}, Language::JAVA, {"", "", ""},
|
||||
{"test_op_0", "test_op_1"});
|
||||
|
||||
Reference in New Issue
Block a user