mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 20:22:39 +08:00
[Java worker] Migrate task execution and submission on top of core worker (#5370)
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
package org.ray.api;
|
||||
|
||||
public enum ObjectType {
|
||||
PUT_OBJECT,
|
||||
RETURN_OBJECT,
|
||||
}
|
||||
@@ -2,14 +2,11 @@ package org.ray.api.id;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
|
||||
public class ActorId extends BaseId implements Serializable {
|
||||
private static final int UNIQUE_BYTES_LENGTH = 4;
|
||||
|
||||
public static final int LENGTH = UNIQUE_BYTES_LENGTH + JobId.LENGTH;
|
||||
public static final int LENGTH = 8;
|
||||
|
||||
public static final ActorId NIL = nil();
|
||||
|
||||
@@ -25,19 +22,6 @@ public class ActorId extends BaseId implements Serializable {
|
||||
return new ActorId(bytes);
|
||||
}
|
||||
|
||||
public static ActorId generateActorId(JobId jobId) {
|
||||
byte[] uniqueBytes = new byte[ActorId.UNIQUE_BYTES_LENGTH];
|
||||
new Random().nextBytes(uniqueBytes);
|
||||
|
||||
byte[] bytes = new byte[ActorId.LENGTH];
|
||||
ByteBuffer wbb = ByteBuffer.wrap(bytes);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
System.arraycopy(uniqueBytes, 0, bytes, 0, ActorId.UNIQUE_BYTES_LENGTH);
|
||||
System.arraycopy(jobId.getBytes(), 0, bytes, ActorId.UNIQUE_BYTES_LENGTH, JobId.LENGTH);
|
||||
return new ActorId(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a nil ActorId.
|
||||
*/
|
||||
@@ -47,6 +31,15 @@ public class ActorId extends BaseId implements Serializable {
|
||||
return new ActorId(b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an ActorId with random value. Used for local mode and test only.
|
||||
*/
|
||||
public static ActorId fromRandom() {
|
||||
byte[] b = new byte[LENGTH];
|
||||
new Random().nextBytes(b);
|
||||
return new ActorId(b);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return LENGTH;
|
||||
|
||||
@@ -2,10 +2,8 @@ package org.ray.api.id;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import org.ray.api.ObjectType;
|
||||
|
||||
/**
|
||||
* Represents the id of a Ray object.
|
||||
@@ -16,20 +14,6 @@ public class ObjectId extends BaseId implements Serializable {
|
||||
|
||||
public static final ObjectId NIL = genNil();
|
||||
|
||||
private static int CREATED_BY_TASK_FLAG_BITS_OFFSET = 15;
|
||||
|
||||
private static int OBJECT_TYPE_FLAG_BITS_OFFSET = 14;
|
||||
|
||||
private static int TRANSPORT_TYPE_FLAG_BITS_OFFSET = 11;
|
||||
|
||||
private static int FLAGS_BYTES_POS = TaskId.LENGTH;
|
||||
|
||||
private static int FLAGS_BYTES_LENGTH = 2;
|
||||
|
||||
private static int INDEX_BYTES_POS = FLAGS_BYTES_POS + FLAGS_BYTES_LENGTH;
|
||||
|
||||
private static int INDEX_BYTES_LENGTH = 4;
|
||||
|
||||
/**
|
||||
* Create an ObjectId from a ByteBuffer.
|
||||
*/
|
||||
@@ -55,48 +39,6 @@ public class ObjectId extends BaseId implements Serializable {
|
||||
return new ObjectId(b);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the object ID of an object put by the task.
|
||||
*/
|
||||
public static ObjectId forPut(TaskId taskId, int putIndex) {
|
||||
short flags = 0;
|
||||
flags = setCreatedByTaskFlag(flags, true);
|
||||
// Set a default transport type with value 0.
|
||||
flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET));
|
||||
flags = setObjectTypeFlag(flags, ObjectType.PUT_OBJECT);
|
||||
|
||||
byte[] bytes = new byte[ObjectId.LENGTH];
|
||||
System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH);
|
||||
|
||||
ByteBuffer wbb = ByteBuffer.wrap(bytes);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
wbb.putShort(FLAGS_BYTES_POS, flags);
|
||||
|
||||
wbb.putInt(INDEX_BYTES_POS, putIndex);
|
||||
return new ObjectId(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the object ID of an object return by the task.
|
||||
*/
|
||||
public static ObjectId forReturn(TaskId taskId, int returnIndex) {
|
||||
short flags = 0;
|
||||
flags = setCreatedByTaskFlag(flags, true);
|
||||
// Set a default transport type with value 0.
|
||||
flags = (short) (flags | (0x0 << TRANSPORT_TYPE_FLAG_BITS_OFFSET));
|
||||
flags = setObjectTypeFlag(flags, ObjectType.RETURN_OBJECT);
|
||||
|
||||
byte[] bytes = new byte[ObjectId.LENGTH];
|
||||
System.arraycopy(taskId.getBytes(), 0, bytes, 0, TaskId.LENGTH);
|
||||
|
||||
ByteBuffer wbb = ByteBuffer.wrap(bytes);
|
||||
wbb.order(ByteOrder.LITTLE_ENDIAN);
|
||||
wbb.putShort(FLAGS_BYTES_POS, flags);
|
||||
|
||||
wbb.putInt(INDEX_BYTES_POS, returnIndex);
|
||||
return new ObjectId(bytes);
|
||||
}
|
||||
|
||||
public ObjectId(byte[] id) {
|
||||
super(id);
|
||||
}
|
||||
@@ -106,25 +48,4 @@ public class ObjectId extends BaseId implements Serializable {
|
||||
return LENGTH;
|
||||
}
|
||||
|
||||
public TaskId getTaskId() {
|
||||
byte[] taskIdBytes = Arrays.copyOf(getBytes(), TaskId.LENGTH);
|
||||
return TaskId.fromBytes(taskIdBytes);
|
||||
}
|
||||
|
||||
private static short setCreatedByTaskFlag(short flags, boolean createdByTask) {
|
||||
if (createdByTask) {
|
||||
return (short) (flags | (0x1 << CREATED_BY_TASK_FLAG_BITS_OFFSET));
|
||||
} else {
|
||||
return (short) (flags | (0x0 << CREATED_BY_TASK_FLAG_BITS_OFFSET));
|
||||
}
|
||||
}
|
||||
|
||||
private static short setObjectTypeFlag(short flags, ObjectType objectType) {
|
||||
if (objectType == ObjectType.RETURN_OBJECT) {
|
||||
return (short)(flags | (0x1 << OBJECT_TYPE_FLAG_BITS_OFFSET));
|
||||
} else {
|
||||
return (short)(flags | (0x0 << OBJECT_TYPE_FLAG_BITS_OFFSET));
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -11,9 +11,7 @@ import java.util.Random;
|
||||
*/
|
||||
public class TaskId extends BaseId implements Serializable {
|
||||
|
||||
private static final int UNIQUE_BYTES_LENGTH = 6;
|
||||
|
||||
public static final int LENGTH = UNIQUE_BYTES_LENGTH + ActorId.LENGTH;
|
||||
public static final int LENGTH = 14;
|
||||
|
||||
public static final TaskId NIL = genNil();
|
||||
|
||||
@@ -38,15 +36,6 @@ public class TaskId extends BaseId implements Serializable {
|
||||
return new TaskId(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the id of the actor to which this task belongs
|
||||
*/
|
||||
public ActorId getActorId() {
|
||||
byte[] actorIdBytes = new byte[ActorId.LENGTH];
|
||||
System.arraycopy(getBytes(), UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH);
|
||||
return ActorId.fromByteBuffer(ByteBuffer.wrap(actorIdBytes));
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a nil TaskId.
|
||||
*/
|
||||
|
||||
@@ -1,45 +1,35 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.BaseTaskOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.api.runtime.RayRuntime;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.RuntimeContextImpl;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.FunctionManager;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.ray.runtime.task.TaskSubmitter;
|
||||
import org.ray.runtime.util.StringUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -50,71 +40,24 @@ import org.slf4j.LoggerFactory;
|
||||
public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
public static final String PYTHON_INIT_METHOD_NAME = "__init__";
|
||||
protected RayConfig rayConfig;
|
||||
protected WorkerContext workerContext;
|
||||
protected Worker worker;
|
||||
protected RayletClient rayletClient;
|
||||
protected ObjectStoreProxy objectStoreProxy;
|
||||
protected TaskExecutor taskExecutor;
|
||||
protected FunctionManager functionManager;
|
||||
protected RuntimeContext runtimeContext;
|
||||
protected GcsClient gcsClient;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"core_worker_library_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
}
|
||||
protected ObjectStore objectStore;
|
||||
protected TaskSubmitter taskSubmitter;
|
||||
protected RayletClient rayletClient;
|
||||
protected WorkerContext workerContext;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
functionManager = new FunctionManager(rayConfig.jobResourcePath);
|
||||
worker = new Worker(this);
|
||||
runtimeContext = new RuntimeContextImpl(this);
|
||||
}
|
||||
|
||||
protected void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start runtime.
|
||||
*/
|
||||
@@ -125,31 +68,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> RayObject<T> put(T obj) {
|
||||
ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(),
|
||||
workerContext.nextPutIndex());
|
||||
put(objectId, obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
}
|
||||
|
||||
public <T> void put(ObjectId objectId, T obj) {
|
||||
TaskId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.put(objectId, obj);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Store a serialized object in the object store.
|
||||
*
|
||||
* @param obj The serialized Java object to be stored.
|
||||
* @return A RayObject instance that represents the in-store object.
|
||||
*/
|
||||
public RayObject<Object> putSerialized(byte[] obj) {
|
||||
ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(),
|
||||
workerContext.nextPutIndex());
|
||||
TaskId taskId = workerContext.getCurrentTaskId();
|
||||
LOGGER.debug("Putting serialized object {}, for task {} ", objectId, taskId);
|
||||
objectStoreProxy.putSerialized(objectId, obj);
|
||||
ObjectId objectId = objectStore.put(obj);
|
||||
return new RayObjectImpl<>(objectId);
|
||||
}
|
||||
|
||||
@@ -161,12 +80,12 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> List<T> get(List<ObjectId> objectIds) {
|
||||
return objectStoreProxy.get(objectIds);
|
||||
return objectStore.get(objectIds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
rayletClient.freePlasmaObjects(objectIds, localOnly, deleteCreatingTasks);
|
||||
objectStore.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -180,44 +99,33 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
return rayletClient.wait(waitList, numReturns,
|
||||
timeoutMs, workerContext.getCurrentTaskId());
|
||||
return objectStore.wait(waitList, numReturns, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, Object[] args, CallOptions options) {
|
||||
TaskSpec spec = createTaskSpec(func, null, RayActorImpl.NIL, args, false, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
return callNormalFunction(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject call(RayFunc func, RayActor<?> actor, Object[] args) {
|
||||
if (!(actor instanceof RayActorImpl)) {
|
||||
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
|
||||
}
|
||||
RayActorImpl<?> actorImpl = (RayActorImpl) actor;
|
||||
TaskSpec spec;
|
||||
synchronized (actor) {
|
||||
spec = createTaskSpec(func, null, actorImpl, args, false, true, null);
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
}
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.functionDescriptor;
|
||||
return callActorFunction(actor, functionDescriptor, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
TaskSpec spec = createTaskSpec(actorFactoryFunc, null, RayActorImpl.NIL,
|
||||
args, true, false, options);
|
||||
RayActorImpl<?> actor = new RayActorImpl(spec.taskId.getActorId());
|
||||
actor.increaseTaskCounter();
|
||||
actor.setTaskCursor(spec.returnIds[0]);
|
||||
rayletClient.submitTask(spec);
|
||||
return (RayActor<T>) actor;
|
||||
FunctionDescriptor functionDescriptor =
|
||||
functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc)
|
||||
.functionDescriptor;
|
||||
return (RayActor<T>) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
private void checkPyArguments(Object[] args) {
|
||||
@@ -233,146 +141,69 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
public RayObject callPy(String moduleName, String functionName, Object[] args,
|
||||
CallOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, "", functionName);
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, false, false, options);
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, "",
|
||||
functionName);
|
||||
return callNormalFunction(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayObject callPy(RayPyActor pyActor, String functionName, Object... args) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
|
||||
pyActor.getClassName(), functionName);
|
||||
RayPyActorImpl actorImpl = (RayPyActorImpl) pyActor;
|
||||
TaskSpec spec;
|
||||
synchronized (pyActor) {
|
||||
spec = createTaskSpec(null, desc, actorImpl, args, false, true, null);
|
||||
actorImpl.setTaskCursor(spec.returnIds[1]);
|
||||
actorImpl.clearNewActorHandles();
|
||||
}
|
||||
rayletClient.submitTask(spec);
|
||||
return new RayObjectImpl(spec.returnIds[0]);
|
||||
return callActorFunction(pyActor, functionDescriptor, args);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayPyActor createPyActor(String moduleName, String className, Object[] args,
|
||||
ActorCreationOptions options) {
|
||||
checkPyArguments(args);
|
||||
PyFunctionDescriptor desc = new PyFunctionDescriptor(moduleName, className, "__init__");
|
||||
TaskSpec spec = createTaskSpec(null, desc, RayPyActorImpl.NIL, args, true, false, options);
|
||||
RayPyActorImpl actor = new RayPyActorImpl(spec.actorCreationId, moduleName, className);
|
||||
actor.increaseTaskCounter();
|
||||
actor.setTaskCursor(spec.returnIds[0]);
|
||||
rayletClient.submitTask(spec);
|
||||
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(moduleName, className,
|
||||
PYTHON_INIT_METHOD_NAME);
|
||||
return (RayPyActor) createActorImpl(functionDescriptor, args, options);
|
||||
}
|
||||
|
||||
private RayObject callNormalFunction(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, CallOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
|
||||
functionArgs, 1, options);
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
}
|
||||
|
||||
private RayObject callActorFunction(RayActor rayActor,
|
||||
FunctionDescriptor functionDescriptor, Object[] args) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
|
||||
functionDescriptor, functionArgs, 1, null);
|
||||
return new RayObjectImpl(returnIds.get(0));
|
||||
}
|
||||
|
||||
private RayActor createActorImpl(FunctionDescriptor functionDescriptor,
|
||||
Object[] args, ActorCreationOptions options) {
|
||||
List<FunctionArg> functionArgs = ArgumentsBuilder
|
||||
.wrap(args, functionDescriptor.getLanguage() != Language.JAVA);
|
||||
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
|
||||
Preconditions.checkState(StringUtil.isNullOrEmpty(options.jvmOptions));
|
||||
}
|
||||
RayActor actor = taskSubmitter
|
||||
.createActor(functionDescriptor, functionArgs,
|
||||
options);
|
||||
return actor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the task specification.
|
||||
*
|
||||
* @param func The target remote function.
|
||||
* @param pyFunctionDescriptor Descriptor of the target Python function, if the task is a Python
|
||||
* task.
|
||||
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
|
||||
* @param args The arguments for the remote function.
|
||||
* @param isActorCreationTask Whether this task is an actor creation task.
|
||||
* @param isActorTask Whether this task is an actor task.
|
||||
* @return A TaskSpec object.
|
||||
*/
|
||||
private TaskSpec createTaskSpec(RayFunc func, PyFunctionDescriptor pyFunctionDescriptor,
|
||||
RayActorImpl<?> actor, Object[] args,
|
||||
boolean isActorCreationTask, boolean isActorTask, BaseTaskOptions taskOptions) {
|
||||
Preconditions.checkArgument((func == null) != (pyFunctionDescriptor == null));
|
||||
|
||||
ActorId actorCreationId = ActorId.NIL;
|
||||
TaskId taskId = null;
|
||||
final JobId currentJobId = workerContext.getCurrentJobId();
|
||||
final TaskId currentTaskId = workerContext.getCurrentTaskId();
|
||||
final int taskIndex = workerContext.nextTaskIndex();
|
||||
if (isActorCreationTask) {
|
||||
taskId = RayletClientImpl.generateActorCreationTaskId(currentJobId, currentTaskId, taskIndex);
|
||||
actorCreationId = taskId.getActorId();
|
||||
} else if (isActorTask) {
|
||||
taskId = RayletClientImpl.generateActorTaskId(currentJobId, currentTaskId, taskIndex, actor.getId());
|
||||
} else {
|
||||
taskId = RayletClientImpl.generateNormalTaskId(currentJobId, currentTaskId, taskIndex);
|
||||
}
|
||||
|
||||
int numReturns = actor.getId().isNil() ? 1 : 2;
|
||||
|
||||
Map<String, Double> resources;
|
||||
if (null == taskOptions) {
|
||||
resources = new HashMap<>();
|
||||
} else {
|
||||
resources = new HashMap<>(taskOptions.resources);
|
||||
}
|
||||
|
||||
int maxActorReconstruction = 0;
|
||||
List<String> dynamicWorkerOptions = ImmutableList.of();
|
||||
if (taskOptions instanceof ActorCreationOptions) {
|
||||
maxActorReconstruction = ((ActorCreationOptions) taskOptions).maxReconstructions;
|
||||
String jvmOptions = ((ActorCreationOptions) taskOptions).jvmOptions;
|
||||
if (!StringUtil.isNullOrEmpty(jvmOptions)) {
|
||||
dynamicWorkerOptions = ImmutableList.of(((ActorCreationOptions) taskOptions).jvmOptions);
|
||||
}
|
||||
}
|
||||
|
||||
TaskLanguage language;
|
||||
FunctionDescriptor functionDescriptor;
|
||||
if (func != null) {
|
||||
language = TaskLanguage.JAVA;
|
||||
functionDescriptor = functionManager.getFunction(workerContext.getCurrentJobId(), func)
|
||||
.getFunctionDescriptor();
|
||||
} else {
|
||||
language = TaskLanguage.PYTHON;
|
||||
functionDescriptor = pyFunctionDescriptor;
|
||||
}
|
||||
|
||||
ObjectId previousActorTaskDummyObjectId = ObjectId.NIL;
|
||||
if (isActorTask) {
|
||||
previousActorTaskDummyObjectId = actor.getTaskCursor();
|
||||
}
|
||||
|
||||
return new TaskSpec(
|
||||
workerContext.getCurrentJobId(),
|
||||
taskId,
|
||||
workerContext.getCurrentTaskId(),
|
||||
-1,
|
||||
actorCreationId,
|
||||
maxActorReconstruction,
|
||||
actor.getId(),
|
||||
actor.getHandleId(),
|
||||
actor.increaseTaskCounter(),
|
||||
previousActorTaskDummyObjectId,
|
||||
actor.getNewActorHandles().toArray(new UniqueId[0]),
|
||||
ArgumentsBuilder.wrap(args, language == TaskLanguage.PYTHON),
|
||||
numReturns,
|
||||
resources,
|
||||
language,
|
||||
functionDescriptor,
|
||||
dynamicWorkerOptions
|
||||
);
|
||||
}
|
||||
|
||||
public void loop() {
|
||||
worker.loop();
|
||||
}
|
||||
|
||||
public Worker getWorker() {
|
||||
return worker;
|
||||
}
|
||||
|
||||
public WorkerContext getWorkerContext() {
|
||||
return workerContext;
|
||||
}
|
||||
|
||||
public RayletClient getRayletClient() {
|
||||
return rayletClient;
|
||||
public ObjectStore getObjectStore() {
|
||||
return objectStore;
|
||||
}
|
||||
|
||||
public ObjectStoreProxy getObjectStoreProxy() {
|
||||
return objectStoreProxy;
|
||||
public RayletClient getRayletClient() {
|
||||
return rayletClient;
|
||||
}
|
||||
|
||||
public FunctionManager getFunctionManager() {
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import java.io.Externalizable;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.util.Sha1Digestor;
|
||||
|
||||
public class RayActorImpl<T> implements RayActor<T>, Externalizable {
|
||||
|
||||
public static final RayActorImpl NIL = new RayActorImpl();
|
||||
|
||||
/**
|
||||
* Id of this actor.
|
||||
*/
|
||||
protected ActorId id;
|
||||
/**
|
||||
* Handle id of this actor.
|
||||
*/
|
||||
protected UniqueId handleId;
|
||||
/**
|
||||
* The number of tasks that have been invoked on this actor.
|
||||
*/
|
||||
protected int taskCounter;
|
||||
/**
|
||||
* The unique id of the last return of the last task.
|
||||
* It's used as a dependency for the next task.
|
||||
*/
|
||||
protected ObjectId taskCursor;
|
||||
/**
|
||||
* The number of times that this actor handle has been forked.
|
||||
* It's used to make sure ids of actor handles are unique.
|
||||
*/
|
||||
protected int numForks;
|
||||
|
||||
/**
|
||||
* The new actor handles that were created from this handle
|
||||
* since the last task on this handle was submitted. This is
|
||||
* used to garbage-collect dummy objects that are no longer
|
||||
* necessary in the backend.
|
||||
*/
|
||||
protected List<UniqueId> newActorHandles;
|
||||
|
||||
public RayActorImpl() {
|
||||
this(ActorId.NIL, UniqueId.NIL);
|
||||
}
|
||||
|
||||
public RayActorImpl(ActorId id) {
|
||||
this(id, UniqueId.NIL);
|
||||
}
|
||||
|
||||
public RayActorImpl(ActorId id, UniqueId handleId) {
|
||||
this.id = id;
|
||||
this.handleId = handleId;
|
||||
this.taskCounter = 0;
|
||||
this.taskCursor = null;
|
||||
this.newActorHandles = new ArrayList<>();
|
||||
numForks = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getId() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getHandleId() {
|
||||
return handleId;
|
||||
}
|
||||
|
||||
public void setTaskCursor(ObjectId taskCursor) {
|
||||
this.taskCursor = taskCursor;
|
||||
}
|
||||
|
||||
public List<UniqueId> getNewActorHandles() {
|
||||
return this.newActorHandles;
|
||||
}
|
||||
|
||||
public void clearNewActorHandles() {
|
||||
this.newActorHandles.clear();
|
||||
}
|
||||
|
||||
public ObjectId getTaskCursor() {
|
||||
return taskCursor;
|
||||
}
|
||||
|
||||
public int increaseTaskCounter() {
|
||||
return taskCounter++;
|
||||
}
|
||||
|
||||
public RayActorImpl<T> fork() {
|
||||
RayActorImpl<T> ret = new RayActorImpl<>();
|
||||
ret.id = this.id;
|
||||
ret.taskCounter = 0;
|
||||
ret.numForks = 0;
|
||||
ret.taskCursor = this.taskCursor;
|
||||
ret.handleId = this.computeNextActorHandleId();
|
||||
newActorHandles.add(ret.handleId);
|
||||
return ret;
|
||||
}
|
||||
|
||||
protected UniqueId computeNextActorHandleId() {
|
||||
byte[] bytes = Sha1Digestor.digest(handleId.getBytes(), ++numForks);
|
||||
return new UniqueId(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(this.id);
|
||||
out.writeObject(this.handleId);
|
||||
out.writeObject(this.taskCursor);
|
||||
out.writeObject(this.taskCounter);
|
||||
out.writeObject(this.numForks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
this.id = (ActorId) in.readObject();
|
||||
this.handleId = (UniqueId) in.readObject();
|
||||
this.taskCursor = (ObjectId) in.readObject();
|
||||
this.taskCounter = (int) in.readObject();
|
||||
this.numForks = (int) in.readObject();
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,11 @@ package org.ray.runtime;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.MockRayletClient;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.ray.runtime.raylet.LocalModeRayletClient;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
|
||||
public class RayDevRuntime extends AbstractRayRuntime {
|
||||
|
||||
@@ -13,37 +15,26 @@ public class RayDevRuntime extends AbstractRayRuntime {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
private MockObjectInterface objectInterface;
|
||||
|
||||
private AtomicInteger jobCounter = new AtomicInteger(0);
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// Reset library path at runtime.
|
||||
resetLibraryPath();
|
||||
|
||||
objectInterface = new MockObjectInterface(workerContext);
|
||||
if (rayConfig.getJobId().isNil()) {
|
||||
rayConfig.setJobId(nextJobId());
|
||||
}
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterface);
|
||||
rayletClient = new MockRayletClient(this, rayConfig.numberExecThreadsForDevRuntime);
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
|
||||
objectStore = new LocalModeObjectStore(workerContext);
|
||||
taskSubmitter = new LocalModeTaskSubmitter(this, (LocalModeObjectStore) objectStore,
|
||||
rayConfig.numberExecThreadsForDevRuntime);
|
||||
((LocalModeObjectStore) objectStore).addObjectPutCallback(
|
||||
objectId -> ((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId));
|
||||
rayletClient = new LocalModeRayletClient();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
rayletClient.destroy();
|
||||
}
|
||||
|
||||
public MockObjectInterface getObjectInterface() {
|
||||
return objectInterface;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Worker getWorker() {
|
||||
return ((MockRayletClient) rayletClient).getCurrentWorker();
|
||||
taskExecutor = null;
|
||||
}
|
||||
|
||||
private JobId nextJobId() {
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.base.Strings;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.context.NativeWorkerContext;
|
||||
import org.ray.runtime.gcs.GcsClient;
|
||||
import org.ray.runtime.gcs.GcsClientOptions;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.objectstore.ObjectInterfaceImpl;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.object.NativeObjectStore;
|
||||
import org.ray.runtime.raylet.NativeRayletClient;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.task.NativeTaskSubmitter;
|
||||
import org.ray.runtime.task.TaskExecutor;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -23,12 +35,65 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private RunManager manager = null;
|
||||
|
||||
private ObjectInterfaceImpl objectInterfaceImpl = null;
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer;
|
||||
|
||||
static {
|
||||
try {
|
||||
LOGGER.debug("Loading native libraries.");
|
||||
// Load native libraries.
|
||||
String[] libraries = new String[]{"core_worker_library_java"};
|
||||
for (String library : libraries) {
|
||||
String fileName = System.mapLibraryName(library);
|
||||
// Copy the file from resources to a temp dir, and load the native library.
|
||||
File file = File.createTempFile(fileName, "");
|
||||
file.deleteOnExit();
|
||||
InputStream in = AbstractRayRuntime.class.getResourceAsStream("/" + fileName);
|
||||
Preconditions.checkNotNull(in, "{} doesn't exist.", fileName);
|
||||
Files.copy(in, Paths.get(file.getAbsolutePath()), StandardCopyOption.REPLACE_EXISTING);
|
||||
System.load(file.getAbsolutePath());
|
||||
}
|
||||
LOGGER.debug("Native libraries loaded.");
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Couldn't load native libraries.", e);
|
||||
}
|
||||
nativeSetup(RayConfig.create().logDir);
|
||||
Runtime.getRuntime().addShutdownHook(new Thread(RayNativeRuntime::nativeShutdownHook));
|
||||
}
|
||||
|
||||
public RayNativeRuntime(RayConfig rayConfig) {
|
||||
super(rayConfig);
|
||||
}
|
||||
|
||||
protected void resetLibraryPath() {
|
||||
if (rayConfig.libraryPath.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
String path = System.getProperty("java.library.path");
|
||||
if (Strings.isNullOrEmpty(path)) {
|
||||
path = "";
|
||||
} else {
|
||||
path += ":";
|
||||
}
|
||||
path += String.join(":", rayConfig.libraryPath);
|
||||
|
||||
// This is a hack to reset library path at runtime,
|
||||
// see https://stackoverflow.com/questions/15409223/.
|
||||
System.setProperty("java.library.path", path);
|
||||
// Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed.
|
||||
final Field sysPathsField;
|
||||
try {
|
||||
sysPathsField = ClassLoader.class.getDeclaredField("sys_paths");
|
||||
sysPathsField.setAccessible(true);
|
||||
sysPathsField.set(null, null);
|
||||
} catch (NoSuchFieldException | IllegalAccessException e) {
|
||||
LOGGER.error("Failed to set library path.", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() {
|
||||
// Reset library path at runtime.
|
||||
@@ -44,20 +109,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
if (rayConfig.getJobId() == JobId.NIL) {
|
||||
rayConfig.setJobId(gcsClient.nextJobId());
|
||||
}
|
||||
|
||||
workerContext = new WorkerContext(rayConfig.workerMode,
|
||||
rayConfig.getJobId(), rayConfig.runMode);
|
||||
rayletClient = new RayletClientImpl(
|
||||
rayConfig.rayletSocketName,
|
||||
workerContext.getCurrentWorkerId(),
|
||||
rayConfig.workerMode == WorkerType.WORKER,
|
||||
workerContext.getCurrentJobId()
|
||||
);
|
||||
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectInterfaceImpl = new ObjectInterfaceImpl(workerContext, rayletClient,
|
||||
rayConfig.objectStoreSocketName);
|
||||
objectStoreProxy = new ObjectStoreProxy(workerContext, objectInterfaceImpl);
|
||||
nativeCoreWorkerPointer = nativeInitCoreWorker(rayConfig.workerMode.getNumber(),
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
|
||||
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
|
||||
new GcsClientOptions(rayConfig));
|
||||
Preconditions.checkState(nativeCoreWorkerPointer != 0);
|
||||
|
||||
taskExecutor = new TaskExecutor(this);
|
||||
workerContext = new NativeWorkerContext(nativeCoreWorkerPointer);
|
||||
objectStore = new NativeObjectStore(workerContext, nativeCoreWorkerPointer);
|
||||
taskSubmitter = new NativeTaskSubmitter(nativeCoreWorkerPointer);
|
||||
rayletClient = new NativeRayletClient(nativeCoreWorkerPointer);
|
||||
|
||||
// register
|
||||
registerWorker();
|
||||
@@ -71,8 +134,14 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
if (null != manager) {
|
||||
manager.cleanup();
|
||||
}
|
||||
objectInterfaceImpl.destroy();
|
||||
workerContext.destroy();
|
||||
if (nativeCoreWorkerPointer != 0) {
|
||||
nativeDestroyCoreWorker(nativeCoreWorkerPointer);
|
||||
nativeCoreWorkerPointer = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void run() {
|
||||
nativeRunTaskExecutor(nativeCoreWorkerPointer, taskExecutor);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -99,4 +168,16 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
redisClient.hmset("Workers:" + workerId, workerInfo);
|
||||
}
|
||||
}
|
||||
|
||||
private static native long nativeInitCoreWorker(int workerMode, String storeSocket,
|
||||
String rayletSocket, byte[] jobId, GcsClientOptions gcsClientOptions);
|
||||
|
||||
private static native void nativeRunTaskExecutor(long nativeCoreWorkerPointer,
|
||||
TaskExecutor taskExecutor);
|
||||
|
||||
private static native void nativeDestroyCoreWorker(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native void nativeSetup(String logDir);
|
||||
|
||||
private static native void nativeShutdownHook();
|
||||
}
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
|
||||
public class RayPyActorImpl extends RayActorImpl implements RayPyActor {
|
||||
|
||||
public static final RayPyActorImpl NIL = new RayPyActorImpl(ActorId.NIL, null, null);
|
||||
|
||||
/**
|
||||
* Module name of the Python actor class.
|
||||
*/
|
||||
private String moduleName;
|
||||
|
||||
/**
|
||||
* Name of the Python actor class.
|
||||
*/
|
||||
private String className;
|
||||
|
||||
// Note that this empty constructor must be public
|
||||
// since it'll be needed when deserializing.
|
||||
public RayPyActorImpl() {}
|
||||
|
||||
public RayPyActorImpl(ActorId id, String moduleName, String className) {
|
||||
super(id);
|
||||
this.moduleName = moduleName;
|
||||
this.className = className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
return moduleName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
return className;
|
||||
}
|
||||
|
||||
public RayPyActorImpl fork() {
|
||||
RayPyActorImpl ret = new RayPyActorImpl();
|
||||
ret.id = this.id;
|
||||
ret.taskCounter = 0;
|
||||
ret.numForks = 0;
|
||||
ret.taskCursor = this.taskCursor;
|
||||
ret.moduleName = this.moduleName;
|
||||
ret.className = this.className;
|
||||
ret.handleId = this.computeNextActorHandleId();
|
||||
newActorHandles.add(ret.handleId);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
super.writeExternal(out);
|
||||
out.writeObject(this.moduleName);
|
||||
out.writeObject(this.className);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
super.readExternal(in);
|
||||
this.moduleName = (String) in.readObject();
|
||||
this.className = (String) in.readObject();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.generated.Common.WorkerType;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for worker context of core worker.
|
||||
*/
|
||||
public class WorkerContext {
|
||||
|
||||
/**
|
||||
* The native pointer of worker context of core worker.
|
||||
*/
|
||||
private final long nativeWorkerContextPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
/**
|
||||
* The ID of main thread which created the worker context.
|
||||
*/
|
||||
private long mainThreadId;
|
||||
|
||||
/**
|
||||
* The run-mode of this worker.
|
||||
*/
|
||||
private RunMode runMode;
|
||||
|
||||
public WorkerContext(WorkerType workerType, JobId jobId, RunMode runMode) {
|
||||
this.nativeWorkerContextPointer = nativeCreateWorkerContext(workerType.getNumber(), jobId.getBytes());
|
||||
mainThreadId = Thread.currentThread().getId();
|
||||
this.runMode = runMode;
|
||||
currentClassLoader = null;
|
||||
}
|
||||
|
||||
public long getNativeWorkerContext() {
|
||||
return nativeWorkerContextPointer;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return For the main thread, this method returns the ID of this worker's current running task;
|
||||
* for other threads, this method returns a random ID.
|
||||
*/
|
||||
public TaskId getCurrentTaskId() {
|
||||
return TaskId.fromBytes(nativeGetCurrentTaskId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the current task which is being executed by the current worker. Note, this method can only
|
||||
* be called from the main thread.
|
||||
*/
|
||||
public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
|
||||
if (runMode == RunMode.CLUSTER) {
|
||||
Preconditions.checkState(
|
||||
Thread.currentThread().getId() == mainThreadId,
|
||||
"This method should only be called from the main thread."
|
||||
);
|
||||
}
|
||||
|
||||
Preconditions.checkNotNull(task);
|
||||
byte[] taskSpec = RayletClientImpl.convertTaskSpecToProtobuf(task);
|
||||
nativeSetCurrentTask(nativeWorkerContextPointer, taskSpec);
|
||||
currentClassLoader = classLoader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the put index and return the new value.
|
||||
*/
|
||||
public int nextPutIndex() {
|
||||
return nativeGetNextPutIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* Increment the task index and return the new value.
|
||||
*/
|
||||
public int nextTaskIndex() {
|
||||
return nativeGetNextTaskIndex(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The ID of the current worker.
|
||||
*/
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return new UniqueId(nativeGetCurrentWorkerId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
* The ID of the current job.
|
||||
*/
|
||||
public JobId getCurrentJobId() {
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeWorkerContextPointer));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The class loader which is associated with the current job.
|
||||
*/
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current task.
|
||||
*/
|
||||
public TaskSpec getCurrentTask() {
|
||||
byte[] bytes = nativeGetCurrentTask(nativeWorkerContextPointer);
|
||||
if (bytes == null) {
|
||||
return null;
|
||||
}
|
||||
return RayletClientImpl.parseTaskSpecFromProtobuf(bytes);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeWorkerContextPointer);
|
||||
}
|
||||
|
||||
private static native long nativeCreateWorkerContext(int workerType, byte[] jobId);
|
||||
|
||||
private static native byte[] nativeGetCurrentTaskId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeSetCurrentTask(long nativeWorkerContextPointer, byte[] taskSpec);
|
||||
|
||||
private static native byte[] nativeGetCurrentTask(long nativeWorkerContextPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native byte[] nativeGetCurrentWorkerId(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextTaskIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native int nativeGetNextPutIndex(long nativeWorkerContextPointer);
|
||||
|
||||
private static native void nativeDestroy(long nativeWorkerContextPointer);
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package org.ray.runtime.actor;
|
||||
|
||||
import java.io.Externalizable;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* RayActor implementation for local mode.
|
||||
*/
|
||||
public class LocalModeRayActor implements RayActor, Externalizable {
|
||||
|
||||
private ActorId actorId;
|
||||
|
||||
private AtomicReference<ObjectId> previousActorTaskDummyObjectId = new AtomicReference<>();
|
||||
|
||||
public LocalModeRayActor(ActorId actorId, ObjectId previousActorTaskDummyObjectId) {
|
||||
this.actorId = actorId;
|
||||
this.previousActorTaskDummyObjectId.set(previousActorTaskDummyObjectId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Required by FST
|
||||
*/
|
||||
public LocalModeRayActor() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getId() {
|
||||
return actorId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getHandleId() {
|
||||
return UniqueId.NIL;
|
||||
}
|
||||
|
||||
public ObjectId exchangePreviousActorTaskDummyObjectId(ObjectId previousActorTaskDummyObjectId) {
|
||||
return this.previousActorTaskDummyObjectId.getAndSet(previousActorTaskDummyObjectId);
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(actorId);
|
||||
out.writeObject(previousActorTaskDummyObjectId.get());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
actorId = (ActorId) in.readObject();
|
||||
previousActorTaskDummyObjectId.set((ObjectId) in.readObject());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
package org.ray.runtime.actor;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.io.Externalizable;
|
||||
import java.io.IOException;
|
||||
import java.io.ObjectInput;
|
||||
import java.io.ObjectOutput;
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
* RayActor implementation for cluster mode. This is a wrapper class for C++ ActorHandle.
|
||||
*/
|
||||
public class NativeRayActor implements RayActor, RayPyActor, Externalizable {
|
||||
|
||||
/**
|
||||
* Address of native actor handle.
|
||||
*/
|
||||
private long nativeActorHandle;
|
||||
|
||||
public NativeRayActor(long nativeActorHandle) {
|
||||
Preconditions.checkState(nativeActorHandle != 0);
|
||||
this.nativeActorHandle = nativeActorHandle;
|
||||
}
|
||||
|
||||
/**
|
||||
* Required by FST
|
||||
*/
|
||||
public NativeRayActor() {
|
||||
}
|
||||
|
||||
public long getNativeActorHandle() {
|
||||
return nativeActorHandle;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getId() {
|
||||
return ActorId.fromBytes(nativeGetActorId(nativeActorHandle));
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getHandleId() {
|
||||
return new UniqueId(nativeGetActorHandleId(nativeActorHandle));
|
||||
}
|
||||
|
||||
public Language getLanguage() {
|
||||
return Language.forNumber(nativeGetLanguage(nativeActorHandle));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getModuleName() {
|
||||
Preconditions.checkState(getLanguage() == Language.PYTHON);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeActorHandle).get(0);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getClassName() {
|
||||
Preconditions.checkState(getLanguage() == Language.PYTHON);
|
||||
return nativeGetActorCreationTaskFunctionDescriptor(nativeActorHandle).get(1);
|
||||
}
|
||||
|
||||
public NativeRayActor fork() {
|
||||
return new NativeRayActor(nativeFork(nativeActorHandle));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeExternal(ObjectOutput out) throws IOException {
|
||||
out.writeObject(nativeSerialize(nativeActorHandle));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
|
||||
nativeActorHandle = nativeDeserialize((byte[]) in.readObject());
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void finalize() {
|
||||
nativeFree(nativeActorHandle);
|
||||
}
|
||||
|
||||
private static native long nativeFork(long nativeActorHandle);
|
||||
|
||||
private static native byte[] nativeGetActorId(long nativeActorHandle);
|
||||
|
||||
private static native byte[] nativeGetActorHandleId(long nativeActorHandle);
|
||||
|
||||
private static native int nativeGetLanguage(long nativeActorHandle);
|
||||
|
||||
private static native List<String> nativeGetActorCreationTaskFunctionDescriptor(
|
||||
long nativeActorHandle);
|
||||
|
||||
private static native byte[] nativeSerialize(long nativeActorHandle);
|
||||
|
||||
private static native long nativeDeserialize(byte[] data);
|
||||
|
||||
private static native void nativeFree(long nativeActorHandle);
|
||||
}
|
||||
+8
-6
@@ -1,4 +1,4 @@
|
||||
package org.ray.runtime.util;
|
||||
package org.ray.runtime.actor;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.nustaq.serialization.FSTBasicObjectSerializer;
|
||||
@@ -6,20 +6,22 @@ import org.nustaq.serialization.FSTClazzInfo;
|
||||
import org.nustaq.serialization.FSTClazzInfo.FSTFieldInfo;
|
||||
import org.nustaq.serialization.FSTObjectInput;
|
||||
import org.nustaq.serialization.FSTObjectOutput;
|
||||
import org.ray.runtime.RayActorImpl;
|
||||
|
||||
public class RayActorSerializer extends FSTBasicObjectSerializer {
|
||||
/**
|
||||
* To deal with serialization about {@link NativeRayActor}.
|
||||
*/
|
||||
public class NativeRayActorSerializer extends FSTBasicObjectSerializer {
|
||||
|
||||
@Override
|
||||
public void writeObject(FSTObjectOutput out, Object toWrite, FSTClazzInfo clzInfo,
|
||||
FSTClazzInfo.FSTFieldInfo referencedBy, int streamPosition) throws IOException {
|
||||
((RayActorImpl) toWrite).fork().writeExternal(out);
|
||||
((NativeRayActor) toWrite).fork().writeExternal(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void readObject(FSTObjectInput in, Object toRead, FSTClazzInfo clzInfo,
|
||||
FSTFieldInfo referencedBy) throws Exception {
|
||||
FSTFieldInfo referencedBy) throws Exception {
|
||||
super.readObject(in, toRead, clzInfo, referencedBy);
|
||||
((RayActorImpl) toRead).readExternal(in);
|
||||
((NativeRayActor) toRead).readExternal(in);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package org.ray.runtime.context;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.TaskSpec;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.task.LocalModeTaskSubmitter;
|
||||
import sun.reflect.generics.reflectiveObjects.NotImplementedException;
|
||||
|
||||
/**
|
||||
* Worker context for local mode.
|
||||
*/
|
||||
public class LocalModeWorkerContext implements WorkerContext {
|
||||
|
||||
private final JobId jobId;
|
||||
private ThreadLocal<TaskSpec> currentTask = new ThreadLocal<>();
|
||||
|
||||
public LocalModeWorkerContext(JobId jobId) {
|
||||
this.jobId = jobId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public JobId getCurrentJobId() {
|
||||
return jobId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
if (taskSpec == null) {
|
||||
return ActorId.NIL;
|
||||
}
|
||||
return LocalModeTaskSubmitter.getActorId(taskSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
Preconditions.checkNotNull(taskSpec, "Current task is not set.");
|
||||
return taskSpec.getType();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
TaskSpec taskSpec = currentTask.get();
|
||||
Preconditions.checkState(taskSpec != null);
|
||||
return TaskId.fromBytes(taskSpec.getTaskId().toByteArray());
|
||||
}
|
||||
|
||||
public void setCurrentTask(TaskSpec taskSpec) {
|
||||
currentTask.set(taskSpec);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
package org.ray.runtime.context;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
/**
|
||||
* Worker context for cluster mode. This is a wrapper class for worker context of core worker.
|
||||
*/
|
||||
public class NativeWorkerContext implements WorkerContext {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
private ClassLoader currentClassLoader;
|
||||
|
||||
public NativeWorkerContext(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId getCurrentWorkerId() {
|
||||
return UniqueId.fromByteBuffer(nativeGetCurrentWorkerId(nativeCoreWorkerPointer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public JobId getCurrentJobId() {
|
||||
return JobId.fromByteBuffer(nativeGetCurrentJobId(nativeCoreWorkerPointer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
return ActorId.fromByteBuffer(nativeGetCurrentActorId(nativeCoreWorkerPointer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public ClassLoader getCurrentClassLoader() {
|
||||
return currentClassLoader;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
|
||||
if (this.currentClassLoader != currentClassLoader) {
|
||||
this.currentClassLoader = currentClassLoader;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskType getCurrentTaskType() {
|
||||
return TaskType.forNumber(nativeGetCurrentTaskType(nativeCoreWorkerPointer));
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskId getCurrentTaskId() {
|
||||
return TaskId.fromByteBuffer(nativeGetCurrentTaskId(nativeCoreWorkerPointer));
|
||||
}
|
||||
|
||||
private static native int nativeGetCurrentTaskType(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentTaskId(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentJobId(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentWorkerId(long nativeCoreWorkerPointer);
|
||||
|
||||
private static native ByteBuffer nativeGetCurrentActorId(long nativeCoreWorkerPointer);
|
||||
}
|
||||
+8
-8
@@ -1,14 +1,14 @@
|
||||
package org.ray.runtime;
|
||||
package org.ray.runtime.context;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.runtimecontext.NodeInfo;
|
||||
import org.ray.api.runtimecontext.RuntimeContext;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
public class RuntimeContextImpl implements RuntimeContext {
|
||||
|
||||
@@ -25,16 +25,16 @@ public class RuntimeContextImpl implements RuntimeContext {
|
||||
|
||||
@Override
|
||||
public ActorId getCurrentActorId() {
|
||||
Worker worker = runtime.getWorker();
|
||||
Preconditions.checkState(worker != null && !worker.getCurrentActorId().isNil(),
|
||||
ActorId actorId = runtime.getWorkerContext().getCurrentActorId();
|
||||
Preconditions.checkState(actorId != null && !actorId.isNil(),
|
||||
"This method should only be called from an actor.");
|
||||
return worker.getCurrentActorId();
|
||||
return actorId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean wasCurrentActorReconstructed() {
|
||||
TaskSpec currentTask = runtime.getWorkerContext().getCurrentTask();
|
||||
Preconditions.checkState(currentTask != null && currentTask.isActorCreationTask(),
|
||||
TaskType currentTaskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
Preconditions.checkState(currentTaskType == TaskType.ACTOR_CREATION_TASK,
|
||||
"This method can only be called from an actor creation task.");
|
||||
if (isSingleProcess()) {
|
||||
return false;
|
||||
@@ -0,0 +1,49 @@
|
||||
package org.ray.runtime.context;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
|
||||
/**
|
||||
* The context of worker.
|
||||
*/
|
||||
public interface WorkerContext {
|
||||
|
||||
/**
|
||||
* ID of the current worker.
|
||||
*/
|
||||
UniqueId getCurrentWorkerId();
|
||||
|
||||
/**
|
||||
* ID of the current job.
|
||||
*/
|
||||
JobId getCurrentJobId();
|
||||
|
||||
/**
|
||||
* ID of the current actor.
|
||||
*/
|
||||
ActorId getCurrentActorId();
|
||||
|
||||
/**
|
||||
* The class loader that is associated with the current job. It's used for locating classes when
|
||||
* dealing with serialization and deserialization in {@link org.ray.runtime.util.Serializer}.
|
||||
*/
|
||||
ClassLoader getCurrentClassLoader();
|
||||
|
||||
/**
|
||||
* Set the current class loader.
|
||||
*/
|
||||
void setCurrentClassLoader(ClassLoader currentClassLoader);
|
||||
|
||||
/**
|
||||
* Type of the current task.
|
||||
*/
|
||||
TaskType getCurrentTaskType();
|
||||
|
||||
/**
|
||||
* ID of the current task.
|
||||
*/
|
||||
TaskId getCurrentTaskId();
|
||||
}
|
||||
@@ -1,5 +1,8 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
* Base interface of a Ray task's function descriptor.
|
||||
*
|
||||
@@ -8,4 +11,13 @@ package org.ray.runtime.functionmanager;
|
||||
*/
|
||||
public interface FunctionDescriptor {
|
||||
|
||||
/**
|
||||
* @return A list of strings represents the functions.
|
||||
*/
|
||||
List<String> toList();
|
||||
|
||||
/**
|
||||
* @return The language of the function.
|
||||
*/
|
||||
Language getLanguage();
|
||||
}
|
||||
|
||||
+13
@@ -1,6 +1,9 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import com.google.common.base.Objects;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.List;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
* Represents metadata of Java function.
|
||||
@@ -49,4 +52,14 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
|
||||
public int hashCode() {
|
||||
return Objects.hashCode(className, name, typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> toList() {
|
||||
return ImmutableList.of(className, name, typeDescriptor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.JAVA;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
package org.ray.runtime.functionmanager;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
|
||||
/**
|
||||
* Represents metadata of a Python function.
|
||||
*/
|
||||
@@ -21,5 +25,15 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
|
||||
public String toString() {
|
||||
return moduleName + "." + className + "." + functionName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<String> toList() {
|
||||
return Arrays.asList(moduleName, className, functionName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Language getLanguage() {
|
||||
return Language.PYTHON;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package org.ray.runtime.gcs;
|
||||
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
|
||||
/**
|
||||
* Options to create GCS Client.
|
||||
*/
|
||||
public class GcsClientOptions {
|
||||
public String ip;
|
||||
public int port;
|
||||
public String password;
|
||||
|
||||
public GcsClientOptions(RayConfig rayConfig) {
|
||||
ip = rayConfig.getRedisIp();
|
||||
port = rayConfig.getRedisPort();
|
||||
password = rayConfig.redisPassword;
|
||||
}
|
||||
}
|
||||
+14
-13
@@ -1,4 +1,4 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
@@ -8,22 +8,24 @@ import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.function.Consumer;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class MockObjectInterface implements ObjectInterface {
|
||||
/**
|
||||
* Object store methods for local mode.
|
||||
*/
|
||||
public class LocalModeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockObjectInterface.class);
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
|
||||
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
public MockObjectInterface(WorkerContext workerContext) {
|
||||
this.workerContext = workerContext;
|
||||
public LocalModeObjectStore(WorkerContext workerContext) {
|
||||
super(workerContext);
|
||||
}
|
||||
|
||||
public void addObjectPutCallback(Consumer<ObjectId> callback) {
|
||||
@@ -35,15 +37,14 @@ public class MockObjectInterface implements ObjectInterface {
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
ObjectId objectId = ObjectId.forPut(workerContext.getCurrentTaskId(),
|
||||
workerContext.nextPutIndex());
|
||||
put(obj, objectId);
|
||||
public ObjectId putRaw(NativeRayObject obj) {
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
putRaw(obj, objectId);
|
||||
return objectId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
public void putRaw(NativeRayObject obj, ObjectId objectId) {
|
||||
Preconditions.checkNotNull(obj);
|
||||
Preconditions.checkNotNull(objectId);
|
||||
pool.putIfAbsent(objectId, obj);
|
||||
@@ -53,7 +54,7 @@ public class MockObjectInterface implements ObjectInterface {
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
|
||||
waitInternal(objectIds, objectIds.size(), timeoutMs);
|
||||
return objectIds.stream().map(pool::get).collect(Collectors.toList());
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.id.BaseId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Object store methods for cluster mode. This is a wrapper class for core worker object interface.
|
||||
*/
|
||||
public class NativeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(NativeObjectStore.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeObjectStore(WorkerContext workerContext, long nativeCoreWorkerPointer) {
|
||||
super(workerContext);
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId putRaw(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeCoreWorkerPointer, obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putRaw(NativeRayObject obj, ObjectId objectId) {
|
||||
nativePut(nativeCoreWorkerPointer, objectId.getBytes(), obj);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeCoreWorkerPointer, toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeCoreWorkerPointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeCoreWorkerPointer, toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native byte[] nativePut(long nativeCoreWorkerPointer, NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeCoreWorkerPointer, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeCoreWorkerPointer,
|
||||
List<byte[]> ids, long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeCoreWorkerPointer,
|
||||
List<byte[]> objectIds, int numObjects, long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeCoreWorkerPointer, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
}
|
||||
+4
-1
@@ -1,5 +1,8 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
package org.ray.runtime.object;
|
||||
|
||||
/**
|
||||
* Binary representation of ray object.
|
||||
*/
|
||||
public class NativeRayObject {
|
||||
|
||||
public byte[] data;
|
||||
@@ -0,0 +1,217 @@
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.context.WorkerContext;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
|
||||
/**
|
||||
* A class that is used to put/get objects to/from the object store.
|
||||
*/
|
||||
public abstract class ObjectStore {
|
||||
|
||||
private static final byte[] WORKER_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
|
||||
private static final byte[] ACTOR_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
public ObjectStore(WorkerContext workerContext) {
|
||||
this.workerContext = workerContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Put a raw object into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @return Generated ID of the object.
|
||||
*/
|
||||
public abstract ObjectId putRaw(NativeRayObject obj);
|
||||
|
||||
/**
|
||||
* Put a raw object with specified ID into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @param objectId Object ID specified by user.
|
||||
*/
|
||||
public abstract void putRaw(NativeRayObject obj, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Serialize and put an object to the object store.
|
||||
*
|
||||
* @param object The object to put.
|
||||
* @return Id of the object.
|
||||
*/
|
||||
public ObjectId put(Object object) {
|
||||
return putRaw(serialize(object));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of raw objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to get.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return Result list of objects data.
|
||||
*/
|
||||
public abstract List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param ids List of the object ids.
|
||||
* @param <T> Type of these objects.
|
||||
* @return A list of GetResult objects.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> List<T> get(List<ObjectId> ids) {
|
||||
// Pass -1 as timeout to wait until all objects are available in object store.
|
||||
List<NativeRayObject> dataAndMetaList = getRaw(ids, -1);
|
||||
|
||||
List<T> results = new ArrayList<>();
|
||||
for (int i = 0; i < dataAndMetaList.size(); i++) {
|
||||
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
|
||||
Object object = null;
|
||||
if (dataAndMeta != null) {
|
||||
object = deserialize(dataAndMeta, ids.get(i));
|
||||
}
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
throw (RayException) object;
|
||||
}
|
||||
results.add((T) object);
|
||||
}
|
||||
// This check must be placed after the throw exception statement.
|
||||
// Because if there was any exception, The get operation would return early
|
||||
// and wouldn't wait until all objects exist.
|
||||
Preconditions.checkState(dataAndMetaList.stream().allMatch(Objects::nonNull));
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for a list of objects to appear in the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to wait for.
|
||||
* @param numObjects Number of objects that should appear.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return A bitset that indicates each object has appeared or not.
|
||||
*/
|
||||
public abstract List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Wait for a list of RayObjects to be locally available, until specified number of objects are
|
||||
* ready, or specified timeout has passed.
|
||||
*
|
||||
* @param waitList A list of RayObject to wait for.
|
||||
* @param numReturns The number of objects that should be returned.
|
||||
* @param timeoutMs The maximum time in milliseconds to wait before returning.
|
||||
* @return Two lists, one containing locally available objects, one containing the rest.
|
||||
*/
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
|
||||
Preconditions.checkNotNull(waitList);
|
||||
if (waitList.isEmpty()) {
|
||||
return new WaitResult<>(Collections.emptyList(), Collections.emptyList());
|
||||
}
|
||||
|
||||
List<ObjectId> ids = waitList.stream().map(RayObject::getId).collect(Collectors.toList());
|
||||
|
||||
List<Boolean> ready = wait(ids, numReturns, timeoutMs);
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < ready.size(); i++) {
|
||||
if (ready.get(i)) {
|
||||
readyList.add(waitList.get(i));
|
||||
} else {
|
||||
unreadyList.add(waitList.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
return new WaitResult<>(readyList, unreadyList);
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
|
||||
/**
|
||||
* Deserialize an object.
|
||||
*
|
||||
* @param nativeRayObject The object to deserialize.
|
||||
* @param objectId The associated object ID of the object.
|
||||
* @return The deserialized object.
|
||||
*/
|
||||
public Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId) {
|
||||
byte[] meta = nativeRayObject.metadata;
|
||||
byte[] data = nativeRayObject.data;
|
||||
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
return Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize an object.
|
||||
*
|
||||
* @param object The object to serialize.
|
||||
* @return The serialized object.
|
||||
*/
|
||||
public NativeRayObject serialize(Object object) {
|
||||
if (object instanceof NativeRayObject) {
|
||||
return (NativeRayObject) object;
|
||||
} else if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
return new NativeRayObject((byte[]) object, RAW_TYPE_META);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
return new NativeRayObject(Serializer.encode(object),
|
||||
TASK_EXECUTION_EXCEPTION_META);
|
||||
} else {
|
||||
return new NativeRayObject(Serializer.encode(object), null);
|
||||
}
|
||||
}
|
||||
}
|
||||
+4
-1
@@ -1,10 +1,13 @@
|
||||
package org.ray.runtime;
|
||||
package org.ray.runtime.object;
|
||||
|
||||
import java.io.Serializable;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
||||
/**
|
||||
* Implementation of {@link RayObject}.
|
||||
*/
|
||||
public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
|
||||
|
||||
private final ObjectId id;
|
||||
@@ -1,54 +0,0 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.id.ObjectId;
|
||||
|
||||
/**
|
||||
* The interface that contains all worker methods that are related to object store.
|
||||
*/
|
||||
public interface ObjectInterface {
|
||||
|
||||
/**
|
||||
* Put an object into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @return Generated ID of the object.
|
||||
*/
|
||||
ObjectId put(NativeRayObject obj);
|
||||
|
||||
/**
|
||||
* Put an object with specified ID into object store.
|
||||
*
|
||||
* @param obj The ray object.
|
||||
* @param objectId Object ID specified by user.
|
||||
*/
|
||||
void put(NativeRayObject obj, ObjectId objectId);
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to get.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return Result list of objects data.
|
||||
*/
|
||||
List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Wait for a list of objects to appear in the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to wait for.
|
||||
* @param numObjects Number of objects that should appear.
|
||||
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
|
||||
* @return A bitset that indicates each object has appeared or not.
|
||||
*/
|
||||
List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
|
||||
|
||||
/**
|
||||
* Delete a list of objects from the object store.
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
|
||||
}
|
||||
@@ -1,92 +0,0 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.BaseId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.raylet.RayletClient;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* This is a wrapper class for core worker object interface.
|
||||
*/
|
||||
public class ObjectInterfaceImpl implements ObjectInterface {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
|
||||
|
||||
/**
|
||||
* The native pointer of core worker object interface.
|
||||
*/
|
||||
private final long nativeObjectInterfacePointer;
|
||||
|
||||
public ObjectInterfaceImpl(WorkerContext workerContext, RayletClient rayletClient,
|
||||
String storeSocketName) {
|
||||
this.nativeObjectInterfacePointer =
|
||||
nativeCreateObjectInterface(workerContext.getNativeWorkerContext(),
|
||||
((RayletClientImpl) rayletClient).getClient(), storeSocketName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ObjectId put(NativeRayObject obj) {
|
||||
return new ObjectId(nativePut(nativeObjectInterfacePointer, obj));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void put(NativeRayObject obj, ObjectId objectId) {
|
||||
try {
|
||||
nativePut(nativeObjectInterfacePointer, objectId.getBytes(), obj);
|
||||
} catch (RayException e) {
|
||||
LOGGER.warn(e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NativeRayObject> get(List<ObjectId> objectIds, long timeoutMs) {
|
||||
return nativeGet(nativeObjectInterfacePointer, toBinaryList(objectIds), timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs) {
|
||||
return nativeWait(nativeObjectInterfacePointer, toBinaryList(objectIds), numObjects, timeoutMs);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(nativeObjectInterfacePointer,
|
||||
toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(nativeObjectInterfacePointer);
|
||||
}
|
||||
|
||||
private static List<byte[]> toBinaryList(List<ObjectId> ids) {
|
||||
return ids.stream().map(BaseId::getBytes).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native long nativeCreateObjectInterface(long nativeObjectInterface,
|
||||
long nativeRayletClient,
|
||||
String storeSocketName);
|
||||
|
||||
private static native byte[] nativePut(long nativeObjectInterface, NativeRayObject obj);
|
||||
|
||||
private static native void nativePut(long nativeObjectInterface, byte[] objectId,
|
||||
NativeRayObject obj);
|
||||
|
||||
private static native List<NativeRayObject> nativeGet(long nativeObjectInterface,
|
||||
List<byte[]> ids,
|
||||
long timeoutMs);
|
||||
|
||||
private static native List<Boolean> nativeWait(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
int numObjects, long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(long nativeObjectInterface, List<byte[]> objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks);
|
||||
|
||||
private static native void nativeDestroy(long nativeObjectInterface);
|
||||
}
|
||||
@@ -1,152 +0,0 @@
|
||||
package org.ray.runtime.objectstore;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import org.ray.api.exception.RayActorException;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.exception.RayWorkerException;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.WorkerContext;
|
||||
import org.ray.runtime.generated.Gcs.ErrorType;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A class that is used to put/get objects to/from the object store.
|
||||
*/
|
||||
public class ObjectStoreProxy {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(ObjectStoreProxy.class);
|
||||
|
||||
private static final byte[] WORKER_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
|
||||
private static final byte[] ACTOR_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
|
||||
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
|
||||
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
|
||||
|
||||
private static final byte[] RAW_TYPE_META = "RAW".getBytes();
|
||||
|
||||
private final WorkerContext workerContext;
|
||||
|
||||
private final ObjectInterface objectInterface;
|
||||
|
||||
public ObjectStoreProxy(WorkerContext workerContext, ObjectInterface objectInterface) {
|
||||
this.workerContext = workerContext;
|
||||
this.objectInterface = objectInterface;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an object from the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param <T> Type of the object.
|
||||
* @return The GetResult object.
|
||||
*/
|
||||
public <T> T get(ObjectId id) {
|
||||
List<T> list = get(ImmutableList.of(id));
|
||||
return list.get(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a list of objects from the object store.
|
||||
*
|
||||
* @param ids List of the object ids.
|
||||
* @param <T> Type of these objects.
|
||||
* @return A list of GetResult objects.
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> List<T> get(List<ObjectId> ids) {
|
||||
// Pass -1 as timeout to wait until all objects are available in object store.
|
||||
List<NativeRayObject> dataAndMetaList = objectInterface.get(ids, -1);
|
||||
|
||||
List<T> results = new ArrayList<>();
|
||||
for (int i = 0; i < dataAndMetaList.size(); i++) {
|
||||
NativeRayObject dataAndMeta = dataAndMetaList.get(i);
|
||||
Object object = null;
|
||||
if (dataAndMeta != null) {
|
||||
byte[] meta = dataAndMeta.metadata;
|
||||
byte[] data = dataAndMeta.data;
|
||||
if (meta != null && meta.length > 0) {
|
||||
// If meta is not null, deserialize the object from meta.
|
||||
object = deserializeFromMeta(meta, data,
|
||||
workerContext.getCurrentClassLoader(), ids.get(i));
|
||||
} else {
|
||||
// If data is not null, deserialize the Java object.
|
||||
object = Serializer.decode(data, workerContext.getCurrentClassLoader());
|
||||
}
|
||||
if (object instanceof RayException) {
|
||||
// If the object is a `RayException`, it means that an error occurred during task
|
||||
// execution.
|
||||
throw (RayException) object;
|
||||
}
|
||||
}
|
||||
|
||||
results.add((T) object);
|
||||
}
|
||||
// This check must be placed after the throw exception statement.
|
||||
// Because if there was any exception, The get operation would return early
|
||||
// and wouldn't wait until all objects exist.
|
||||
Preconditions.checkState(dataAndMetaList.stream().allMatch(Objects::nonNull));
|
||||
return results;
|
||||
}
|
||||
|
||||
private Object deserializeFromMeta(byte[] meta, byte[] data,
|
||||
ClassLoader classLoader, ObjectId objectId) {
|
||||
if (Arrays.equals(meta, RAW_TYPE_META)) {
|
||||
return data;
|
||||
} else if (Arrays.equals(meta, WORKER_EXCEPTION_META)) {
|
||||
return RayWorkerException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, ACTOR_EXCEPTION_META)) {
|
||||
return RayActorException.INSTANCE;
|
||||
} else if (Arrays.equals(meta, UNRECONSTRUCTABLE_EXCEPTION_META)) {
|
||||
return new UnreconstructableException(objectId);
|
||||
} else if (Arrays.equals(meta, TASK_EXECUTION_EXCEPTION_META)) {
|
||||
return Serializer.decode(data, classLoader);
|
||||
}
|
||||
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize and put an object to the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param object The object to put.
|
||||
*/
|
||||
public void put(ObjectId id, Object object) {
|
||||
if (object instanceof byte[]) {
|
||||
// If the object is a byte array, skip serializing it and use a special metadata to
|
||||
// indicate it's raw binary. So that this object can also be read by Python.
|
||||
objectInterface.put(new NativeRayObject((byte[]) object, RAW_TYPE_META), id);
|
||||
} else if (object instanceof RayTaskException) {
|
||||
objectInterface
|
||||
.put(new NativeRayObject(Serializer.encode(object), TASK_EXECUTION_EXCEPTION_META), id);
|
||||
} else {
|
||||
objectInterface.put(new NativeRayObject(Serializer.encode(object), null), id);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Put an already serialized object to the object store.
|
||||
*
|
||||
* @param id Id of the object.
|
||||
* @param serializedObject The serialized object to put.
|
||||
*/
|
||||
public void putSerialized(ObjectId id, byte[] serializedObject) {
|
||||
objectInterface.put(new NativeRayObject(serializedObject, null), id);
|
||||
}
|
||||
|
||||
public ObjectInterface getObjectInterface() {
|
||||
return objectInterface;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Raylet client for local mode.
|
||||
*/
|
||||
public class LocalModeRayletClient implements RayletClient {
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeRayletClient.class);
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
}
|
||||
@@ -1,204 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.ConcurrentLinkedDeque;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.NotImplementedException;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.Worker;
|
||||
import org.ray.runtime.objectstore.MockObjectInterface;
|
||||
import org.ray.runtime.objectstore.NativeRayObject;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* A mock implementation of RayletClient, used in single process mode.
|
||||
*/
|
||||
public class MockRayletClient implements RayletClient {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(MockRayletClient.class);
|
||||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new ConcurrentHashMap<>();
|
||||
private final MockObjectInterface objectInterface;
|
||||
private final RayDevRuntime runtime;
|
||||
private final ExecutorService exec;
|
||||
private final Deque<Worker> idleWorkers;
|
||||
private final Map<ActorId, Worker> actorWorkers;
|
||||
private final ThreadLocal<Worker> currentWorker;
|
||||
|
||||
public MockRayletClient(RayDevRuntime runtime, int numberThreads) {
|
||||
this.runtime = runtime;
|
||||
this.objectInterface = runtime.getObjectInterface();
|
||||
objectInterface.addObjectPutCallback(this::onObjectPut);
|
||||
// The thread pool that executes tasks in parallel.
|
||||
exec = Executors.newFixedThreadPool(numberThreads);
|
||||
idleWorkers = new ConcurrentLinkedDeque<>();
|
||||
actorWorkers = new HashMap<>();
|
||||
currentWorker = new ThreadLocal<>();
|
||||
}
|
||||
|
||||
public synchronized void onObjectPut(ObjectId id) {
|
||||
Set<TaskSpec> tasks = waitingTasks.get(id);
|
||||
if (tasks != null) {
|
||||
waitingTasks.remove(id);
|
||||
for (TaskSpec taskSpec : tasks) {
|
||||
submitTask(taskSpec);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public Worker getCurrentWorker() {
|
||||
return currentWorker.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a worker from the worker pool to run the given task.
|
||||
*/
|
||||
private synchronized Worker getWorker(TaskSpec task) {
|
||||
Worker worker;
|
||||
if (task.isActorTask()) {
|
||||
worker = actorWorkers.get(task.actorId);
|
||||
} else {
|
||||
if (task.isActorCreationTask()) {
|
||||
worker = new Worker(runtime);
|
||||
actorWorkers.put(task.actorCreationId, worker);
|
||||
} else if (idleWorkers.size() > 0) {
|
||||
worker = idleWorkers.pop();
|
||||
} else {
|
||||
worker = new Worker(runtime);
|
||||
}
|
||||
}
|
||||
currentWorker.set(worker);
|
||||
return worker;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the worker to the worker pool.
|
||||
*/
|
||||
private void returnWorker(Worker worker) {
|
||||
currentWorker.remove();
|
||||
idleWorkers.push(worker);
|
||||
}
|
||||
|
||||
@Override
|
||||
public synchronized void submitTask(TaskSpec task) {
|
||||
LOGGER.debug("Submitting task: {}.", task);
|
||||
Set<ObjectId> unreadyObjects = getUnreadyObjects(task);
|
||||
if (unreadyObjects.isEmpty()) {
|
||||
// If all dependencies are ready, execute this task.
|
||||
exec.submit(() -> {
|
||||
Worker worker = getWorker(task);
|
||||
try {
|
||||
worker.execute(task);
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
if (task.isActorCreationTask() || task.isActorTask()) {
|
||||
ObjectId[] returnIds = task.returnIds;
|
||||
objectInterface.put(new NativeRayObject(new byte[] {}, new byte[] {}),
|
||||
returnIds[returnIds.length - 1]);
|
||||
}
|
||||
} finally {
|
||||
returnWorker(worker);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// If some dependencies aren't ready yet, put this task in waiting list.
|
||||
for (ObjectId id : unreadyObjects) {
|
||||
waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(task);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<ObjectId> getUnreadyObjects(TaskSpec spec) {
|
||||
Set<ObjectId> unreadyObjects = new HashSet<>();
|
||||
// Check whether task arguments are ready.
|
||||
for (FunctionArg arg : spec.args) {
|
||||
if (arg.id != null) {
|
||||
if (!objectInterface.isObjectReady(arg.id)) {
|
||||
unreadyObjects.add(arg.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (spec.isActorTask()) {
|
||||
if (!objectInterface.isObjectReady(spec.previousActorTaskDummyObjectId)) {
|
||||
unreadyObjects.add(spec.previousActorTaskDummyObjectId);
|
||||
}
|
||||
}
|
||||
return unreadyObjects;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public TaskSpec getTask() {
|
||||
throw new RuntimeException("invalid execution flow here");
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
if (waitFor == null || waitFor.isEmpty()) {
|
||||
return new WaitResult<>(ImmutableList.of(), ImmutableList.of());
|
||||
}
|
||||
|
||||
List<ObjectId> ids = waitFor.stream().map(RayObject::getId).collect(Collectors.toList());
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
List<Boolean> result = objectInterface.wait(ids, ids.size(), timeoutMs);
|
||||
for (int i = 0; i < waitFor.size(); i++) {
|
||||
if (result.get(i)) {
|
||||
readyList.add(waitFor.get(i));
|
||||
} else {
|
||||
unreadyList.add(waitFor.get(i));
|
||||
}
|
||||
}
|
||||
return new WaitResult<>(readyList, unreadyList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks) {
|
||||
objectInterface.delete(objectIds, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
throw new NotImplementedException("Not implemented.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
LOGGER.error("Not implemented under SINGLE_PROCESS mode.");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void destroy() {
|
||||
exec.shutdown();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
|
||||
/**
|
||||
* Raylet client for cluster mode. This is a wrapper class for C++ RayletClient.
|
||||
*/
|
||||
public class NativeRayletClient implements RayletClient {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private long nativeCoreWorkerPointer = 0;
|
||||
|
||||
public NativeRayletClient(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
return new UniqueId(nativePrepareCheckpoint(nativeCoreWorkerPointer, actorId.getBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
nativeNotifyActorResumedFromCheckpoint(nativeCoreWorkerPointer, actorId.getBytes(),
|
||||
checkpointId.getBytes());
|
||||
}
|
||||
|
||||
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
nativeSetResource(nativeCoreWorkerPointer, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
/// Native method declarations.
|
||||
///
|
||||
/// If you change the signature of any native methods, please re-generate
|
||||
/// the C++ header file and update the C++ implementation accordingly:
|
||||
///
|
||||
/// Suppose that $Dir is your ray root directory.
|
||||
/// 1) pushd $Dir/java/runtime/target/classes
|
||||
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.NativeRayletClient
|
||||
/// 3) clang-format -i org_ray_runtime_raylet_NativeRayletClient.h
|
||||
/// 4) cp org_ray_runtime_raylet_NativeRayletClient.h $Dir/src/ray/core_worker/lib/java/
|
||||
/// 5) vim $Dir/src/ray/core_worker/lib/java/org_ray_runtime_raylet_NativeRayletClient.cc
|
||||
/// 6) popd
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
|
||||
byte[] checkpointId);
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId) throws RayException;
|
||||
}
|
||||
@@ -1,33 +1,16 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
|
||||
/**
|
||||
* Client to the Raylet backend.
|
||||
*/
|
||||
public interface RayletClient {
|
||||
|
||||
void submitTask(TaskSpec task);
|
||||
|
||||
TaskSpec getTask();
|
||||
|
||||
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId);
|
||||
|
||||
void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks);
|
||||
|
||||
UniqueId prepareCheckpoint(ActorId actorId);
|
||||
|
||||
void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId);
|
||||
|
||||
void setResource(String resourceName, double capacity, UniqueId nodeId);
|
||||
|
||||
void destroy();
|
||||
}
|
||||
|
||||
@@ -1,343 +0,0 @@
|
||||
package org.ray.runtime.raylet;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.protobuf.ByteString;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Common;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.task.FunctionArg;
|
||||
import org.ray.runtime.task.TaskLanguage;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
public class RayletClientImpl implements RayletClient {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayletClientImpl.class);
|
||||
|
||||
/**
|
||||
* The pointer to c++'s raylet client.
|
||||
*/
|
||||
private long client = 0;
|
||||
|
||||
// TODO(qwang): JobId parameter can be removed once we embed jobId in driverId.
|
||||
public RayletClientImpl(String schedulerSockName, UniqueId workerId,
|
||||
boolean isWorker, JobId jobId) {
|
||||
client = nativeInit(schedulerSockName, workerId.getBytes(),
|
||||
isWorker, jobId.getBytes());
|
||||
}
|
||||
|
||||
public long getClient() {
|
||||
return client;
|
||||
}
|
||||
|
||||
@Override
|
||||
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int
|
||||
timeoutMs, TaskId currentTaskId) {
|
||||
Preconditions.checkNotNull(waitFor);
|
||||
if (waitFor.isEmpty()) {
|
||||
return new WaitResult<>(new ArrayList<>(), new ArrayList<>());
|
||||
}
|
||||
|
||||
List<ObjectId> ids = new ArrayList<>();
|
||||
for (RayObject<T> element : waitFor) {
|
||||
ids.add(element.getId());
|
||||
}
|
||||
|
||||
boolean[] ready = nativeWaitObject(client, IdUtil.getIdBytes(ids),
|
||||
numReturns, timeoutMs, false, currentTaskId.getBytes());
|
||||
List<RayObject<T>> readyList = new ArrayList<>();
|
||||
List<RayObject<T>> unreadyList = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < ready.length; i++) {
|
||||
if (ready[i]) {
|
||||
readyList.add(waitFor.get(i));
|
||||
} else {
|
||||
unreadyList.add(waitFor.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
return new WaitResult<>(readyList, unreadyList);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void submitTask(TaskSpec spec) {
|
||||
LOGGER.debug("Submitting task: {}", spec);
|
||||
Preconditions.checkState(!spec.parentTaskId.isNil());
|
||||
Preconditions.checkState(!spec.jobId.isNil());
|
||||
|
||||
byte[] taskSpec = convertTaskSpecToProtobuf(spec);
|
||||
nativeSubmitTask(client, taskSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TaskSpec getTask() {
|
||||
byte[] bytes = nativeGetTask(client);
|
||||
assert (null != bytes);
|
||||
return parseTaskSpecFromProtobuf(bytes);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void freePlasmaObjects(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks) {
|
||||
byte[][] objectIdsArray = IdUtil.getIdBytes(objectIds);
|
||||
nativeFreePlasmaObjects(client, objectIdsArray, localOnly, deleteCreatingTasks);
|
||||
}
|
||||
|
||||
@Override
|
||||
public UniqueId prepareCheckpoint(ActorId actorId) {
|
||||
return new UniqueId(nativePrepareCheckpoint(client, actorId.getBytes()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void notifyActorResumedFromCheckpoint(ActorId actorId, UniqueId checkpointId) {
|
||||
nativeNotifyActorResumedFromCheckpoint(client, actorId.getBytes(), checkpointId.getBytes());
|
||||
}
|
||||
|
||||
public static TaskId generateActorCreationTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) {
|
||||
byte[] bytes = nativeGenerateActorCreationTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex);
|
||||
return TaskId.fromBytes(bytes);
|
||||
}
|
||||
|
||||
public static TaskId generateActorTaskId(JobId jobId, TaskId parentTaskId, int taskIndex, ActorId actorId) {
|
||||
byte[] bytes = nativeGenerateActorTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex, actorId.getBytes());
|
||||
return TaskId.fromBytes(bytes);
|
||||
}
|
||||
|
||||
public static TaskId generateNormalTaskId(JobId jobId, TaskId parentTaskId, int taskIndex) {
|
||||
byte[] bytes = nativeGenerateNormalTaskId(jobId.getBytes(), parentTaskId.getBytes(), taskIndex);
|
||||
return TaskId.fromBytes(bytes);
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse `TaskSpec` protobuf bytes.
|
||||
*/
|
||||
public static TaskSpec parseTaskSpecFromProtobuf(byte[] bytes) {
|
||||
Common.TaskSpec taskSpec;
|
||||
try {
|
||||
taskSpec = Common.TaskSpec.parseFrom(bytes);
|
||||
} catch (InvalidProtocolBufferException e) {
|
||||
throw new RuntimeException("Invalid protobuf data.");
|
||||
}
|
||||
|
||||
// Parse common fields.
|
||||
JobId jobId = JobId.fromByteBuffer(taskSpec.getJobId().asReadOnlyByteBuffer());
|
||||
TaskId taskId = TaskId.fromByteBuffer(taskSpec.getTaskId().asReadOnlyByteBuffer());
|
||||
TaskId parentTaskId = TaskId.fromByteBuffer(taskSpec.getParentTaskId().asReadOnlyByteBuffer());
|
||||
int parentCounter = (int) taskSpec.getParentCounter();
|
||||
int numReturns = (int) taskSpec.getNumReturns();
|
||||
Map<String, Double> resources = taskSpec.getRequiredResourcesMap();
|
||||
|
||||
// Parse args.
|
||||
FunctionArg[] args = new FunctionArg[taskSpec.getArgsCount()];
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
Common.TaskArg arg = taskSpec.getArgs(i);
|
||||
int objectIdsLength = arg.getObjectIdsCount();
|
||||
if (objectIdsLength > 0) {
|
||||
Preconditions.checkArgument(objectIdsLength == 1,
|
||||
"This arg has more than one id: {}", objectIdsLength);
|
||||
args[i] = FunctionArg
|
||||
.passByReference(ObjectId.fromByteBuffer(arg.getObjectIds(0).asReadOnlyByteBuffer()));
|
||||
} else {
|
||||
args[i] = FunctionArg.passByValue(arg.getData().toByteArray());
|
||||
}
|
||||
}
|
||||
|
||||
// Parse function descriptor
|
||||
Preconditions.checkArgument(taskSpec.getLanguage() == Common.Language.JAVA);
|
||||
Preconditions.checkArgument(taskSpec.getFunctionDescriptorCount() == 3);
|
||||
JavaFunctionDescriptor functionDescriptor = new JavaFunctionDescriptor(
|
||||
taskSpec.getFunctionDescriptor(0).toString(Charset.defaultCharset()),
|
||||
taskSpec.getFunctionDescriptor(1).toString(Charset.defaultCharset()),
|
||||
taskSpec.getFunctionDescriptor(2).toString(Charset.defaultCharset())
|
||||
);
|
||||
|
||||
// Parse ActorCreationTaskSpec.
|
||||
ActorId actorCreationId = ActorId.NIL;
|
||||
int maxActorReconstructions = 0;
|
||||
UniqueId[] newActorHandles = new UniqueId[0];
|
||||
List<String> dynamicWorkerOptions = new ArrayList<>();
|
||||
if (taskSpec.getType() == Common.TaskType.ACTOR_CREATION_TASK) {
|
||||
Common.ActorCreationTaskSpec actorCreationTaskSpec = taskSpec.getActorCreationTaskSpec();
|
||||
actorCreationId = ActorId
|
||||
.fromByteBuffer(actorCreationTaskSpec.getActorId().asReadOnlyByteBuffer());
|
||||
maxActorReconstructions = (int) actorCreationTaskSpec.getMaxActorReconstructions();
|
||||
dynamicWorkerOptions = ImmutableList
|
||||
.copyOf(actorCreationTaskSpec.getDynamicWorkerOptionsList());
|
||||
}
|
||||
|
||||
// Parse ActorTaskSpec.
|
||||
ActorId actorId = ActorId.NIL;
|
||||
UniqueId actorHandleId = UniqueId.NIL;
|
||||
ObjectId previousActorTaskDummyObjectId = ObjectId.NIL;
|
||||
int actorCounter = 0;
|
||||
if (taskSpec.getType() == Common.TaskType.ACTOR_TASK) {
|
||||
Common.ActorTaskSpec actorTaskSpec = taskSpec.getActorTaskSpec();
|
||||
actorId = ActorId.fromByteBuffer(actorTaskSpec.getActorId().asReadOnlyByteBuffer());
|
||||
actorHandleId = UniqueId
|
||||
.fromByteBuffer(actorTaskSpec.getActorHandleId().asReadOnlyByteBuffer());
|
||||
actorCounter = (int) actorTaskSpec.getActorCounter();
|
||||
previousActorTaskDummyObjectId = ObjectId.fromByteBuffer(
|
||||
actorTaskSpec.getPreviousActorTaskDummyObjectId().asReadOnlyByteBuffer());
|
||||
newActorHandles = actorTaskSpec.getNewActorHandlesList().stream()
|
||||
.map(byteString -> UniqueId.fromByteBuffer(byteString.asReadOnlyByteBuffer()))
|
||||
.toArray(UniqueId[]::new);
|
||||
}
|
||||
|
||||
return new TaskSpec(jobId, taskId, parentTaskId, parentCounter, actorCreationId,
|
||||
maxActorReconstructions, actorId, actorHandleId, actorCounter,
|
||||
previousActorTaskDummyObjectId, newActorHandles, args, numReturns, resources,
|
||||
TaskLanguage.JAVA, functionDescriptor, dynamicWorkerOptions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a `TaskSpec` to protobuf-serialized bytes.
|
||||
*/
|
||||
public static byte[] convertTaskSpecToProtobuf(TaskSpec task) {
|
||||
// Set common fields.
|
||||
Common.TaskSpec.Builder builder = Common.TaskSpec.newBuilder()
|
||||
.setJobId(ByteString.copyFrom(task.jobId.getBytes()))
|
||||
.setTaskId(ByteString.copyFrom(task.taskId.getBytes()))
|
||||
.setParentTaskId(ByteString.copyFrom(task.parentTaskId.getBytes()))
|
||||
.setParentCounter(task.parentCounter)
|
||||
.setNumReturns(task.numReturns)
|
||||
.putAllRequiredResources(task.resources);
|
||||
|
||||
// Set args
|
||||
builder.addAllArgs(
|
||||
Arrays.stream(task.args).map(arg -> {
|
||||
Common.TaskArg.Builder argBuilder = Common.TaskArg.newBuilder();
|
||||
if (arg.id != null) {
|
||||
argBuilder.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build();
|
||||
} else {
|
||||
argBuilder.setData(ByteString.copyFrom(arg.data)).build();
|
||||
}
|
||||
return argBuilder.build();
|
||||
}).collect(Collectors.toList())
|
||||
);
|
||||
|
||||
// Set function descriptor and language.
|
||||
if (task.language == TaskLanguage.JAVA) {
|
||||
builder.setLanguage(Common.Language.JAVA);
|
||||
builder.addAllFunctionDescriptor(ImmutableList.of(
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().className.getBytes()),
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().name.getBytes()),
|
||||
ByteString.copyFrom(task.getJavaFunctionDescriptor().typeDescriptor.getBytes())
|
||||
));
|
||||
} else {
|
||||
builder.setLanguage(Common.Language.PYTHON);
|
||||
builder.addAllFunctionDescriptor(ImmutableList.of(
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().moduleName.getBytes()),
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().className.getBytes()),
|
||||
ByteString.copyFrom(task.getPyFunctionDescriptor().functionName.getBytes()),
|
||||
ByteString.EMPTY
|
||||
));
|
||||
}
|
||||
|
||||
if (!task.actorCreationId.isNil()) {
|
||||
// Actor creation task.
|
||||
builder.setType(TaskType.ACTOR_CREATION_TASK);
|
||||
builder.setActorCreationTaskSpec(
|
||||
Common.ActorCreationTaskSpec.newBuilder()
|
||||
.setActorId(ByteString.copyFrom(task.actorCreationId.getBytes()))
|
||||
.setMaxActorReconstructions(task.maxActorReconstructions)
|
||||
.addAllDynamicWorkerOptions(task.dynamicWorkerOptions)
|
||||
);
|
||||
} else if (!task.actorId.isNil()) {
|
||||
// Actor task.
|
||||
builder.setType(TaskType.ACTOR_TASK);
|
||||
List<ByteString> newHandles = Arrays.stream(task.newActorHandles)
|
||||
.map(id -> ByteString.copyFrom(id.getBytes())).collect(Collectors.toList());
|
||||
final ObjectId actorCreationDummyObjectId = IdUtil.computeActorCreationDummyObjectId(
|
||||
ActorId.fromByteBuffer(ByteBuffer.wrap(task.actorId.getBytes())));
|
||||
builder.setActorTaskSpec(
|
||||
Common.ActorTaskSpec.newBuilder()
|
||||
.setActorId(ByteString.copyFrom(task.actorId.getBytes()))
|
||||
.setActorHandleId(ByteString.copyFrom(task.actorHandleId.getBytes()))
|
||||
.setActorCreationDummyObjectId(
|
||||
ByteString.copyFrom(actorCreationDummyObjectId.getBytes()))
|
||||
.setPreviousActorTaskDummyObjectId(
|
||||
ByteString.copyFrom(task.previousActorTaskDummyObjectId.getBytes()))
|
||||
.setActorCounter(task.actorCounter)
|
||||
.addAllNewActorHandles(newHandles)
|
||||
);
|
||||
} else {
|
||||
// Normal task.
|
||||
builder.setType(TaskType.NORMAL_TASK);
|
||||
}
|
||||
|
||||
return builder.build().toByteArray();
|
||||
}
|
||||
|
||||
public void setResource(String resourceName, double capacity, UniqueId nodeId) {
|
||||
nativeSetResource(client, resourceName, capacity, nodeId.getBytes());
|
||||
}
|
||||
|
||||
public void destroy() {
|
||||
nativeDestroy(client);
|
||||
}
|
||||
|
||||
/// Native method declarations.
|
||||
///
|
||||
/// If you change the signature of any native methods, please re-generate
|
||||
/// the C++ header file and update the C++ implementation accordingly:
|
||||
///
|
||||
/// Suppose that $Dir is your ray root directory.
|
||||
/// 1) pushd $Dir/java/runtime/target/classes
|
||||
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl
|
||||
/// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h
|
||||
/// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/ray/raylet/lib/java/
|
||||
/// 5) vim $Dir/src/ray/raylet/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
|
||||
/// 6) popd
|
||||
|
||||
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
|
||||
boolean isWorker, byte[] driverTaskId);
|
||||
|
||||
private static native void nativeSubmitTask(long client, byte[] taskSpec)
|
||||
throws RayException;
|
||||
|
||||
private static native byte[] nativeGetTask(long client) throws RayException;
|
||||
|
||||
private static native void nativeDestroy(long client) throws RayException;
|
||||
|
||||
private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
|
||||
int numReturns, int timeout, boolean waitLocal, byte[] currentTaskId) throws RayException;
|
||||
|
||||
private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds,
|
||||
boolean localOnly, boolean deleteCreatingTasks) throws RayException;
|
||||
|
||||
private static native byte[] nativePrepareCheckpoint(long conn, byte[] actorId);
|
||||
|
||||
private static native void nativeNotifyActorResumedFromCheckpoint(long conn, byte[] actorId,
|
||||
byte[] checkpointId);
|
||||
|
||||
private static native void nativeSetResource(long conn, String resourceName, double capacity,
|
||||
byte[] nodeId) throws RayException;
|
||||
|
||||
private static native byte[] nativeGenerateActorCreationTaskId(byte[] jobId, byte[] parentTaskId,
|
||||
int taskIndex);
|
||||
|
||||
private static native byte[] nativeGenerateActorTaskId(byte[] jobId, byte[] parentTaskId,
|
||||
int taskIndex, byte[] actorId);
|
||||
|
||||
private static native byte[] nativeGenerateNormalTaskId(byte[] jobId, byte[] parentTaskId,
|
||||
int taskIndex);
|
||||
}
|
||||
@@ -20,6 +20,7 @@ import java.util.Random;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.util.FileUtil;
|
||||
import org.ray.runtime.util.ResourceUtil;
|
||||
@@ -44,17 +45,16 @@ public class RunManager {
|
||||
|
||||
private Random random;
|
||||
|
||||
private List<Process> processes;
|
||||
private List<Pair<String, Process>> processes;
|
||||
|
||||
private static final int KILL_PROCESS_WAIT_TIMEOUT_SECONDS = 1;
|
||||
|
||||
private final Map<String, File> tempFiles;
|
||||
private static final Map<String, File> tempFiles = new HashMap<>();
|
||||
|
||||
public RunManager(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
processes = new ArrayList<>();
|
||||
random = new Random();
|
||||
tempFiles = new HashMap<>();
|
||||
}
|
||||
|
||||
public void cleanup() {
|
||||
@@ -63,19 +63,28 @@ public class RunManager {
|
||||
// cannot exit gracefully.
|
||||
|
||||
for (int i = processes.size() - 1; i >= 0; --i) {
|
||||
Process p = processes.get(i);
|
||||
p.destroy();
|
||||
Pair<String, Process> pair = processes.get(i);
|
||||
String name = pair.getLeft();
|
||||
Process p = pair.getRight();
|
||||
|
||||
try {
|
||||
p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while waiting for process {}" +
|
||||
" to be terminated.", processes.get(i));
|
||||
}
|
||||
|
||||
if (p.isAlive()) {
|
||||
p.destroyForcibly();
|
||||
int numAttempts = 0;
|
||||
while (p.isAlive()) {
|
||||
if (numAttempts == 0) {
|
||||
LOGGER.debug("Terminating process {}.", name);
|
||||
p.destroy();
|
||||
} else {
|
||||
LOGGER.debug("Terminating process {} forcibly.", name);
|
||||
p.destroyForcibly();
|
||||
}
|
||||
try {
|
||||
p.waitFor(KILL_PROCESS_WAIT_TIMEOUT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException e) {
|
||||
LOGGER.warn("Got InterruptedException while waiting for process {}" +
|
||||
" to be terminated.", processes.get(i));
|
||||
}
|
||||
numAttempts++;
|
||||
}
|
||||
LOGGER.info("Process {} is now terminated.", name);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +161,7 @@ public class RunManager {
|
||||
if (!p.isAlive()) {
|
||||
throw new RuntimeException("Failed to start " + name);
|
||||
}
|
||||
processes.add(p);
|
||||
processes.add(Pair.of(name, p));
|
||||
LOGGER.info("{} process started", name);
|
||||
}
|
||||
|
||||
@@ -259,7 +268,7 @@ public class RunManager {
|
||||
String.format("--store_socket_name=%s", rayConfig.objectStoreSocketName),
|
||||
String.format("--object_manager_port=%d", 0), // The object manager port.
|
||||
String.format("--node_manager_port=%d", 0), // The node manager port.
|
||||
String.format("--node_ip_address=%s",rayConfig.nodeIp),
|
||||
String.format("--node_ip_address=%s", rayConfig.nodeIp),
|
||||
String.format("--redis_address=%s", rayConfig.getRedisIp()),
|
||||
String.format("--redis_port=%d", rayConfig.getRedisPort()),
|
||||
String.format("--num_initial_workers=%d", 0), // number of initial workers
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
package org.ray.runtime.runner.worker;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayNativeRuntime;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -20,7 +20,7 @@ public class DefaultWorker {
|
||||
});
|
||||
Ray.init();
|
||||
LOGGER.info("Worker started.");
|
||||
((AbstractRayRuntime)Ray.internal()).loop();
|
||||
((RayNativeRuntime)Ray.internal()).run();
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Failed to start worker.", e);
|
||||
}
|
||||
|
||||
@@ -3,12 +3,16 @@ package org.ray.runtime.task;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
|
||||
/**
|
||||
* Helper methods to convert arguments from/to objects.
|
||||
*/
|
||||
public class ArgumentsBuilder {
|
||||
|
||||
/**
|
||||
@@ -20,16 +24,13 @@ public class ArgumentsBuilder {
|
||||
/**
|
||||
* Convert real function arguments to task spec arguments.
|
||||
*/
|
||||
public static FunctionArg[] wrap(Object[] args, boolean crossLanguage) {
|
||||
FunctionArg[] ret = new FunctionArg[args.length];
|
||||
for (int i = 0; i < ret.length; i++) {
|
||||
Object arg = args[i];
|
||||
public static List<FunctionArg> wrap(Object[] args, boolean crossLanguage) {
|
||||
List<FunctionArg> ret = new ArrayList<>();
|
||||
for (Object arg : args) {
|
||||
ObjectId id = null;
|
||||
byte[] data = null;
|
||||
if (arg == null) {
|
||||
data = Serializer.encode(null);
|
||||
} else if (arg instanceof RayActor) {
|
||||
data = Serializer.encode(arg);
|
||||
} else if (arg instanceof RayObject) {
|
||||
id = ((RayObject) arg).getId();
|
||||
} else if (arg instanceof byte[] && crossLanguage) {
|
||||
@@ -40,41 +41,28 @@ public class ArgumentsBuilder {
|
||||
} else {
|
||||
byte[] serialized = Serializer.encode(arg);
|
||||
if (serialized.length > LARGEST_SIZE_PASS_BY_VALUE) {
|
||||
id = ((AbstractRayRuntime) Ray.internal()).putSerialized(serialized).getId();
|
||||
id = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
.put(new NativeRayObject(serialized, null));
|
||||
} else {
|
||||
data = serialized;
|
||||
}
|
||||
}
|
||||
if (id != null) {
|
||||
ret[i] = FunctionArg.passByReference(id);
|
||||
ret.add(FunctionArg.passByReference(id));
|
||||
} else {
|
||||
ret[i] = FunctionArg.passByValue(data);
|
||||
ret.add(FunctionArg.passByValue(data));
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert task spec arguments to real function arguments.
|
||||
* Convert list of NativeRayObject to real function arguments.
|
||||
*/
|
||||
public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) {
|
||||
Object[] realArgs = new Object[task.args.length];
|
||||
List<ObjectId> idsToFetch = new ArrayList<>();
|
||||
List<Integer> indices = new ArrayList<>();
|
||||
for (int i = 0; i < task.args.length; i++) {
|
||||
FunctionArg arg = task.args[i];
|
||||
if (arg.id != null) {
|
||||
// pass by reference
|
||||
idsToFetch.add(arg.id);
|
||||
indices.add(i);
|
||||
} else {
|
||||
// pass by value
|
||||
realArgs[i] = Serializer.decode(arg.data, classLoader);
|
||||
}
|
||||
}
|
||||
List<Object> objects = Ray.get(idsToFetch);
|
||||
for (int i = 0; i < objects.size(); i++) {
|
||||
realArgs[indices.get(i)] = objects.get(i);
|
||||
public static Object[] unwrap(ObjectStore objectStore, List<NativeRayObject> args) {
|
||||
Object[] realArgs = new Object[args.size()];
|
||||
for (int i = 0; i < args.size(); i++) {
|
||||
realArgs[i] = objectStore.deserialize(args.get(i), null);
|
||||
}
|
||||
return realArgs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,298 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.protobuf.ByteString;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayDeque;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Deque;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import org.ray.runtime.context.LocalModeWorkerContext;
|
||||
import org.ray.runtime.RayDevRuntime;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.generated.Common.ActorCreationTaskSpec;
|
||||
import org.ray.runtime.generated.Common.ActorTaskSpec;
|
||||
import org.ray.runtime.generated.Common.Language;
|
||||
import org.ray.runtime.generated.Common.TaskArg;
|
||||
import org.ray.runtime.generated.Common.TaskSpec;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.LocalModeObjectStore;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Task submitter for local mode.
|
||||
*/
|
||||
public class LocalModeTaskSubmitter implements TaskSubmitter {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeTaskSubmitter.class);
|
||||
|
||||
private final Map<ObjectId, Set<TaskSpec>> waitingTasks = new HashMap<>();
|
||||
private final Object taskAndObjectLock = new Object();
|
||||
private final RayDevRuntime runtime;
|
||||
private final LocalModeObjectStore objectStore;
|
||||
private final ExecutorService exec;
|
||||
private final Deque<TaskExecutor> idleTaskExecutors = new ArrayDeque<>();
|
||||
private final Map<ActorId, TaskExecutor> actorTaskExecutors = new HashMap<>();
|
||||
private final Object taskExecutorLock = new Object();
|
||||
private final ThreadLocal<TaskExecutor> currentTaskExecutor = new ThreadLocal<>();
|
||||
|
||||
public LocalModeTaskSubmitter(RayDevRuntime runtime, LocalModeObjectStore objectStore,
|
||||
int numberThreads) {
|
||||
this.runtime = runtime;
|
||||
this.objectStore = objectStore;
|
||||
// The thread pool that executes tasks in parallel.
|
||||
exec = Executors.newFixedThreadPool(numberThreads);
|
||||
}
|
||||
|
||||
public void onObjectPut(ObjectId id) {
|
||||
Set<TaskSpec> tasks;
|
||||
synchronized (taskAndObjectLock) {
|
||||
tasks = waitingTasks.remove(id);
|
||||
if (tasks != null) {
|
||||
for (TaskSpec task : tasks) {
|
||||
Set<ObjectId> unreadyObjects = getUnreadyObjects(task);
|
||||
if (unreadyObjects.isEmpty()) {
|
||||
submitTaskSpec(task);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the worker of current thread. <br> NOTE: Cannot be used for multi-threading in worker.
|
||||
*/
|
||||
public TaskExecutor getCurrentTaskExecutor() {
|
||||
return currentTaskExecutor.get();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a worker from the worker pool to run the given task.
|
||||
*/
|
||||
private TaskExecutor getTaskExecutor(TaskSpec task) {
|
||||
TaskExecutor taskExecutor;
|
||||
synchronized (taskExecutorLock) {
|
||||
if (task.getType() == TaskType.ACTOR_TASK) {
|
||||
taskExecutor = actorTaskExecutors.get(getActorId(task));
|
||||
} else if (task.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
actorTaskExecutors.put(getActorId(task), taskExecutor);
|
||||
} else if (idleTaskExecutors.size() > 0) {
|
||||
taskExecutor = idleTaskExecutors.pop();
|
||||
} else {
|
||||
taskExecutor = new TaskExecutor(runtime);
|
||||
}
|
||||
}
|
||||
currentTaskExecutor.set(taskExecutor);
|
||||
return taskExecutor;
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the worker to the worker pool.
|
||||
*/
|
||||
private void returnTaskExecutor(TaskExecutor worker, TaskSpec taskSpec) {
|
||||
currentTaskExecutor.remove();
|
||||
synchronized (taskExecutorLock) {
|
||||
if (taskSpec.getType() == TaskType.NORMAL_TASK) {
|
||||
idleTaskExecutors.push(worker);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Set<ObjectId> getUnreadyObjects(TaskSpec taskSpec) {
|
||||
Set<ObjectId> unreadyObjects = new HashSet<>();
|
||||
// Check whether task arguments are ready.
|
||||
for (TaskArg arg : taskSpec.getArgsList()) {
|
||||
for (ByteString idByteString : arg.getObjectIdsList()) {
|
||||
ObjectId id = new ObjectId(idByteString.toByteArray());
|
||||
if (!objectStore.isObjectReady(id)) {
|
||||
unreadyObjects.add(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
|
||||
ObjectId dummyObjectId = new ObjectId(
|
||||
taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray());
|
||||
if (!objectStore.isObjectReady(dummyObjectId)) {
|
||||
unreadyObjects.add(dummyObjectId);
|
||||
}
|
||||
}
|
||||
return unreadyObjects;
|
||||
}
|
||||
|
||||
private TaskSpec.Builder getTaskSpecBuilder(TaskType taskType,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args) {
|
||||
byte[] taskIdBytes = new byte[TaskId.LENGTH];
|
||||
new Random().nextBytes(taskIdBytes);
|
||||
return TaskSpec.newBuilder()
|
||||
.setType(taskType)
|
||||
.setLanguage(Language.JAVA)
|
||||
.setJobId(
|
||||
ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
|
||||
.setTaskId(ByteString.copyFrom(taskIdBytes))
|
||||
.addAllFunctionDescriptor(functionDescriptor.toList().stream().map(ByteString::copyFromUtf8)
|
||||
.collect(Collectors.toList()))
|
||||
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
|
||||
.addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build()
|
||||
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.data)).build())
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(numReturns == 1);
|
||||
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args)
|
||||
.setNumReturns(numReturns)
|
||||
.build();
|
||||
submitTaskSpec(taskSpec);
|
||||
return getReturnIds(taskSpec);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions options) {
|
||||
ActorId actorId = ActorId.fromRandom();
|
||||
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args)
|
||||
.setNumReturns(1)
|
||||
.setActorCreationTaskSpec(ActorCreationTaskSpec.newBuilder()
|
||||
.setActorId(ByteString.copyFrom(actorId.toByteBuffer()))
|
||||
.build())
|
||||
.build();
|
||||
submitTaskSpec(taskSpec);
|
||||
return new LocalModeRayActor(actorId, getReturnIds(taskSpec).get(0));
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(numReturns == 1);
|
||||
TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args);
|
||||
List<ObjectId> returnIds = getReturnIds(
|
||||
TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
|
||||
TaskSpec taskSpec = builder
|
||||
.setNumReturns(numReturns + 1)
|
||||
.setActorTaskSpec(
|
||||
ActorTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actor.getId().getBytes()))
|
||||
.setPreviousActorTaskDummyObjectId(ByteString.copyFrom(
|
||||
((LocalModeRayActor) actor)
|
||||
.exchangePreviousActorTaskDummyObjectId(returnIds.get(returnIds.size() - 1))
|
||||
.getBytes()))
|
||||
.build())
|
||||
.build();
|
||||
submitTaskSpec(taskSpec);
|
||||
return Collections.singletonList(returnIds.get(0));
|
||||
}
|
||||
|
||||
public static ActorId getActorId(TaskSpec taskSpec) {
|
||||
ByteString actorId = null;
|
||||
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
|
||||
actorId = taskSpec.getActorCreationTaskSpec().getActorId();
|
||||
} else if (taskSpec.getType() == TaskType.ACTOR_TASK) {
|
||||
actorId = taskSpec.getActorTaskSpec().getActorId();
|
||||
}
|
||||
if (actorId == null) {
|
||||
return null;
|
||||
}
|
||||
return ActorId.fromBytes(actorId.toByteArray());
|
||||
}
|
||||
|
||||
private void submitTaskSpec(TaskSpec taskSpec) {
|
||||
LOGGER.debug("Submitting task: {}.", taskSpec);
|
||||
synchronized (taskAndObjectLock) {
|
||||
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
|
||||
if (unreadyObjects.isEmpty()) {
|
||||
// If all dependencies are ready, execute this task.
|
||||
exec.submit(() -> {
|
||||
TaskExecutor taskExecutor = getTaskExecutor(taskSpec);
|
||||
try {
|
||||
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
|
||||
.map(arg -> arg.id != null ?
|
||||
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
|
||||
: new NativeRayObject(arg.data, null))
|
||||
.collect(Collectors.toList());
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
|
||||
List<NativeRayObject> returnObjects = taskExecutor
|
||||
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
|
||||
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(null);
|
||||
List<ObjectId> returnIds = getReturnIds(taskSpec);
|
||||
for (int i = 0; i < returnIds.size(); i++) {
|
||||
NativeRayObject putObject;
|
||||
if (i >= returnObjects.size()) {
|
||||
// If the task is an actor task or an actor creation task,
|
||||
// put the dummy object in object store, so those tasks which depends on it
|
||||
// can be executed.
|
||||
putObject = new NativeRayObject(new byte[]{}, new byte[]{});
|
||||
} else {
|
||||
putObject = returnObjects.get(i);
|
||||
}
|
||||
objectStore.putRaw(putObject, returnIds.get(i));
|
||||
}
|
||||
} finally {
|
||||
returnTaskExecutor(taskExecutor, taskSpec);
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// If some dependencies aren't ready yet, put this task in waiting list.
|
||||
for (ObjectId id : unreadyObjects) {
|
||||
waitingTasks.computeIfAbsent(id, k -> new HashSet<>()).add(taskSpec);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
|
||||
List<ByteString> functionDescriptor = taskSpec.getFunctionDescriptorList();
|
||||
return new JavaFunctionDescriptor(functionDescriptor.get(0).toStringUtf8(),
|
||||
functionDescriptor.get(1).toStringUtf8(), functionDescriptor.get(2).toStringUtf8());
|
||||
}
|
||||
|
||||
private static List<FunctionArg> getFunctionArgs(TaskSpec taskSpec) {
|
||||
List<FunctionArg> functionArgs = new ArrayList<>();
|
||||
for (int i = 0; i < taskSpec.getArgsCount(); i++) {
|
||||
TaskArg arg = taskSpec.getArgs(i);
|
||||
if (arg.getObjectIdsCount() > 0) {
|
||||
functionArgs.add(FunctionArg
|
||||
.passByReference(new ObjectId(arg.getObjectIds(0).toByteArray())));
|
||||
} else {
|
||||
functionArgs.add(FunctionArg.passByValue(arg.getData().toByteArray()));
|
||||
}
|
||||
}
|
||||
return functionArgs;
|
||||
}
|
||||
|
||||
private static List<ObjectId> getReturnIds(TaskSpec taskSpec) {
|
||||
return getReturnIds(TaskId.fromBytes(taskSpec.getTaskId().toByteArray()),
|
||||
taskSpec.getNumReturns());
|
||||
}
|
||||
|
||||
private static List<ObjectId> getReturnIds(TaskId taskId, long numReturns) {
|
||||
List<ObjectId> returnIds = new ArrayList<>();
|
||||
for (int i = 0; i < numReturns; i++) {
|
||||
returnIds.add(ObjectId.fromByteBuffer(
|
||||
(ByteBuffer) ByteBuffer.allocate(ObjectId.LENGTH).put(taskId.getBytes())
|
||||
.putInt(TaskId.LENGTH, i + 1).position(0)));
|
||||
}
|
||||
return returnIds;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
|
||||
/**
|
||||
* Task submitter for cluster mode. This is a wrapper class for core worker task interface.
|
||||
*/
|
||||
public class NativeTaskSubmitter implements TaskSubmitter {
|
||||
|
||||
/**
|
||||
* The native pointer of core worker.
|
||||
*/
|
||||
private final long nativeCoreWorkerPointer;
|
||||
|
||||
public NativeTaskSubmitter(long nativeCoreWorkerPointer) {
|
||||
this.nativeCoreWorkerPointer = nativeCoreWorkerPointer;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options) {
|
||||
List<byte[]> returnIds = nativeSubmitTask(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
numReturns, options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions options) {
|
||||
long nativeActorHandle = nativeCreateActor(nativeCoreWorkerPointer, functionDescriptor, args,
|
||||
options);
|
||||
return new NativeRayActor(nativeActorHandle);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options) {
|
||||
Preconditions.checkState(actor instanceof NativeRayActor);
|
||||
List<byte[]> returnIds = nativeSubmitActorTask(nativeCoreWorkerPointer,
|
||||
((NativeRayActor) actor).getNativeActorHandle(), functionDescriptor, args, numReturns,
|
||||
options);
|
||||
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static native List<byte[]> nativeSubmitTask(long nativeCoreWorkerPointer,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
|
||||
CallOptions callOptions);
|
||||
|
||||
private static native long nativeCreateActor(long nativeCoreWorkerPointer,
|
||||
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions actorCreationOptions);
|
||||
|
||||
private static native List<byte[]> nativeSubmitActorTask(long nativeCoreWorkerPointer,
|
||||
long nativeActorHandle, FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions callOptions);
|
||||
}
|
||||
+46
-57
@@ -1,4 +1,4 @@
|
||||
package org.ray.runtime;
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.ArrayList;
|
||||
@@ -8,38 +8,34 @@ import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.Checkpointable.CheckpointContext;
|
||||
import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.ray.runtime.util.IdUtil;
|
||||
import org.ray.runtime.generated.Common.TaskType;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* The worker, which pulls tasks from {@link org.ray.runtime.raylet.RayletClient} and executes them
|
||||
* continuously.
|
||||
* The task executor, which executes tasks assigned by raylet continuously.
|
||||
*/
|
||||
public class Worker {
|
||||
public final class TaskExecutor {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
|
||||
|
||||
// TODO(hchen): Use the C++ config.
|
||||
private static final int NUM_ACTOR_CHECKPOINTS_TO_KEEP = 20;
|
||||
|
||||
private final AbstractRayRuntime runtime;
|
||||
protected final AbstractRayRuntime runtime;
|
||||
|
||||
/**
|
||||
* The current actor object, if this worker is an actor, otherwise null.
|
||||
*/
|
||||
private Object currentActor = null;
|
||||
|
||||
/**
|
||||
* Id of the current actor object, if the worker is an actor, otherwise NIL.
|
||||
*/
|
||||
private ActorId currentActorId = ActorId.NIL;
|
||||
protected Object currentActor = null;
|
||||
|
||||
/**
|
||||
* The exception that failed the actor creation task, if any.
|
||||
@@ -61,53 +57,36 @@ public class Worker {
|
||||
*/
|
||||
private long lastCheckpointTimestamp = 0;
|
||||
|
||||
|
||||
public Worker(AbstractRayRuntime runtime) {
|
||||
public TaskExecutor(AbstractRayRuntime runtime) {
|
||||
this.runtime = runtime;
|
||||
}
|
||||
|
||||
public ActorId getCurrentActorId() {
|
||||
return currentActorId;
|
||||
}
|
||||
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
|
||||
List<NativeRayObject> argsBytes) {
|
||||
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
|
||||
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
|
||||
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
|
||||
LOGGER.debug("Executing task {}", taskId);
|
||||
|
||||
public void loop() {
|
||||
while (true) {
|
||||
LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName());
|
||||
TaskSpec task = runtime.getRayletClient().getTask();
|
||||
execute(task);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a task.
|
||||
*/
|
||||
public void execute(TaskSpec spec) {
|
||||
LOGGER.debug("Executing task {}", spec);
|
||||
ObjectId returnId = spec.returnIds[0];
|
||||
List<NativeRayObject> returnObjects = new ArrayList<>();
|
||||
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
|
||||
try {
|
||||
// Get method
|
||||
RayFunction rayFunction = runtime.getFunctionManager()
|
||||
.getFunction(spec.jobId, spec.getJavaFunctionDescriptor());
|
||||
// Set context
|
||||
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
|
||||
.getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo));
|
||||
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
|
||||
|
||||
if (spec.isActorCreationTask()) {
|
||||
currentActorId = spec.taskId.getActorId();
|
||||
}
|
||||
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);
|
||||
|
||||
// Get local actor object and arguments.
|
||||
Object actor = null;
|
||||
if (spec.isActorTask()) {
|
||||
Preconditions.checkState(spec.actorId.equals(currentActorId));
|
||||
if (taskType == TaskType.ACTOR_TASK) {
|
||||
if (actorCreationException != null) {
|
||||
throw actorCreationException;
|
||||
}
|
||||
actor = currentActor;
|
||||
|
||||
}
|
||||
Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader);
|
||||
Object[] args = ArgumentsBuilder.unwrap(runtime.getObjectStore(), argsBytes);
|
||||
// Execute the task.
|
||||
Object result;
|
||||
if (!rayFunction.isConstructor()) {
|
||||
@@ -116,27 +95,37 @@ public class Worker {
|
||||
result = rayFunction.getConstructor().newInstance(args);
|
||||
}
|
||||
// Set result
|
||||
if (!spec.isActorCreationTask()) {
|
||||
if (spec.isActorTask()) {
|
||||
maybeSaveCheckpoint(actor, spec.actorId);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
if (taskType == TaskType.ACTOR_TASK) {
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeSaveCheckpoint(actor, runtime.getWorkerContext().getCurrentActorId());
|
||||
}
|
||||
|
||||
runtime.put(returnId, result);
|
||||
returnObjects.add(runtime.getObjectStore().serialize(result));
|
||||
} else {
|
||||
maybeLoadCheckpoint(result, spec.taskId.getActorId());
|
||||
// TODO (kfstorm): handle checkpoint in core worker.
|
||||
maybeLoadCheckpoint(result, runtime.getWorkerContext().getCurrentActorId());
|
||||
currentActor = result;
|
||||
}
|
||||
LOGGER.debug("Finished executing task {}", spec.taskId);
|
||||
LOGGER.debug("Finished executing task {}", taskId);
|
||||
} catch (Exception e) {
|
||||
LOGGER.error("Error executing task " + spec, e);
|
||||
if (!spec.isActorCreationTask()) {
|
||||
runtime.put(returnId, new RayTaskException("Error executing task " + spec, e));
|
||||
LOGGER.error("Error executing task " + taskId, e);
|
||||
if (taskType != TaskType.ACTOR_CREATION_TASK) {
|
||||
returnObjects.add(runtime.getObjectStore()
|
||||
.serialize(new RayTaskException("Error executing task " + taskId, e)));
|
||||
} else {
|
||||
actorCreationException = e;
|
||||
}
|
||||
} finally {
|
||||
Thread.currentThread().setContextClassLoader(oldLoader);
|
||||
runtime.getWorkerContext().setCurrentClassLoader(null);
|
||||
}
|
||||
return returnObjects;
|
||||
}
|
||||
|
||||
private JavaFunctionDescriptor parseFunctionDescriptor(List<String> rayFunctionInfo) {
|
||||
Preconditions.checkState(rayFunctionInfo != null && rayFunctionInfo.size() == 3);
|
||||
return new JavaFunctionDescriptor(rayFunctionInfo.get(0), rayFunctionInfo.get(1),
|
||||
rayFunctionInfo.get(2));
|
||||
}
|
||||
|
||||
private void maybeSaveCheckpoint(Object actor, ActorId actorId) {
|
||||
@@ -155,7 +144,7 @@ public class Worker {
|
||||
}
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
UniqueId checkpointId = runtime.rayletClient.prepareCheckpoint(actorId);
|
||||
UniqueId checkpointId = runtime.getRayletClient().prepareCheckpoint(actorId);
|
||||
checkpointIds.add(checkpointId);
|
||||
if (checkpointIds.size() > NUM_ACTOR_CHECKPOINTS_TO_KEEP) {
|
||||
((Checkpointable) actor).checkpointExpired(actorId, checkpointIds.get(0));
|
||||
@@ -192,7 +181,7 @@ public class Worker {
|
||||
Preconditions.checkArgument(checkpointValid,
|
||||
"'loadCheckpoint' must return a checkpoint ID that exists in the "
|
||||
+ "'availableCheckpoints' list, or null.");
|
||||
runtime.rayletClient.notifyActorResumedFromCheckpoint(actorId, checkpointId);
|
||||
runtime.getRayletClient().notifyActorResumedFromCheckpoint(actorId, checkpointId);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
/**
|
||||
* Language of a Ray task.
|
||||
*/
|
||||
public enum TaskLanguage {
|
||||
|
||||
JAVA,
|
||||
|
||||
PYTHON,
|
||||
}
|
||||
@@ -1,167 +0,0 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
|
||||
import org.ray.runtime.functionmanager.PyFunctionDescriptor;
|
||||
|
||||
/**
|
||||
* Represents necessary information of a task for scheduling and executing.
|
||||
*/
|
||||
public class TaskSpec {
|
||||
|
||||
// ID of the job that created this task.
|
||||
public final JobId jobId;
|
||||
|
||||
// Task ID of the task.
|
||||
public final TaskId taskId;
|
||||
|
||||
// Task ID of the parent task.
|
||||
public final TaskId parentTaskId;
|
||||
|
||||
// A count of the number of tasks submitted by the parent task before this one.
|
||||
public final int parentCounter;
|
||||
|
||||
// Id for createActor a target actor
|
||||
public final ActorId actorCreationId;
|
||||
|
||||
public final int maxActorReconstructions;
|
||||
|
||||
// Actor ID of the task. This is the actor that this task is executed on
|
||||
// or NIL_ACTOR_ID if the task is just a normal task.
|
||||
public final ActorId actorId;
|
||||
|
||||
// ID per actor client for session consistency
|
||||
public final UniqueId actorHandleId;
|
||||
|
||||
// Number of tasks that have been submitted to this actor so far.
|
||||
public final int actorCounter;
|
||||
|
||||
// Object id returned by the previous task submitted to the same actor.
|
||||
public final ObjectId previousActorTaskDummyObjectId;
|
||||
|
||||
// Task arguments.
|
||||
public final UniqueId[] newActorHandles;
|
||||
|
||||
// Task arguments.
|
||||
public final FunctionArg[] args;
|
||||
|
||||
// number of return objects.
|
||||
public final int numReturns;
|
||||
|
||||
// Return ids.
|
||||
public final ObjectId[] returnIds;
|
||||
|
||||
// The task's resource demands.
|
||||
public final Map<String, Double> resources;
|
||||
|
||||
// Language of this task.
|
||||
public final TaskLanguage language;
|
||||
|
||||
public final List<String> dynamicWorkerOptions;
|
||||
|
||||
// Descriptor of the remote function.
|
||||
// Note, if task language is Java, the type is JavaFunctionDescriptor. If the task language
|
||||
// is Python, the type is PyFunctionDescriptor.
|
||||
private final FunctionDescriptor functionDescriptor;
|
||||
|
||||
public boolean isActorTask() {
|
||||
return !actorId.isNil();
|
||||
}
|
||||
|
||||
public boolean isActorCreationTask() {
|
||||
return !actorCreationId.isNil();
|
||||
}
|
||||
|
||||
public TaskSpec(
|
||||
JobId jobId,
|
||||
TaskId taskId,
|
||||
TaskId parentTaskId,
|
||||
int parentCounter,
|
||||
ActorId actorCreationId,
|
||||
int maxActorReconstructions,
|
||||
ActorId actorId,
|
||||
UniqueId actorHandleId,
|
||||
int actorCounter,
|
||||
ObjectId previousActorTaskDummyObjectId,
|
||||
UniqueId[] newActorHandles,
|
||||
FunctionArg[] args,
|
||||
int numReturns,
|
||||
Map<String, Double> resources,
|
||||
TaskLanguage language,
|
||||
FunctionDescriptor functionDescriptor,
|
||||
List<String> dynamicWorkerOptions) {
|
||||
this.jobId = jobId;
|
||||
this.taskId = taskId;
|
||||
this.parentTaskId = parentTaskId;
|
||||
this.parentCounter = parentCounter;
|
||||
this.actorCreationId = actorCreationId;
|
||||
this.maxActorReconstructions = maxActorReconstructions;
|
||||
this.actorId = actorId;
|
||||
this.actorHandleId = actorHandleId;
|
||||
this.actorCounter = actorCounter;
|
||||
this.previousActorTaskDummyObjectId = previousActorTaskDummyObjectId;
|
||||
this.newActorHandles = newActorHandles;
|
||||
this.args = args;
|
||||
this.numReturns = numReturns;
|
||||
this.dynamicWorkerOptions = dynamicWorkerOptions;
|
||||
|
||||
returnIds = new ObjectId[numReturns];
|
||||
for (int i = 0; i < numReturns; ++i) {
|
||||
returnIds[i] = ObjectId.forReturn(taskId, i + 1);
|
||||
}
|
||||
this.resources = resources;
|
||||
this.language = language;
|
||||
if (language == TaskLanguage.JAVA) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof JavaFunctionDescriptor,
|
||||
"Expect JavaFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else if (language == TaskLanguage.PYTHON) {
|
||||
Preconditions.checkArgument(functionDescriptor instanceof PyFunctionDescriptor,
|
||||
"Expect PyFunctionDescriptor type, but got {}.", functionDescriptor.getClass());
|
||||
} else {
|
||||
Preconditions.checkArgument(false, "Unknown task language: {}.", language);
|
||||
}
|
||||
this.functionDescriptor = functionDescriptor;
|
||||
}
|
||||
|
||||
public JavaFunctionDescriptor getJavaFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.JAVA);
|
||||
return (JavaFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
public PyFunctionDescriptor getPyFunctionDescriptor() {
|
||||
Preconditions.checkState(language == TaskLanguage.PYTHON);
|
||||
return (PyFunctionDescriptor) functionDescriptor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "TaskSpec{" +
|
||||
"jobId=" + jobId +
|
||||
", taskId=" + taskId +
|
||||
", parentTaskId=" + parentTaskId +
|
||||
", parentCounter=" + parentCounter +
|
||||
", actorCreationId=" + actorCreationId +
|
||||
", maxActorReconstructions=" + maxActorReconstructions +
|
||||
", actorId=" + actorId +
|
||||
", actorHandleId=" + actorHandleId +
|
||||
", actorCounter=" + actorCounter +
|
||||
", previousActorTaskDummyObjectId=" + previousActorTaskDummyObjectId +
|
||||
", newActorHandles=" + Arrays.toString(newActorHandles) +
|
||||
", args=" + Arrays.toString(args) +
|
||||
", numReturns=" + numReturns +
|
||||
", resources=" + resources +
|
||||
", language=" + language +
|
||||
", functionDescriptor=" + functionDescriptor +
|
||||
", dynamicWorkerOptions=" + dynamicWorkerOptions +
|
||||
'}';
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package org.ray.runtime.task;
|
||||
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.CallOptions;
|
||||
import org.ray.runtime.functionmanager.FunctionDescriptor;
|
||||
|
||||
/**
|
||||
* A set of methods to submit tasks and create actors.
|
||||
*/
|
||||
public interface TaskSubmitter {
|
||||
|
||||
/**
|
||||
* Submit a normal task.
|
||||
* @param functionDescriptor The remote function to execute.
|
||||
* @param args Arguments of this task.
|
||||
* @param numReturns Return object count.
|
||||
* @param options Options for this task.
|
||||
* @return Ids of the return objects.
|
||||
*/
|
||||
List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
int numReturns, CallOptions options);
|
||||
|
||||
/**
|
||||
* Create an actor.
|
||||
* @param functionDescriptor The remote function that generates the actor object.
|
||||
* @param args Arguments of this task.
|
||||
* @param options Options for this actor creation task.
|
||||
* @return Handle to the actor.
|
||||
*/
|
||||
RayActor createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
|
||||
ActorCreationOptions options);
|
||||
|
||||
/**
|
||||
* Submit an actor task.
|
||||
* @param actor Handle to the actor.
|
||||
* @param functionDescriptor The remote function to execute.
|
||||
* @param args Arguments of this task.
|
||||
* @param numReturns Return object count.
|
||||
* @param options Options for this task.
|
||||
* @return Ids of the return objects.
|
||||
*/
|
||||
List<ObjectId> submitActorTask(RayActor actor, FunctionDescriptor functionDescriptor,
|
||||
List<FunctionArg> args, int numReturns, CallOptions options);
|
||||
}
|
||||
@@ -1,28 +1,13 @@
|
||||
package org.ray.runtime.util;
|
||||
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.ray.api.id.BaseId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.api.id.ActorId;
|
||||
|
||||
/**
|
||||
* Helper method for different Ids.
|
||||
* Note: any changes to these methods must be synced with C++ helper functions
|
||||
* in src/ray/common/id.h
|
||||
* Helper method for different Ids. Note: any changes to these methods must be synced with C++
|
||||
* helper functions in src/ray/common/id.h
|
||||
*/
|
||||
public class IdUtil {
|
||||
|
||||
public static <T extends BaseId> byte[][] getIdBytes(List<T> objectIds) {
|
||||
int size = objectIds.size();
|
||||
byte[][] ids = new byte[size][];
|
||||
for (int i = 0; i < size; i++) {
|
||||
ids[i] = objectIds.get(i).getBytes();
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the murmur hash code of this ID.
|
||||
*/
|
||||
@@ -43,14 +28,14 @@ public class IdUtil {
|
||||
|
||||
for (int i = 0; i < length8; i++) {
|
||||
final int i8 = i * 8;
|
||||
long k = ((long)data[i8] & 0xff)
|
||||
+ (((long)data[i8 + 1] & 0xff) << 8)
|
||||
+ (((long)data[i8 + 2] & 0xff) << 16)
|
||||
+ (((long)data[i8 + 3] & 0xff) << 24)
|
||||
+ (((long)data[i8 + 4] & 0xff) << 32)
|
||||
+ (((long)data[i8 + 5] & 0xff) << 40)
|
||||
+ (((long)data[i8 + 6] & 0xff) << 48)
|
||||
+ (((long)data[i8 + 7] & 0xff) << 56);
|
||||
long k = ((long) data[i8] & 0xff)
|
||||
+ (((long) data[i8 + 1] & 0xff) << 8)
|
||||
+ (((long) data[i8 + 2] & 0xff) << 16)
|
||||
+ (((long) data[i8 + 3] & 0xff) << 24)
|
||||
+ (((long) data[i8 + 4] & 0xff) << 32)
|
||||
+ (((long) data[i8 + 5] & 0xff) << 40)
|
||||
+ (((long) data[i8 + 6] & 0xff) << 48)
|
||||
+ (((long) data[i8 + 7] & 0xff) << 56);
|
||||
|
||||
k *= m;
|
||||
k ^= k >>> r;
|
||||
@@ -90,16 +75,4 @@ public class IdUtil {
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
/*
|
||||
* A helper function to compute actor creation dummy object id according
|
||||
* the given actor id.
|
||||
*/
|
||||
public static ObjectId computeActorCreationDummyObjectId(ActorId actorId) {
|
||||
byte[] bytes = new byte[ObjectId.LENGTH];
|
||||
System.arraycopy(actorId.getBytes(), 0, bytes, 0, ActorId.LENGTH);
|
||||
Arrays.fill(bytes, ActorId.LENGTH, bytes.length, (byte) 0xFF);
|
||||
return ObjectId.fromByteBuffer(ByteBuffer.wrap(bytes));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
package org.ray.runtime.util;
|
||||
|
||||
import org.nustaq.serialization.FSTConfiguration;
|
||||
import org.ray.runtime.RayActorImpl;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.actor.NativeRayActorSerializer;
|
||||
|
||||
/**
|
||||
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
|
||||
@@ -10,7 +11,7 @@ public class Serializer {
|
||||
|
||||
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
|
||||
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
|
||||
conf.registerSerializer(RayActorImpl.class, new RayActorSerializer(), true);
|
||||
conf.registerSerializer(NativeRayActor.class, new NativeRayActorSerializer(), true);
|
||||
return conf;
|
||||
});
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import java.io.File;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import javax.tools.JavaCompiler;
|
||||
import javax.tools.ToolProvider;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
@@ -14,7 +13,6 @@ import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.function.RayFunc0;
|
||||
import org.ray.api.function.RayFunc1;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.functionmanager.FunctionManager.JobFunctionTable;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.BeforeClass;
|
||||
|
||||
+4
-2
@@ -1,5 +1,8 @@
|
||||
package org.ray.streaming.schedule.impl;
|
||||
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.actor.LocalModeRayActor;
|
||||
import org.ray.streaming.api.partition.impl.RoundRobinPartition;
|
||||
import org.ray.streaming.core.graph.ExecutionEdge;
|
||||
import org.ray.streaming.core.graph.ExecutionGraph;
|
||||
@@ -12,7 +15,6 @@ import org.ray.streaming.schedule.ITaskAssign;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.ray.api.RayActor;
|
||||
import org.ray.runtime.RayActorImpl;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
@@ -29,7 +31,7 @@ public class TaskAssignImplTest {
|
||||
|
||||
List<RayActor<StreamWorker>> workers = new ArrayList<>();
|
||||
for(int i = 0; i < plan.getPlanVertexList().size(); i++) {
|
||||
workers.add(new RayActorImpl<>());
|
||||
workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom()));
|
||||
}
|
||||
|
||||
ITaskAssign taskAssign = new TaskAssignImpl();
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayActor;
|
||||
@@ -10,8 +12,8 @@ import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.exception.UnreconstructableException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.RayActorImpl;
|
||||
import org.ray.runtime.objectstore.NativeRayObject;
|
||||
import org.ray.runtime.actor.NativeRayActor;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -72,6 +74,12 @@ public class ActorTest extends BaseTest {
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@RayRemote
|
||||
public static int testActorAsFieldOfParameter(List<RayActor<Counter>> actor, int delta) {
|
||||
RayObject<Integer> res = Ray.call(Counter::increase, actor.get(0), delta);
|
||||
return res.get();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPassActorAsParameter() {
|
||||
RayActor<Counter> actor = Ray.createActor(Counter::new, 0);
|
||||
@@ -79,13 +87,17 @@ public class ActorTest extends BaseTest {
|
||||
Ray.call(ActorTest::testActorAsFirstParameter, actor, 1).get());
|
||||
Assert.assertEquals(Integer.valueOf(11),
|
||||
Ray.call(ActorTest::testActorAsSecondParameter, 10, actor).get());
|
||||
Assert.assertEquals(Integer.valueOf(111),
|
||||
Ray.call(ActorTest::testActorAsFieldOfParameter, Collections.singletonList(actor), 100)
|
||||
.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testForkingActorHandle() {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
RayActor<Counter> counter = Ray.createActor(Counter::new, 100);
|
||||
Assert.assertEquals(Integer.valueOf(101), Ray.call(Counter::increase, counter, 1).get());
|
||||
RayActor<Counter> counter2 = ((RayActorImpl<Counter>) counter).fork();
|
||||
RayActor<Counter> counter2 = ((NativeRayActor) counter).fork();
|
||||
Assert.assertEquals(Integer.valueOf(103), Ray.call(Counter::increase, counter2, 2).get());
|
||||
}
|
||||
|
||||
@@ -100,9 +112,8 @@ public class ActorTest extends BaseTest {
|
||||
Ray.internal().free(ImmutableList.of(value.getId()), false, false);
|
||||
// Wait until the object is deleted, because the above free operation is async.
|
||||
while (true) {
|
||||
NativeRayObject result = ((AbstractRayRuntime)
|
||||
Ray.internal()).getObjectStoreProxy().getObjectInterface()
|
||||
.get(ImmutableList.of(value.getId()), 0).get(0);
|
||||
NativeRayObject result = ((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
.getRaw(ImmutableList.of(value.getId()), 0).get(0);
|
||||
if (result == null) {
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.RayObjectImpl;
|
||||
import org.ray.runtime.object.RayObjectImpl;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.testng.Assert;
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import java.util.Arrays;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayObject;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.annotation.RayRemote;
|
||||
import org.ray.api.id.TaskId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
@@ -24,8 +26,8 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
Ray.internal().free(ImmutableList.of(helloId.getId()), true, false);
|
||||
|
||||
final boolean result = TestUtils.waitForCondition(() ->
|
||||
((AbstractRayRuntime) Ray.internal()).getObjectStoreProxy().getObjectInterface()
|
||||
.get(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50);
|
||||
((AbstractRayRuntime) Ray.internal()).getObjectStore()
|
||||
.getRaw(ImmutableList.of(helloId.getId()), 0).get(0) == null, 50);
|
||||
Assert.assertTrue(result);
|
||||
}
|
||||
|
||||
@@ -36,9 +38,10 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
Assert.assertEquals("hello", helloId.get());
|
||||
Ray.internal().free(ImmutableList.of(helloId.getId()), true, true);
|
||||
|
||||
TaskId taskId = TaskId.fromBytes(Arrays.copyOf(helloId.getId().getBytes(), TaskId.LENGTH));
|
||||
final boolean result = TestUtils.waitForCondition(
|
||||
() -> !(((AbstractRayRuntime)Ray.internal()).getGcsClient())
|
||||
.rayletTaskExistsInGcs(helloId.getId().getTaskId()), 50);
|
||||
() -> !(((AbstractRayRuntime) Ray.internal()).getGcsClient())
|
||||
.rayletTaskExistsInGcs(taskId), 50);
|
||||
Assert.assertTrue(result);
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import java.util.Collections;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.TestUtils;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -15,11 +17,15 @@ public class PlasmaStoreTest extends BaseTest {
|
||||
TestUtils.skipTestUnderSingleProcess();
|
||||
ObjectId objectId = ObjectId.fromRandom();
|
||||
AbstractRayRuntime runtime = (AbstractRayRuntime) Ray.internal();
|
||||
ObjectStoreProxy objectInterface = runtime.getObjectStoreProxy();
|
||||
objectInterface.put(objectId, 1);
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId), (Integer) 1);
|
||||
objectInterface.put(objectId, 2);
|
||||
ObjectStore objectStore = runtime.getObjectStore();
|
||||
objectStore.putRaw(new NativeRayObject(new byte[]{1}, null), objectId);
|
||||
Assert.assertEquals(
|
||||
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
|
||||
(byte) 1);
|
||||
objectStore.putRaw(new NativeRayObject(new byte[]{2}, null), objectId);
|
||||
// Putting 2 objects with duplicate ID should fail but ignored.
|
||||
Assert.assertEquals(objectInterface.<Integer>get(objectId), (Integer) 1);
|
||||
Assert.assertEquals(
|
||||
objectStore.getRaw(Collections.singletonList(objectId), -1).get(0).data[0],
|
||||
(byte) 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
package org.ray.api.test;
|
||||
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RayPyActor;
|
||||
import org.ray.api.id.ActorId;
|
||||
import org.ray.api.id.JobId;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.RayPyActorImpl;
|
||||
import org.ray.runtime.util.Serializer;
|
||||
import org.ray.api.id.ObjectId;
|
||||
import org.ray.runtime.AbstractRayRuntime;
|
||||
import org.ray.runtime.object.NativeRayObject;
|
||||
import org.ray.runtime.object.ObjectStore;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
public class RaySerializerTest {
|
||||
public class RaySerializerTest extends BaseMultiLanguageTest {
|
||||
|
||||
@Test
|
||||
public void testSerializePyActor() {
|
||||
final ActorId pyActorId = ActorId.generateActorId(JobId.fromInt(1));
|
||||
RayPyActor pyActor = new RayPyActorImpl(pyActorId, "test", "RaySerializerTest");
|
||||
byte[] bytes = Serializer.encode(pyActor);
|
||||
RayPyActor result = Serializer.decode(bytes);
|
||||
Assert.assertEquals(result.getId(), pyActorId);
|
||||
RayPyActor pyActor = Ray.createPyActor("test", "RaySerializerTest");
|
||||
ObjectStore objectStore = ((AbstractRayRuntime) Ray.internal()).getObjectStore();
|
||||
NativeRayObject nativeRayObject = objectStore.serialize(pyActor);
|
||||
RayPyActor result = (RayPyActor) objectStore
|
||||
.deserialize(nativeRayObject, ObjectId.fromRandom());
|
||||
Assert.assertEquals(result.getId(), pyActor.getId());
|
||||
Assert.assertEquals(result.getModuleName(), "test");
|
||||
Assert.assertEquals(result.getClassName(), "RaySerializerTest");
|
||||
}
|
||||
|
||||
@@ -50,36 +50,10 @@ public class UniqueIdTest {
|
||||
Assert.assertTrue(id6.isNil());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputeReturnId() {
|
||||
// Mock a taskId, and the lowest 4 bytes should be 0.
|
||||
TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE");
|
||||
|
||||
ObjectId returnId = ObjectId.forReturn(taskId, 1);
|
||||
Assert.assertEquals("123456789abcde123456789abcde00c001000000", returnId.toString());
|
||||
Assert.assertEquals(returnId.getTaskId(), taskId);
|
||||
|
||||
returnId = ObjectId.forReturn(taskId, 0x01020304);
|
||||
Assert.assertEquals("123456789abcde123456789abcde00c004030201", returnId.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testComputePutId() {
|
||||
// Mock a taskId, the lowest 4 bytes should be 0.
|
||||
TaskId taskId = TaskId.fromHexString("123456789ABCDE123456789ABCDE");
|
||||
|
||||
ObjectId putId = ObjectId.forPut(taskId, 1);
|
||||
Assert.assertEquals("123456789abcde123456789abcde008001000000".toLowerCase(), putId.toString());
|
||||
|
||||
putId = ObjectId.forPut(taskId, 0x01020304);
|
||||
Assert.assertEquals("123456789abcde123456789abcde008004030201".toLowerCase(), putId.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMurmurHash() {
|
||||
UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232");
|
||||
long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000);
|
||||
Assert.assertEquals(remainder, 787616861);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user