diff --git a/java/api/src/main/java/org/ray/api/id/JobId.java b/java/api/src/main/java/org/ray/api/id/JobId.java new file mode 100644 index 000000000..53157cd5a --- /dev/null +++ b/java/api/src/main/java/org/ray/api/id/JobId.java @@ -0,0 +1,62 @@ +package org.ray.api.id; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Arrays; + +/** + * Represents the id of a Ray job. + */ +public class JobId extends BaseId implements Serializable { + + // Note that the max value of a job id is NIL which value is (2^32 - 1). + public static final Long MAX_VALUE = (long) Math.pow(2, 32) - 1; + + public static final int LENGTH = 4; + + public static final JobId NIL = genNil(); + + /** + * Create a JobID instance according to the given bytes. + */ + private JobId(byte[] id) { + super(id); + } + + /** + * Create a JobId from a given hex string. + */ + public static JobId fromHexString(String hex) { + return new JobId(hexString2Bytes(hex)); + } + + /** + * Creates a JobId from the given ByteBuffer. + */ + public static JobId fromByteBuffer(ByteBuffer bb) { + return new JobId(byteBuffer2Bytes(bb)); + } + + public static JobId fromInt(int value) { + byte[] bytes = new byte[JobId.LENGTH]; + ByteBuffer wbb = ByteBuffer.wrap(bytes); + wbb.order(ByteOrder.LITTLE_ENDIAN); + wbb.putInt(value); + return new JobId(bytes); + } + + /** + * Generate a nil JobId. + */ + private static JobId genNil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new JobId(b); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java index 6e0feee10..5ce1fc383 100644 --- a/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java +++ b/java/api/src/main/java/org/ray/api/runtimecontext/RuntimeContext.java @@ -1,6 +1,7 @@ package org.ray.api.runtimecontext; import java.util.List; +import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; /** @@ -11,7 +12,7 @@ public interface RuntimeContext { /** * Get the current Job ID. */ - UniqueId getCurrentJobId(); + JobId getCurrentJobId(); /** * Get the current actor ID. diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index f77cd6658..88c281f79 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -76,8 +76,6 @@ public abstract class AbstractRayRuntime implements RayRuntime { this.rayConfig = rayConfig; functionManager = new FunctionManager(rayConfig.jobResourcePath); worker = new Worker(this); - workerContext = new WorkerContext(rayConfig.workerMode, - rayConfig.jobId, rayConfig.runMode); runtimeContext = new RuntimeContextImpl(this); } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java index e5d7b20b1..a53d59bc8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java @@ -1,5 +1,7 @@ package org.ray.runtime; +import java.util.concurrent.atomic.AtomicInteger; +import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.objectstore.MockObjectStore; import org.ray.runtime.objectstore.ObjectStoreProxy; @@ -13,9 +15,16 @@ public class RayDevRuntime extends AbstractRayRuntime { private MockObjectStore store; + private AtomicInteger jobCounter = new AtomicInteger(0); + @Override public void start() { store = new MockObjectStore(this); + if (rayConfig.getJobId().isNil()) { + rayConfig.setJobId(nextJobId()); + } + workerContext = new WorkerContext(rayConfig.workerMode, + rayConfig.getJobId(), rayConfig.runMode); objectStoreProxy = new ObjectStoreProxy(this, null); rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime); } @@ -33,4 +42,8 @@ public class RayDevRuntime extends AbstractRayRuntime { public Worker getWorker() { return ((MockRayletClient) rayletClient).getCurrentWorker(); } + + private JobId nextJobId() { + return JobId.fromInt(jobCounter.getAndIncrement()); + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 49e4f6c39..8d98b18f4 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -11,6 +11,7 @@ import java.nio.file.Paths; import java.nio.file.StandardCopyOption; import java.util.HashMap; import java.util.Map; +import org.ray.api.id.JobId; import org.ray.runtime.config.RayConfig; import org.ray.runtime.config.WorkerMode; import org.ray.runtime.gcs.GcsClient; @@ -94,6 +95,12 @@ public final class RayNativeRuntime extends AbstractRayRuntime { gcsClient = new GcsClient(rayConfig.getRedisAddress(), rayConfig.redisPassword); + if (rayConfig.getJobId() == JobId.NIL) { + rayConfig.setJobId(gcsClient.nextJobId()); + } + + workerContext = new WorkerContext(rayConfig.workerMode, + rayConfig.getJobId(), rayConfig.runMode); // TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis. objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName); diff --git a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java index c9815fd26..3286359ba 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java @@ -2,6 +2,7 @@ package org.ray.runtime; import com.google.common.base.Preconditions; import java.util.List; +import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.api.runtimecontext.RuntimeContext; @@ -17,7 +18,7 @@ public class RuntimeContextImpl implements RuntimeContext { } @Override - public UniqueId getCurrentJobId() { + public JobId getCurrentJobId() { return runtime.getWorkerContext().getCurrentJobId(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 3dc2be7ed..828d39cb5 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -1,11 +1,13 @@ package org.ray.runtime; import com.google.common.base.Preconditions; +import org.ray.api.id.JobId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.runtime.config.RunMode; import org.ray.runtime.config.WorkerMode; import org.ray.runtime.task.TaskSpec; +import org.ray.runtime.util.IdUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -29,7 +31,7 @@ public class WorkerContext { private ThreadLocal currentTask; - private UniqueId currentJobId; + private JobId currentJobId; private ClassLoader currentClassLoader; @@ -43,7 +45,7 @@ public class WorkerContext { */ private RunMode runMode; - public WorkerContext(WorkerMode workerMode, UniqueId jobId, RunMode runMode) { + public WorkerContext(WorkerMode workerMode, JobId jobId, RunMode runMode) { mainThreadId = Thread.currentThread().getId(); taskIndex = ThreadLocal.withInitial(() -> 0); putIndex = ThreadLocal.withInitial(() -> 0); @@ -52,15 +54,13 @@ public class WorkerContext { currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { - // TODO(qwang): Assign the driver id to worker id - // once we treat driver id as a special worker id. - workerId = jobId; + workerId = IdUtil.computeDriverId(jobId); currentTaskId.set(TaskId.randomId()); currentJobId = jobId; } else { workerId = UniqueId.randomId(); this.currentTaskId.set(TaskId.NIL); - this.currentJobId = UniqueId.NIL; + this.currentJobId = JobId.NIL; } } @@ -119,7 +119,7 @@ public class WorkerContext { /** * The ID of the current job. */ - public UniqueId getCurrentJobId() { + public JobId getCurrentJobId() { return currentJobId; } diff --git a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java index 27a4ce38d..e67c88d59 100644 --- a/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java +++ b/java/runtime/src/main/java/org/ray/runtime/config/RayConfig.java @@ -10,7 +10,7 @@ import java.io.File; import java.util.ArrayList; import java.util.List; import java.util.Map; -import org.ray.api.id.UniqueId; +import org.ray.api.id.JobId; import org.ray.runtime.util.NetworkUtil; import org.ray.runtime.util.ResourceUtil; import org.ray.runtime.util.StringUtil; @@ -32,7 +32,7 @@ public class RayConfig { public final WorkerMode workerMode; public final RunMode runMode; public final Map resources; - public final UniqueId jobId; + private JobId jobId; public final String logDir; public final boolean redirectOutput; public final List libraryPath; @@ -108,9 +108,9 @@ public class RayConfig { // Job id. String jobId = config.getString("ray.job.id"); if (!jobId.isEmpty()) { - this.jobId = UniqueId.fromHexString(jobId); + this.jobId = JobId.fromHexString(jobId); } else { - this.jobId = UniqueId.randomId(); + this.jobId = JobId.NIL; } // Log dir. logDir = removeTrailingSlash(config.getString("ray.log-dir")); @@ -198,6 +198,14 @@ public class RayConfig { return redisPort; } + public void setJobId(JobId jobId) { + this.jobId = jobId; + } + + public JobId getJobId() { + return this.jobId; + } + @Override public String toString() { return "RayConfig{" diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index 988dac794..230768933 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -24,7 +24,7 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.objectweb.asm.Type; import org.ray.api.function.RayFunc; -import org.ray.api.id.UniqueId; +import org.ray.api.id.JobId; import org.ray.runtime.util.LambdaUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +48,7 @@ public class FunctionManager { /** * Mapping from the job id to the functions that belong to this job. */ - private Map jobFunctionTables = new HashMap<>(); + private Map jobFunctionTables = new HashMap<>(); /** * The resource path which we can load the job's jar resources. @@ -72,7 +72,7 @@ public class FunctionManager { * @param func The lambda. * @return A RayFunction object. */ - public RayFunction getFunction(UniqueId jobId, RayFunc func) { + public RayFunction getFunction(JobId jobId, RayFunc func) { JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass()); if (functionDescriptor == null) { SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func); @@ -92,7 +92,7 @@ public class FunctionManager { * @param functionDescriptor The function descriptor. * @return A RayFunction object. */ - public RayFunction getFunction(UniqueId jobId, JavaFunctionDescriptor functionDescriptor) { + public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) { JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId); if (jobFunctionTable == null) { ClassLoader classLoader; diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java index c5f849a75..97f98eaab 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/GcsClient.java @@ -11,6 +11,7 @@ import java.util.stream.Collectors; import org.apache.commons.lang3.ArrayUtils; import org.ray.api.Checkpointable.Checkpoint; import org.ray.api.id.BaseId; +import org.ray.api.id.JobId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; import org.ray.api.runtimecontext.NodeInfo; @@ -164,6 +165,11 @@ public class GcsClient { return checkpoints; } + public JobId nextJobId() { + int jobCounter = (int) primary.incr("JobCounter".getBytes()); + return JobId.fromInt(jobCounter); + } + private RedisClient getShardClient(BaseId key) { return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key), shards.size())); diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java index dbce9750e..5e0a7b978 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java @@ -107,4 +107,9 @@ public class RedisClient { } } + public long incr(byte[] key) { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.incr(key).intValue(); + } + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java index 9d014b72e..d329459e0 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/MockRayletClient.java @@ -17,6 +17,7 @@ import java.util.concurrent.Executors; import org.apache.commons.lang3.NotImplementedException; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -164,7 +165,7 @@ public class MockRayletClient implements RayletClient { } @Override - public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) { + public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { return TaskId.randomId(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java index 0ef2163f7..8c6abcd5a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClient.java @@ -3,6 +3,7 @@ package org.ray.runtime.raylet; import java.util.List; import org.ray.api.RayObject; import org.ray.api.WaitResult; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -21,7 +22,7 @@ public interface RayletClient { void notifyUnblocked(TaskId currentTaskId); - TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex); + TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex); WaitResult wait(List> waitFor, int numReturns, int timeoutMs, TaskId currentTaskId); diff --git a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java index 8c00f718c..9577cf2e2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java +++ b/java/runtime/src/main/java/org/ray/runtime/raylet/RayletClientImpl.java @@ -15,6 +15,7 @@ import java.util.stream.Collectors; import org.ray.api.RayObject; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -39,7 +40,7 @@ public class RayletClientImpl implements RayletClient { // TODO(qwang): JobId parameter can be removed once we embed jobId in driverId. public RayletClientImpl(String schedulerSockName, UniqueId clientId, - boolean isWorker, UniqueId jobId) { + boolean isWorker, JobId jobId) { client = nativeInit(schedulerSockName, clientId.getBytes(), isWorker, jobId.getBytes()); } @@ -107,7 +108,7 @@ public class RayletClientImpl implements RayletClient { } @Override - public TaskId generateTaskId(UniqueId jobId, TaskId parentTaskId, int taskIndex) { + public TaskId generateTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) { byte[] bytes = nativeGenerateTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex); return new TaskId(bytes); } @@ -146,7 +147,7 @@ public class RayletClientImpl implements RayletClient { } // Parse common fields. - UniqueId jobId = UniqueId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer()); + JobId jobId = JobId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer()); TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer()); TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer()); int parentCounter = (int) taskSpec.getParentCounter(); diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 773499fcf..839dff95d 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -189,6 +189,8 @@ public class RunManager { client.auth(rayConfig.headRedisPassword); } client.set("UseRaylet", "1"); + // Set job counter to compute job id. + client.set("JobCounter", "0"); // Register the number of Redis shards in the primary shard, so that clients // know how many redis shards to expect under RedisShards. client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards)); diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java index 449ff6111..8d6353b46 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskSpec.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; @@ -19,7 +20,7 @@ import org.ray.runtime.util.IdUtil; public class TaskSpec { // ID of the job that created this task. - public final UniqueId jobId; + public final JobId jobId; // Task ID of the task. public final TaskId taskId; @@ -81,7 +82,7 @@ public class TaskSpec { } public TaskSpec( - UniqueId jobId, + JobId jobId, TaskId taskId, TaskId parentTaskId, int parentCounter, diff --git a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java index 67df09fa1..8a96bc57a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java +++ b/java/runtime/src/main/java/org/ray/runtime/util/IdUtil.java @@ -3,13 +3,14 @@ package org.ray.runtime.util; import com.google.common.base.Preconditions; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.Arrays; import java.util.List; import org.ray.api.id.BaseId; +import org.ray.api.id.JobId; import org.ray.api.id.ObjectId; import org.ray.api.id.TaskId; import org.ray.api.id.UniqueId; - /** * Helper method for different Ids. * Note: any changes to these methods must be synced with C++ helper functions @@ -153,6 +154,18 @@ public class IdUtil { } + /** + * Compute the driver id from the given job. + */ + public static UniqueId computeDriverId(JobId jobId) { + byte[] bytes = new byte[UniqueId.LENGTH]; + System.arraycopy(jobId.getBytes(), 0, bytes, 0, jobId.size()); + Arrays.fill(bytes, jobId.size(), UniqueId.LENGTH, (byte)0xFF); + ByteBuffer wbb = ByteBuffer.wrap(bytes); + wbb.order(ByteOrder.LITTLE_ENDIAN); + return new UniqueId(bytes); + } + /** * Compute the murmur hash code of this ID. */ diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 7c30ee755..41440b50d 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -4,6 +4,7 @@ import java.io.File; import java.nio.file.Files; import java.nio.file.Paths; import java.util.Map; +import java.util.Random; import javax.tools.JavaCompiler; import javax.tools.ToolProvider; import org.apache.commons.io.FileUtils; @@ -12,6 +13,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.ray.api.annotation.RayRemote; import org.ray.api.function.RayFunc0; import org.ray.api.function.RayFunc1; +import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.ray.runtime.functionmanager.FunctionManager.JobFunctionTable; import org.testng.Assert; @@ -64,19 +66,19 @@ public class FunctionManagerTest { public void testGetFunctionFromRayFunc() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(UniqueId.NIL, fooFunc); + RayFunction func = functionManager.getFunction(JobId.NIL, fooFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); Assert.assertNotNull(func.getRayRemoteAnnotation()); // Test actor method - func = functionManager.getFunction(UniqueId.NIL, barFunc); + func = functionManager.getFunction(JobId.NIL, barFunc); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor - func = functionManager.getFunction(UniqueId.NIL, barConstructor); + func = functionManager.getFunction(JobId.NIL, barConstructor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); Assert.assertNotNull(func.getRayRemoteAnnotation()); @@ -86,19 +88,19 @@ public class FunctionManagerTest { public void testGetFunctionFromFunctionDescriptor() { final FunctionManager functionManager = new FunctionManager(null); // Test normal function. - RayFunction func = functionManager.getFunction(UniqueId.NIL, fooDescriptor); + RayFunction func = functionManager.getFunction(JobId.NIL, fooDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), fooDescriptor); Assert.assertNotNull(func.getRayRemoteAnnotation()); // Test actor method - func = functionManager.getFunction(UniqueId.NIL, barDescriptor); + func = functionManager.getFunction(JobId.NIL, barDescriptor); Assert.assertFalse(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barDescriptor); Assert.assertNull(func.getRayRemoteAnnotation()); // Test actor constructor - func = functionManager.getFunction(UniqueId.NIL, barConstructorDescriptor); + func = functionManager.getFunction(JobId.NIL, barConstructorDescriptor); Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); Assert.assertNotNull(func.getRayRemoteAnnotation()); @@ -119,7 +121,7 @@ public class FunctionManagerTest { @Test public void testGetFunctionFromLocalResource() throws Exception { - UniqueId jobId = UniqueId.randomId(); + JobId jobId = JobId.fromInt(1); final String resourcePath = FileUtils.getTempDirectoryPath() + "/ray_test_resources"; final String jobResourcePath = resourcePath + "/" + jobId.toString(); File jobResourceDir = new File(jobResourcePath); diff --git a/java/test/src/main/java/org/ray/api/test/GcsClientTest.java b/java/test/src/main/java/org/ray/api/test/GcsClientTest.java index 5fa9e60f0..04b08b64b 100644 --- a/java/test/src/main/java/org/ray/api/test/GcsClientTest.java +++ b/java/test/src/main/java/org/ray/api/test/GcsClientTest.java @@ -4,6 +4,7 @@ import com.google.common.base.Preconditions; import java.util.List; import org.ray.api.Ray; import org.ray.api.TestUtils; +import org.ray.api.id.JobId; import org.ray.api.runtimecontext.NodeInfo; import org.ray.runtime.AbstractRayRuntime; import org.ray.runtime.config.RayConfig; @@ -39,4 +40,17 @@ public class GcsClientTest extends BaseTest { Assert.assertEquals(allNodeInfo.get(0).resources.get("A"), 8.0); } + @Test + public void testNextJob() { + TestUtils.skipTestUnderSingleProcess(); + RayConfig config = ((AbstractRayRuntime)Ray.internal()).getRayConfig(); + // The value of job id of this driver in cluster should be `1L`. + Assert.assertEquals(config.getJobId(), JobId.fromInt(1)); + + GcsClient gcsClient = ((AbstractRayRuntime)Ray.internal()).getGcsClient(); + for (int i = 2; i < 100; ++i) { + Assert.assertEquals(gcsClient.nextJobId(), JobId.fromInt(i)); + } + + } } diff --git a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java index 33e2a345e..f7efe9eae 100644 --- a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java +++ b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java @@ -3,6 +3,7 @@ package org.ray.api.test; import org.ray.api.Ray; import org.ray.api.RayActor; import org.ray.api.annotation.RayRemote; +import org.ray.api.id.JobId; import org.ray.api.id.UniqueId; import org.testng.Assert; import org.testng.annotations.AfterClass; @@ -11,8 +12,7 @@ import org.testng.annotations.Test; public class RuntimeContextTest extends BaseTest { - private static UniqueId JOB_ID = - UniqueId.fromHexString("0011223344556677889900112233445566778899"); + private static JobId JOB_ID = JobId.fromHexString("00112233"); private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket"; diff --git a/python/ray/actor.py b/python/ray/actor.py index 2df1adf4e..8ca4a90fc 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -185,11 +185,11 @@ class ActorClass(object): task. _resources: The default resources required by the actor creation task. _actor_method_cpus: The number of CPUs required by actor method tasks. - _last_job_id_exported_for: The ID of the job of the last Ray - session during which this actor class definition was exported. This - is an imperfect mechanism used to determine if we need to export - the remote function again. It is imperfect in the sense that the - actor class definition could be exported multiple times by + _last_export_session_and_job: A pair of the last exported session + and job to help us to know whether this function was exported. + This is an imperfect mechanism used to determine if we need to + export the remote function again. It is imperfect in the sense that + the actor class definition could be exported multiple times by different workers. _actor_methods: The actor methods. _method_decorators: Optional decorators that should be applied to the @@ -211,7 +211,7 @@ class ActorClass(object): self._num_cpus = num_cpus self._num_gpus = num_gpus self._resources = resources - self._last_job_id_exported_for = None + self._last_export_session_and_job = None self._actor_methods = inspect.getmembers( self._modified_class, ray.utils.is_function_or_method) @@ -344,12 +344,13 @@ class ActorClass(object): *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if (self._last_job_id_exported_for is None or - self._last_job_id_exported_for != worker.current_job_id): - # If this actor class was exported in a previous session, we - # need to export this function again, because current GCS + if (self._last_export_session_and_job != + worker.current_session_and_job): + # If this actor class was not exported in this session and job, + # we need to export this function again, because current GCS # doesn't have it. - self._last_job_id_exported_for = worker.current_job_id + self._last_export_session_and_job = ( + worker.current_session_and_job) worker.function_actor_manager.export_actor_class( self._modified_class, self._actor_method_names) @@ -387,7 +388,8 @@ class ActorClass(object): actor_id, self._modified_class.__module__, self._class_name, actor_cursor, self._actor_method_names, self._method_decorators, self._method_signatures, self._actor_method_num_return_vals, - actor_cursor, actor_method_cpu, worker.current_job_id) + actor_cursor, actor_method_cpu, worker.current_job_id, + worker.current_session_and_job) # We increment the actor counter by 1 to account for the actor creation # task. actor_handle._ray_actor_counter += 1 @@ -465,6 +467,7 @@ class ActorHandle(object): actor_creation_dummy_object_id, actor_method_cpus, actor_job_id, + session_and_job, actor_handle_id=None): assert isinstance(actor_id, ActorID) assert isinstance(actor_job_id, ray.JobID) @@ -490,6 +493,7 @@ class ActorHandle(object): actor_creation_dummy_object_id) self._ray_actor_method_cpus = actor_method_cpus self._ray_actor_job_id = actor_job_id + self._ray_session_and_job = session_and_job self._ray_new_actor_handles = [] self._ray_actor_lock = threading.Lock() @@ -610,8 +614,10 @@ class ActorHandle(object): # there are ANY handles in scope in the process that created the actor, # not just the first one. worker = ray.worker.get_global_worker() + exported_in_current_session_and_job = ( + self._ray_session_and_job == worker.current_session_and_job) if (worker.mode == ray.worker.SCRIPT_MODE - and self._ray_actor_job_id.binary() != worker.worker_id): + and not exported_in_current_session_and_job): # If the worker is a driver and driver id has changed because # Ray was shut down re-initialized, the actor is already cleaned up # and we don't need to send `__ray_terminate__` again. @@ -729,6 +735,7 @@ class ActorHandle(object): # This is the ID of the job that owns the actor, not # necessarily the job that owns this actor handle. state["actor_job_id"], + worker.current_session_and_job, actor_handle_id=actor_handle_id) def __getstate__(self): diff --git a/python/ray/includes/unique_ids.pxd b/python/ray/includes/unique_ids.pxd index 91410f1ae..6c662d9e5 100644 --- a/python/ray/includes/unique_ids.pxd +++ b/python/ray/includes/unique_ids.pxd @@ -1,6 +1,6 @@ from libcpp cimport bool as c_bool from libcpp.string cimport string as c_string -from libc.stdint cimport uint8_t, int64_t +from libc.stdint cimport uint8_t, uint32_t, int64_t cdef extern from "ray/common/id.h" namespace "ray" nogil: cdef cppclass CBaseID[T]: @@ -78,11 +78,20 @@ cdef extern from "ray/common/id.h" namespace "ray" nogil: @staticmethod CFunctionID FromBinary(const c_string &binary) - cdef cppclass CJobID "ray::JobID"(CUniqueID): + cdef cppclass CJobID "ray::JobID"(CBaseID[CJobID]): @staticmethod CJobID FromBinary(const c_string &binary) + @staticmethod + const CJobID Nil() + + @staticmethod + size_t Size() + + @staticmethod + CJobID FromInt(uint32_t value) + cdef cppclass CTaskID "ray::TaskID"(CBaseID[CTaskID]): @staticmethod diff --git a/python/ray/includes/unique_ids.pxi b/python/ray/includes/unique_ids.pxi index 76a61e177..76cc34513 100644 --- a/python/ray/includes/unique_ids.pxi +++ b/python/ray/includes/unique_ids.pxi @@ -109,7 +109,6 @@ cdef class UniqueID(BaseID): def nil(cls): return cls(CUniqueID.Nil().Binary()) - @classmethod def from_random(cls): return cls(os.urandom(CUniqueID.Size())) @@ -194,7 +193,7 @@ cdef class TaskID(BaseID): return cls(CTaskID.Nil().Binary()) @classmethod - def size(cla): + def size(cls): return CTaskID.Size() @classmethod @@ -212,15 +211,43 @@ cdef class ClientID(UniqueID): return self.data -cdef class JobID(UniqueID): +cdef class JobID(BaseID): + cdef CJobID data def __init__(self, id): - check_id(id) + check_id(id, CJobID.Size()) self.data = CJobID.FromBinary(id) cdef CJobID native(self): return self.data + @classmethod + def from_int(cls, value): + return cls(CJobID.FromInt(value).Binary()) + + @classmethod + def nil(cls): + return cls(CJobID.Nil().Binary()) + + @classmethod + def size(cls): + return CJobID.Size() + + def binary(self): + return self.data.Binary() + + def hex(self): + return decode(self.data.Hex()) + + def size(self): + return CJobID.Size() + + def is_nil(self): + return self.data.IsNil() + + cdef size_t hash(self): + return self.data.Hash() + cdef class WorkerID(UniqueID): def __init__(self, id): diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 38a382e3c..a709d89c1 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -43,8 +43,8 @@ class RemoteFunction(object): return the resulting ObjectIDs. For an example, see "test_decorated_function" in "python/ray/tests/test_basic.py". _function_signature: The function signature. - _last_job_id_exported_for: The ID of the job ID of the last Ray - session during which this remote function definition was exported. + _last_export_session_and_job: A pair of the last exported session + and job to help us to know whether this function was exported. This is an imperfect mechanism used to determine if we need to export the remote function again. It is imperfect in the sense that the actor class definition could be exported multiple times by @@ -71,9 +71,7 @@ class RemoteFunction(object): ray.signature.check_signature_supported(self._function) self._function_signature = ray.signature.extract_signature( self._function) - - self._last_job_id_exported_for = None - + self._last_export_session_and_job = None # Override task.remote's signature and docstring @wraps(function) def _remote_proxy(*args, **kwargs): @@ -114,11 +112,11 @@ class RemoteFunction(object): worker = ray.worker.get_global_worker() worker.check_connected() - if (self._last_job_id_exported_for is None - or self._last_job_id_exported_for != worker.current_job_id): - # If this function was exported in a previous session, we need to - # export this function again, because current GCS doesn't have it. - self._last_job_id_exported_for = worker.current_job_id + if self._last_export_session_and_job != worker.current_session_and_job: + # If this function was not exported in this session and job, + # we need to export this function again, because current GCS + # doesn't have it. + self._last_export_session_and_job = worker.current_session_and_job worker.function_actor_manager.export(self) kwargs = {} if kwargs is None else kwargs diff --git a/python/ray/services.py b/python/ray/services.py index c39271130..7252a0523 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -615,6 +615,9 @@ def start_redis(node_ip_address, # can access it and know whether or not to enable cross-languages. primary_redis_client.set("INCLUDE_JAVA", 1 if include_java else 0) + # Init job counter to GCS. + primary_redis_client.set("JobCounter", 0) + # Store version information in the primary Redis shard. _put_version_info_in_redis(primary_redis_client) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 87b8043c4..f75fa3644 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -2510,7 +2510,8 @@ def test_global_state_api(shutdown_only): assert ray.objects() == {} - job_id = ray.utils.binary_to_hex(ray.worker.global_worker.worker_id) + job_id = ray.utils.compute_job_id_from_driver( + ray.WorkerID(ray.worker.global_worker.worker_id)) driver_task_id = ray.worker.global_worker.current_task_id.hex() # One task is put in the task table which corresponds to this driver. @@ -2524,7 +2525,7 @@ def test_global_state_api(shutdown_only): assert task_spec["TaskID"] == driver_task_id assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [] - assert task_spec["JobID"] == job_id + assert task_spec["JobID"] == job_id.hex() assert task_spec["FunctionID"] == nil_id_hex assert task_spec["ReturnObjectIDs"] == [] @@ -2552,7 +2553,7 @@ def test_global_state_api(shutdown_only): task_spec = task_table[task_id]["TaskSpec"] assert task_spec["ActorID"] == nil_id_hex assert task_spec["Args"] == [1, "hi", x_id] - assert task_spec["JobID"] == job_id + assert task_spec["JobID"] == job_id.hex() assert task_spec["ReturnObjectIDs"] == [result_id] assert task_table[task_id] == ray.tasks(task_id) @@ -2583,7 +2584,7 @@ def test_global_state_api(shutdown_only): job_table = ray.jobs() assert len(job_table) == 1 - assert job_table[0]["JobID"] == job_id + assert job_table[0]["JobID"] == job_id.hex() assert job_table[0]["NodeManagerAddress"] == node_ip_address @@ -2691,7 +2692,7 @@ def test_workers(shutdown_only): def test_specific_job_id(): - dummy_driver_id = ray.JobID(b"00112233445566778899") + dummy_driver_id = ray.JobID.from_int(1) ray.init(num_cpus=1, job_id=dummy_driver_id) # in driver diff --git a/python/ray/utils.py b/python/ray/utils.py index 5b6f07f65..954693661 100644 --- a/python/ray/utils.py +++ b/python/ray/utils.py @@ -232,6 +232,20 @@ def hex_to_binary(hex_identifier): return binascii.unhexlify(hex_identifier) +# TODO(qwang): Remove these hepler functions +# once we separate `WorkerID` from `UniqueID`. +def compute_job_id_from_driver(driver_id): + assert isinstance(driver_id, ray.WorkerID) + return ray.JobID(driver_id.binary()[0:ray.JobID.size()]) + + +def compute_driver_id_from_job(job_id): + assert isinstance(job_id, ray.JobID) + rest_length = ray_constants.ID_SIZE - job_id.size() + driver_id_str = job_id.binary() + (rest_length * b"\xff") + return ray.WorkerID(driver_id_str) + + def get_cuda_visible_devices(): """Get the device IDs in the CUDA_VISIBLE_DEVICES environment variable. diff --git a/python/ray/worker.py b/python/ray/worker.py index 5a3146255..d35e2d326 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -18,6 +18,7 @@ import sys import threading import time import traceback +import random # Ray modules import pyarrow @@ -217,6 +218,13 @@ class Worker(object): def current_task_id(self): return self.task_context.current_task_id + @property + def current_session_and_job(self): + """Get the current session index and job id as pair.""" + assert isinstance(self._session_index, int) + assert isinstance(self.current_job_id, ray.JobID) + return self._session_index, self.current_job_id + def mark_actor_init_failed(self, error): """Called to mark this actor as failed during initialization.""" @@ -1718,27 +1726,44 @@ def connect(node, worker.profiler = profiling.Profiler(worker, worker.threads_stopped) + if mode is not LOCAL_MODE: + # Create a Redis client to primary. + # The Redis client can safely be shared between threads. However, + # that is not true of Redis pubsub clients. See the documentation at + # https://github.com/andymccurdy/redis-py#thread-safety. + worker.redis_client = node.create_redis_client() + # Initialize some fields. if mode is WORKER_MODE: + # We should not specify the job_id if it's `WORKER_MODE`. + assert job_id is None + job_id = JobID.nil() + # TODO(qwang): Rename this to `worker_id_str` or type to `WorkerID` worker.worker_id = _random_string() if setproctitle: setproctitle.setproctitle("ray_worker") + elif mode is LOCAL_MODE: + # Code path of local mode + if job_id is None: + job_id = JobID.from_int(random.randint(1, 100000)) + worker.worker_id = ray.utils.compute_driver_id_from_job( + job_id).binary() else: # This is the code path of driver mode. if job_id is None: - job_id = JobID.from_random() + # TODO(qwang): use `GcsClient::GenerateJobId()` here. + job_id = JobID.from_int( + int(worker.redis_client.incr("JobCounter"))) + # When tasks are executed on remote workers in the context of multiple + # drivers, the current job ID is used to keep track of which job is + # responsible for the task so that error messages will be propagated to + # the correct driver. + worker.worker_id = ray.utils.compute_driver_id_from_job( + job_id).binary() - if not isinstance(job_id, JobID): - raise TypeError("The type of given job id must be JobID.") - - worker.worker_id = job_id.binary() - - # When tasks are executed on remote workers in the context of multiple - # drivers, the current job ID is used to keep track of which driver is - # responsible for the task so that error messages will be propagated to - # the correct driver. - if mode != WORKER_MODE: - worker.current_job_id = JobID(worker.worker_id) + if not isinstance(job_id, JobID): + raise TypeError("The type of given job id must be JobID.") + worker.current_job_id = job_id # All workers start out as non-actors. A worker can be turned into an actor # after it is created. @@ -1752,12 +1777,6 @@ def connect(node, worker.local_mode_manager = LocalModeManager() return - # Create a Redis client. - # The Redis client can safely be shared between threads. However, that is - # not true of Redis pubsub clients. See the documentation at - # https://github.com/andymccurdy/redis-py#thread-safety. - worker.redis_client = node.create_redis_client() - # For driver's check that the version information matches the version # information that the Ray cluster was started with. try: @@ -1836,7 +1855,6 @@ def connect(node, # Create an object store client. worker.plasma_client = thread_safe_client( plasma.connect(node.plasma_store_socket_name, None, 0, 300)) - job_id_str = _random_string() # If this is a driver, set the current task ID, the task driver ID, and set # the task index to 0. @@ -1868,7 +1886,7 @@ def connect(node, function_descriptor.get_function_descriptor_list(), [], # arguments. 0, # num_returns. - TaskID(job_id_str[:TaskID.size()]), # parent_task_id. + TaskID(worker.worker_id[:TaskID.size()]), # parent_task_id. 0, # parent_counter. ActorID.nil(), # actor_creation_id. ObjectID.nil(), # actor_creation_dummy_object_id. @@ -1901,7 +1919,7 @@ def connect(node, node.raylet_socket_name, ClientID(worker.worker_id), (mode == WORKER_MODE), - JobID(job_id_str), + worker.current_job_id, ) # Start the import thread diff --git a/src/ray/common/constants.h b/src/ray/common/constants.h index 1f50b8025..99c4f89fa 100644 --- a/src/ray/common/constants.h +++ b/src/ray/common/constants.h @@ -5,7 +5,7 @@ #include /// Length of Ray full-length IDs in bytes. -constexpr int64_t kUniqueIDSize = 20; +constexpr size_t kUniqueIDSize = 20; /// An ObjectID's bytes are split into the task ID itself and the index of the /// object's creation. This is the maximum width of the object index in bits. diff --git a/src/ray/common/id.cc b/src/ray/common/id.cc index 00964094a..38183f860 100644 --- a/src/ray/common/id.cc +++ b/src/ray/common/id.cc @@ -26,6 +26,14 @@ std::mt19937 RandomlySeededMersenneTwister() { uint64_t MurmurHash64A(const void *key, int len, unsigned int seed); +WorkerID ComputeDriverIdFromJob(const JobID &job_id) { + std::vector data(WorkerID::Size(), 0); + std::memcpy(data.data(), job_id.Data(), JobID::Size()); + std::fill_n(data.data() + JobID::Size(), WorkerID::Size() - JobID::Size(), 0xFF); + return WorkerID::FromBinary( + std::string(reinterpret_cast(data.data()), data.size())); +} + plasma::UniqueID ObjectID::ToPlasmaId() const { plasma::UniqueID result; std::memcpy(result.mutable_data(), Data(), kUniqueIDSize); @@ -129,16 +137,6 @@ const TaskID GenerateTaskId(const JobID &job_id, const TaskID &parent_task_id, return TaskID::FromBinary(std::string(buff, buff + TaskID::Size())); } -const WorkerID ComputeDriverId(const JobID &job_id) { - // Currently, a job id equals its driver id. - return WorkerID(job_id); -} - -const JobID ComputeJobId(const WorkerID &driver_id) { - // Currently, a job id equals its driver id. - return JobID(driver_id); -} - const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_id, int64_t num_forks) { // Compute hashes. @@ -155,6 +153,13 @@ const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_i return ActorHandleID::FromBinary(std::string(buff, buff + ActorHandleID::Size())); } +JobID JobID::FromInt(uint32_t value) { + std::vector data(JobID::Size(), 0); + std::memcpy(data.data(), &value, JobID::Size()); + return JobID::FromBinary( + std::string(reinterpret_cast(data.data()), data.size())); +} + #define ID_OSTREAM_OPERATOR(id_type) \ std::ostream &operator<<(std::ostream &os, const id_type &id) { \ if (id.IsNil()) { \ @@ -166,6 +171,7 @@ const ActorHandleID ComputeNextActorHandleId(const ActorHandleID &actor_handle_i } ID_OSTREAM_OPERATOR(UniqueID); +ID_OSTREAM_OPERATOR(JobID); ID_OSTREAM_OPERATOR(TaskID); ID_OSTREAM_OPERATOR(ObjectID); diff --git a/src/ray/common/id.h b/src/ray/common/id.h index 539a0774c..5f6539153 100644 --- a/src/ray/common/id.h +++ b/src/ray/common/id.h @@ -19,6 +19,13 @@ namespace ray { class WorkerID; class UniqueID; +class JobID; + +/// TODO(qwang): These 2 helper functions should be removed +/// once we separated the `WorkerID` from `UniqueID`. +/// +/// A helper function that get the `DriverID` of the given job. +WorkerID ComputeDriverIdFromJob(const JobID &job_id); // Declaration. std::mt19937 RandomlySeededMersenneTwister(); @@ -58,9 +65,10 @@ class BaseID { class UniqueID : public BaseID { public: - UniqueID() : BaseID(){}; static size_t Size() { return kUniqueIDSize; } + UniqueID() : BaseID() {} + protected: UniqueID(const std::string &binary); @@ -68,6 +76,23 @@ class UniqueID : public BaseID { uint8_t id_[kUniqueIDSize]; }; +class JobID : public BaseID { + public: + static constexpr int64_t length = 4; + + // TODO(qwang): Use `uint32_t` to store the data. + static JobID FromInt(uint32_t value); + + static size_t Size() { return length; } + + static JobID FromRandom() = delete; + + JobID() : BaseID() {} + + private: + uint8_t id_[length]; +}; + class TaskID : public BaseID { public: TaskID() : BaseID() {} @@ -116,12 +141,15 @@ class ObjectID : public BaseID { int32_t index_; }; +static_assert(sizeof(JobID) == JobID::length + sizeof(size_t), + "JobID size is not as expected"); static_assert(sizeof(TaskID) == kTaskIDSize + sizeof(size_t), "TaskID size is not as expected"); static_assert(sizeof(ObjectID) == sizeof(int32_t) + sizeof(TaskID), "ObjectID size is not as expected"); std::ostream &operator<<(std::ostream &os, const UniqueID &id); +std::ostream &operator<<(std::ostream &os, const JobID &id); std::ostream &operator<<(std::ostream &os, const TaskID &id); std::ostream &operator<<(std::ostream &os, const ObjectID &id); @@ -192,6 +220,7 @@ T BaseID::FromRandom() { template T BaseID::FromBinary(const std::string &binary) { + RAY_CHECK(binary.size() == T::Size()); T t = T::Nil(); std::memcpy(t.MutableData(), binary.data(), T::Size()); return t; @@ -272,6 +301,7 @@ namespace std { }; DEFINE_UNIQUE_ID(UniqueID); +DEFINE_UNIQUE_ID(JobID); DEFINE_UNIQUE_ID(TaskID); DEFINE_UNIQUE_ID(ObjectID); #include "id_def.h" diff --git a/src/ray/common/id_def.h b/src/ray/common/id_def.h index 7344793f6..d3e079482 100644 --- a/src/ray/common/id_def.h +++ b/src/ray/common/id_def.h @@ -10,6 +10,5 @@ DEFINE_UNIQUE_ID(ActorID) DEFINE_UNIQUE_ID(ActorHandleID) DEFINE_UNIQUE_ID(ActorCheckpointID) DEFINE_UNIQUE_ID(WorkerID) -DEFINE_UNIQUE_ID(JobID) DEFINE_UNIQUE_ID(ConfigID) DEFINE_UNIQUE_ID(ClientID) diff --git a/src/ray/core_worker/context.cc b/src/ray/core_worker/context.cc index af44d3c8b..b8aac0da4 100644 --- a/src/ray/core_worker/context.cc +++ b/src/ray/core_worker/context.cc @@ -40,7 +40,7 @@ thread_local std::unique_ptr WorkerContext::thread_context_ WorkerContext::WorkerContext(WorkerType worker_type, const JobID &job_id) : worker_type(worker_type), - worker_id(worker_type == WorkerType::DRIVER ? WorkerID::FromBinary(job_id.Binary()) + worker_id(worker_type == WorkerType::DRIVER ? ComputeDriverIdFromJob(job_id) : WorkerID::FromRandom()), current_job_id(worker_type == WorkerType::DRIVER ? job_id : JobID::Nil()) { // For worker main thread which initializes the WorkerContext, diff --git a/src/ray/core_worker/core_worker_test.cc b/src/ray/core_worker/core_worker_test.cc index 3eb58595a..07aedad36 100644 --- a/src/ray/core_worker/core_worker_test.cc +++ b/src/ray/core_worker/core_worker_test.cc @@ -125,7 +125,7 @@ class CoreWorkerTest : public ::testing::Test { void TestNormalTask(const std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromRandom()); + raylet_socket_names_[0], JobID::FromInt(1)); // Test pass by value. { @@ -184,7 +184,7 @@ class CoreWorkerTest : public ::testing::Test { void TestActorTask(const std::unordered_map &resources) { CoreWorker driver(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromRandom()); + raylet_socket_names_[0], JobID::FromInt(1)); std::unique_ptr actor_handle; @@ -276,7 +276,7 @@ TEST_F(ZeroNodeTest, TestTaskArg) { } TEST_F(ZeroNodeTest, TestWorkerContext) { - auto job_id = JobID::FromRandom(); + auto job_id = JobID::JobID::FromInt(1); WorkerContext context(WorkerType::WORKER, job_id); ASSERT_TRUE(context.GetCurrentTaskID().IsNil()); @@ -335,7 +335,7 @@ TEST_F(ZeroNodeTest, TestActorHandle) { TEST_F(SingleNodeTest, TestObjectInterface) { CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], raylet_socket_names_[0], - JobID::FromRandom()); + JobID::JobID::FromInt(1)); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -398,10 +398,10 @@ TEST_F(SingleNodeTest, TestObjectInterface) { TEST_F(TwoNodeTest, TestObjectInterfaceCrossNodes) { CoreWorker worker1(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[0], - raylet_socket_names_[0], JobID::FromRandom()); + raylet_socket_names_[0], JobID::JobID::FromInt(1)); CoreWorker worker2(WorkerType::DRIVER, Language::PYTHON, raylet_store_socket_names_[1], - raylet_socket_names_[1], JobID::FromRandom()); + raylet_socket_names_[1], JobID::JobID::FromInt(1)); uint8_t array1[] = {1, 2, 3, 4, 5, 6, 7, 8}; uint8_t array2[] = {10, 11, 12, 13, 14, 15}; @@ -487,7 +487,7 @@ TEST_F(TwoNodeTest, TestActorTaskCrossNodes) { TEST_F(SingleNodeTest, TestCoreWorkerConstructorFailure) { try { CoreWorker core_worker(WorkerType::DRIVER, Language::PYTHON, "", - raylet_socket_names_[0], JobID::FromRandom()); + raylet_socket_names_[0], JobID::FromInt(1)); } catch (const std::exception &e) { std::cout << "Caught exception when constructing core worker: " << e.what(); } diff --git a/src/ray/core_worker/mock_worker.cc b/src/ray/core_worker/mock_worker.cc index 5097f6bd5..c171e5df8 100644 --- a/src/ray/core_worker/mock_worker.cc +++ b/src/ray/core_worker/mock_worker.cc @@ -18,7 +18,7 @@ class MockWorker { public: MockWorker(const std::string &store_socket, const std::string &raylet_socket) : worker_(WorkerType::WORKER, Language::PYTHON, store_socket, raylet_socket, - JobID::FromRandom()) {} + JobID::JobID::FromInt(1)) {} void Run() { auto executor_func = [this](const RayFunction &ray_function, diff --git a/src/ray/gcs/client_test.cc b/src/ray/gcs/client_test.cc index 08d4ac3d3..a63439833 100644 --- a/src/ray/gcs/client_test.cc +++ b/src/ray/gcs/client_test.cc @@ -20,12 +20,18 @@ static inline void flushall_redis(void) { redisFree(context); } +/// A helper function to generate an unique JobID. +inline JobID NextJobID() { + static int32_t counter = 0; + return JobID::FromInt(++counter); +} + class TestGcs : public ::testing::Test { public: TestGcs(CommandType command_type) : num_callbacks_(0), command_type_(command_type) { client_ = std::make_shared("127.0.0.1", 6379, command_type_, /*is_test_client=*/true); - job_id_ = JobID::FromRandom(); + job_id_ = NextJobID(); } virtual ~TestGcs() { @@ -571,7 +577,7 @@ void TestLogSubscribeAll(const JobID &job_id, std::shared_ptr client) { std::vector job_ids; for (int i = 0; i < 3; i++) { - job_ids.emplace_back(JobID::FromRandom()); + job_ids.emplace_back(NextJobID()); } // Callback for a notification. auto notification_callback = [job_ids](gcs::AsyncGcsClient *client, const JobID &id, @@ -770,14 +776,14 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeId); void TestLogSubscribeId(const JobID &job_id, std::shared_ptr client) { // Add a log entry. - JobID job_id1 = JobID::FromRandom(); + JobID job_id1 = NextJobID(); std::vector job_ids1 = {"abc", "def", "ghi"}; auto data1 = std::make_shared(); data1->set_job_id(job_ids1[0]); RAY_CHECK_OK(client->job_table().Append(job_id, job_id1, data1, nullptr)); // Add a log entry at a second key. - JobID job_id2 = JobID::FromRandom(); + JobID job_id2 = NextJobID(); std::vector job_ids2 = {"jkl", "mno", "pqr"}; auto data2 = std::make_shared(); data2->set_job_id(job_ids2[0]); @@ -786,7 +792,7 @@ void TestLogSubscribeId(const JobID &job_id, // The callback for a notification from the table. This should only be // received for keys that we requested notifications for. auto notification_callback = [job_id2, job_ids2]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const JobID &id, const std::vector &data) { // Check that we only get notifications for the requested key. ASSERT_EQ(id, job_id2); @@ -992,7 +998,7 @@ TEST_MACRO(TestGcsWithChainAsio, TestTableSubscribeCancel); void TestLogSubscribeCancel(const JobID &job_id, std::shared_ptr client) { // Add a log entry. - JobID random_job_id = JobID::FromRandom(); + JobID random_job_id = NextJobID(); std::vector job_ids = {"jkl", "mno", "pqr"}; auto data = std::make_shared(); data->set_job_id(job_ids[0]); @@ -1001,7 +1007,7 @@ void TestLogSubscribeCancel(const JobID &job_id, // The callback for a notification from the object table. This should only be // received for the object that we requested notifications for. auto notification_callback = [random_job_id, job_ids]( - gcs::AsyncGcsClient *client, const UniqueID &id, + gcs::AsyncGcsClient *client, const JobID &id, const std::vector &data) { ASSERT_EQ(id, random_job_id); // Check that we get a duplicate notification for the first write. We get a diff --git a/src/ray/raylet/node_manager.cc b/src/ray/raylet/node_manager.cc index 419bf4e3c..280d718ad 100644 --- a/src/ray/raylet/node_manager.cc +++ b/src/ray/raylet/node_manager.cc @@ -268,8 +268,8 @@ void NodeManager::KillWorker(std::shared_ptr worker) { void NodeManager::HandleJobTableUpdate(const JobID &id, const std::vector &job_data) { for (const auto &entry : job_data) { - RAY_LOG(DEBUG) << "HandleJobTableUpdate " << UniqueID::FromBinary(entry.job_id()) - << " " << entry.is_dead(); + RAY_LOG(DEBUG) << "HandleJobTableUpdate " << JobID::FromBinary(entry.job_id()) << " " + << entry.is_dead(); if (entry.is_dead()) { auto job_id = JobID::FromBinary(entry.job_id()); auto workers = worker_pool_.GetWorkersRunningTasksForJob(job_id); @@ -869,9 +869,8 @@ void NodeManager::ProcessRegisterClientRequestMessage( worker_pool_.RegisterDriver(std::move(worker)); local_queues_.AddDriverTaskId(driver_task_id); RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( - JobID(driver_id), - /*is_dead=*/false, std::time(nullptr), initial_config_.node_manager_address, - message->worker_pid())); + job_id, /*is_dead=*/false, std::time(nullptr), + initial_config_.node_manager_address, message->worker_pid())); } } @@ -1039,17 +1038,17 @@ void NodeManager::ProcessDisconnectClientMessage( DispatchTasks(local_queues_.GetReadyTasksWithResources()); } else if (is_driver) { // The client is a driver. - RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( - JobID(client->GetClientId()), - /*is_dead=*/true, std::time(nullptr), initial_config_.node_manager_address, - worker->Pid())); - auto job_id = worker->GetAssignedTaskId(); + const auto job_id = worker->GetAssignedJobId(); + const auto driver_id = ComputeDriverIdFromJob(job_id); RAY_CHECK(!job_id.IsNil()); - local_queues_.RemoveDriverTaskId(job_id); + RAY_CHECK_OK(gcs_client_->job_table().AppendJobData( + job_id, /*is_dead=*/true, std::time(nullptr), + initial_config_.node_manager_address, worker->Pid())); + local_queues_.RemoveDriverTaskId(TaskID::ComputeDriverTaskId(driver_id)); worker_pool_.DisconnectDriver(worker); RAY_LOG(DEBUG) << "Driver (pid=" << worker->Pid() << ") is disconnected. " - << "job_id: " << worker->GetAssignedJobId(); + << "job_id: " << job_id; } client->Close();