[Java] Format ray java code (#13056)

This commit is contained in:
chaokunyang
2020-12-29 10:36:16 +08:00
committed by GitHub
parent cc1c2c3dc9
commit d1dd3410c8
422 changed files with 4384 additions and 5035 deletions
@@ -45,9 +45,7 @@ import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Core functionality to implement Ray APIs.
*/
/** Core functionality to implement Ray APIs. */
public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private static final Logger LOGGER = LoggerFactory.getLogger(AbstractRayRuntime.class);
@@ -63,9 +61,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
protected TaskSubmitter taskSubmitter;
protected WorkerContext workerContext;
/**
* Whether the required thread context is set on the current thread.
*/
/** Whether the required thread context is set on the current thread. */
final ThreadLocal<Boolean> isContextSet = ThreadLocal.withInitial(() -> false);
public AbstractRayRuntime(RayConfig rayConfig) {
@@ -78,7 +74,7 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public <T> ObjectRef<T> put(T obj) {
ObjectId objectId = objectStore.put(obj);
return new ObjectRefImpl<T>(objectId, (Class<T>)(obj == null ? Object.class : obj.getClass()));
return new ObjectRefImpl<T>(objectId, (Class<T>) (obj == null ? Object.class : obj.getClass()));
}
@Override
@@ -101,8 +97,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(
Collectors.toList()), localOnly);
objectStore.delete(
objectRefs.stream()
.map(ref -> ((ObjectRefImpl<?>) ref).getId())
.collect(Collectors.toList()),
localOnly);
}
@Override
@@ -120,13 +119,11 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public ObjectRef call(PyFunction pyFunction, Object[] args, CallOptions options) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyFunction.moduleName,
"",
pyFunction.functionName);
PyFunctionDescriptor functionDescriptor =
new PyFunctionDescriptor(pyFunction.moduleName, "", pyFunction.functionName);
// Python functions always have a return value, even if it's `None`.
return callNormalFunction(functionDescriptor, args,
/*returnType=*/Optional.of(pyFunction.returnType), options);
return callNormalFunction(
functionDescriptor, args, /*returnType=*/ Optional.of(pyFunction.returnType), options);
}
@Override
@@ -139,17 +136,18 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
@Override
public ObjectRef callActor(PyActorHandle pyActor, PyActorMethod pyActorMethod, Object... args) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(pyActor.getModuleName(),
pyActor.getClassName(), pyActorMethod.methodName);
PyFunctionDescriptor functionDescriptor =
new PyFunctionDescriptor(
pyActor.getModuleName(), pyActor.getClassName(), pyActorMethod.methodName);
// Python functions always have a return value, even if it's `None`.
return callActorFunction(pyActor, functionDescriptor, args,
/*returnType=*/Optional.of(pyActorMethod.returnType));
return callActorFunction(
pyActor, functionDescriptor, args, /*returnType=*/ Optional.of(pyActorMethod.returnType));
}
@Override
@SuppressWarnings("unchecked")
public <T> ActorHandle<T> createActor(RayFunc actorFactoryFunc,
Object[] args, ActorCreationOptions options) {
public <T> ActorHandle<T> createActor(
RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options) {
FunctionDescriptor functionDescriptor =
functionManager.getFunction(workerContext.getCurrentJobId(), actorFactoryFunc)
.functionDescriptor;
@@ -157,24 +155,24 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
@Override
public PyActorHandle createActor(PyActorClass pyActorClass, Object[] args,
ActorCreationOptions options) {
PyFunctionDescriptor functionDescriptor = new PyFunctionDescriptor(
pyActorClass.moduleName,
pyActorClass.className,
PYTHON_INIT_METHOD_NAME);
public PyActorHandle createActor(
PyActorClass pyActorClass, Object[] args, ActorCreationOptions options) {
PyFunctionDescriptor functionDescriptor =
new PyFunctionDescriptor(
pyActorClass.moduleName, pyActorClass.className, PYTHON_INIT_METHOD_NAME);
return (PyActorHandle) createActorImpl(functionDescriptor, args, options);
}
@Override
public PlacementGroup createPlacementGroup(String name,
List<Map<String, Double>> bundles, PlacementStrategy strategy) {
boolean bundleResourceValid = bundles.stream().allMatch(
bundle -> bundle.values().stream().allMatch(resource -> resource > 0));
public PlacementGroup createPlacementGroup(
String name, List<Map<String, Double>> bundles, PlacementStrategy strategy) {
boolean bundleResourceValid =
bundles.stream()
.allMatch(bundle -> bundle.values().stream().allMatch(resource -> resource > 0));
if (bundles.isEmpty() || !bundleResourceValid) {
throw new IllegalArgumentException(
"Bundles cannot be empty or bundle's resource must be positive.");
"Bundles cannot be empty or bundle's resource must be positive.");
}
return taskSubmitter.createPlacementGroup(name, bundles, strategy);
}
@@ -264,12 +262,13 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
private ObjectRef callNormalFunction(
FunctionDescriptor functionDescriptor,
Object[] args, Optional<Class<?>> returnType, CallOptions options) {
Object[] args,
Optional<Class<?>> returnType,
CallOptions options) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds = taskSubmitter.submitTask(functionDescriptor,
functionArgs, numReturns, options);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds =
taskSubmitter.submitTask(functionDescriptor, functionArgs, numReturns, options);
Preconditions.checkState(returnIds.size() == numReturns);
if (returnIds.isEmpty()) {
return null;
@@ -284,10 +283,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
Object[] args,
Optional<Class<?>> returnType) {
int numReturns = returnType.isPresent() ? 1 : 0;
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds = taskSubmitter.submitActorTask(rayActor,
functionDescriptor, functionArgs, numReturns, null);
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
List<ObjectId> returnIds =
taskSubmitter.submitActorTask(rayActor, functionDescriptor, functionArgs, numReturns, null);
Preconditions.checkState(returnIds.size() == numReturns);
if (returnIds.isEmpty()) {
return null;
@@ -296,10 +294,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
}
private BaseActorHandle createActorImpl(FunctionDescriptor functionDescriptor,
Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder
.wrap(args, functionDescriptor.getLanguage());
private BaseActorHandle createActorImpl(
FunctionDescriptor functionDescriptor, Object[] args, ActorCreationOptions options) {
List<FunctionArg> functionArgs = ArgumentsBuilder.wrap(args, functionDescriptor.getLanguage());
if (functionDescriptor.getLanguage() != Language.JAVA && options != null) {
Preconditions.checkState(Strings.isNullOrEmpty(options.jvmOptions));
}
@@ -9,9 +9,7 @@ import io.ray.runtime.util.LoggingUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The default Ray runtime factory. It produces an instance of RayRuntime.
*/
/** The default Ray runtime factory. It produces an instance of RayRuntime. */
public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
@Override
@@ -22,19 +20,22 @@ public class DefaultRayRuntimeFactory implements RayRuntimeFactory {
if (rayConfig.workerMode == WorkerType.WORKER) {
// Handle the uncaught exceptions thrown from user-spawned threads.
Thread.setDefaultUncaughtExceptionHandler((Thread t, Throwable e) -> {
logger.error(String.format("Uncaught worker exception in thread %s", t), e);
});
Thread.setDefaultUncaughtExceptionHandler(
(Thread t, Throwable e) -> {
logger.error(String.format("Uncaught worker exception in thread %s", t), e);
});
}
try {
logger.debug("Initializing runtime with config: {}", rayConfig);
AbstractRayRuntime innerRuntime = rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime = rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
AbstractRayRuntime innerRuntime =
rayConfig.runMode == RunMode.SINGLE_PROCESS
? new RayDevRuntime(rayConfig)
: new RayNativeRuntime(rayConfig);
RayRuntimeInternal runtime =
rayConfig.numWorkersPerProcess > 1
? RayRuntimeProxy.newInstance(innerRuntime)
: innerRuntime;
runtime.start();
return runtime;
} catch (Exception e) {
@@ -35,14 +35,15 @@ public class RayDevRuntime extends AbstractRayRuntime {
taskExecutor = new LocalModeTaskExecutor(this);
workerContext = new LocalModeWorkerContext(rayConfig.getJobId());
objectStore = new LocalModeObjectStore(workerContext);
taskSubmitter = new LocalModeTaskSubmitter(this, taskExecutor,
(LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore).addObjectPutCallback(
objectId -> {
if (taskSubmitter != null) {
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
}
});
taskSubmitter =
new LocalModeTaskSubmitter(this, taskExecutor, (LocalModeObjectStore) objectStore);
((LocalModeObjectStore) objectStore)
.addObjectPutCallback(
objectId -> {
if (taskSubmitter != null) {
((LocalModeTaskSubmitter) taskSubmitter).onObjectPut(objectId);
}
});
}
@Override
@@ -72,7 +73,7 @@ public class RayDevRuntime extends AbstractRayRuntime {
@SuppressWarnings("unchecked")
@Override
public <T extends BaseActorHandle> Optional<T> getActor(String name, boolean global) {
return (Optional<T>) ((LocalModeTaskSubmitter)taskSubmitter).getActor(name, global);
return (Optional<T>) ((LocalModeTaskSubmitter) taskSubmitter).getActor(name, global);
}
@Override
@@ -87,24 +88,21 @@ public class RayDevRuntime extends AbstractRayRuntime {
}
@Override
public PlacementGroup getPlacementGroup(
PlacementGroupId id) {
//@TODO(clay4444): We need a LocalGcsClient before implements this.
public PlacementGroup getPlacementGroup(PlacementGroupId id) {
// @TODO(clay4444): We need a LocalGcsClient before implements this.
throw new UnsupportedOperationException(
"Ray doesn't support placement group operations in local mode.");
"Ray doesn't support placement group operations in local mode.");
}
@Override
public List<PlacementGroup> getAllPlacementGroups() {
//@TODO(clay4444): We need a LocalGcsClient before implements this.
// @TODO(clay4444): We need a LocalGcsClient before implements this.
throw new UnsupportedOperationException(
"Ray doesn't support placement group operations in local mode.");
"Ray doesn't support placement group operations in local mode.");
}
@Override
public void exitActor() {
}
public void exitActor() {}
private JobId nextJobId() {
return JobId.fromInt(jobCounter.getAndIncrement());
@@ -29,9 +29,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Native runtime for cluster mode.
*/
/** Native runtime for cluster mode. */
public final class RayNativeRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
@@ -39,10 +37,9 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private boolean startRayHead = false;
/**
* In Java, GC runs in a standalone thread, and we can't control the exact
* timing of garbage collection. By using this lock, when
* {@link NativeObjectStore#nativeRemoveLocalReference} is executing, the core
* worker will not be shut down, therefore it guarantees some kind of
* In Java, GC runs in a standalone thread, and we can't control the exact timing of garbage
* collection. By using this lock, when {@link NativeObjectStore#nativeRemoveLocalReference} is
* executing, the core worker will not be shut down, therefore it guarantees some kind of
* thread-safety. Note that this guarantee only works for driver.
*/
private final ReadWriteLock shutdownLock = new ReentrantReadWriteLock();
@@ -117,21 +114,29 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayletConfigStringMap.put(entry.getKey(), entry.getValue().toString());
}
nativeInitialize(rayConfig.workerMode.getNumber(),
rayConfig.nodeIp, rayConfig.getNodeManagerPort(),
nativeInitialize(
rayConfig.workerMode.getNumber(),
rayConfig.nodeIp,
rayConfig.getNodeManagerPort(),
rayConfig.workerMode == WorkerType.DRIVER ? System.getProperty("user.dir") : "",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName,
rayConfig.objectStoreSocketName,
rayConfig.rayletSocketName,
(rayConfig.workerMode == WorkerType.DRIVER ? rayConfig.getJobId() : JobId.NIL).getBytes(),
new GcsClientOptions(rayConfig), numWorkersPerProcess,
rayConfig.logDir, rayletConfigStringMap, serializedJobConfig);
new GcsClientOptions(rayConfig),
numWorkersPerProcess,
rayConfig.logDir,
rayletConfigStringMap,
serializedJobConfig);
taskExecutor = new NativeTaskExecutor(this);
workerContext = new NativeWorkerContext();
objectStore = new NativeObjectStore(workerContext, shutdownLock);
taskSubmitter = new NativeTaskSubmitter();
LOGGER.debug("RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
LOGGER.debug(
"RayNativeRuntime started with store {}, raylet {}",
rayConfig.objectStoreSocketName,
rayConfig.rayletSocketName);
} catch (Exception e) {
if (startRayHead) {
try {
@@ -201,8 +206,8 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
@Override
public Object getAsyncContext() {
return new AsyncContext(workerContext.getCurrentWorkerId(),
workerContext.getCurrentClassLoader());
return new AsyncContext(
workerContext.getCurrentWorkerId(), workerContext.getCurrentClassLoader());
}
@Override
@@ -229,10 +234,18 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
private static native void nativeInitialize(
int workerMode, String ndoeIpAddress,
int nodeManagerPort, String driverName, String storeSocket, String rayletSocket,
byte[] jobId, GcsClientOptions gcsClientOptions, int numWorkersPerProcess,
String logDir, Map<String, String> rayletConfigParameters, byte[] serializedJobConfig);
int workerMode,
String ndoeIpAddress,
int nodeManagerPort,
String driverName,
String storeSocket,
String rayletSocket,
byte[] jobId,
GcsClientOptions gcsClientOptions,
int numWorkersPerProcess,
String logDir,
Map<String, String> rayletConfigParameters,
byte[] serializedJobConfig);
private static native void nativeRunTaskExecutor(TaskExecutor taskExecutor);
@@ -8,14 +8,10 @@ import io.ray.runtime.gcs.GcsClient;
import io.ray.runtime.object.ObjectStore;
import io.ray.runtime.task.TaskExecutor;
/**
* This interface is required to make {@link RayRuntimeProxy} work.
*/
/** This interface is required to make {@link RayRuntimeProxy} work. */
public interface RayRuntimeInternal extends RayRuntime {
/**
* Start runtime.
*/
/** Start runtime. */
void start();
WorkerContext getWorkerContext();
@@ -13,9 +13,7 @@ import java.lang.reflect.Method;
*/
public class RayRuntimeProxy implements InvocationHandler {
/**
* The original runtime.
*/
/** The original runtime. */
private AbstractRayRuntime obj;
private RayRuntimeProxy(AbstractRayRuntime obj) {
@@ -26,19 +24,20 @@ public class RayRuntimeProxy implements InvocationHandler {
return obj;
}
/**
* Generate a new instance of {@link RayRuntimeInternal} with additional context check.
*/
/** Generate a new instance of {@link RayRuntimeInternal} with additional context check. */
static RayRuntimeInternal newInstance(AbstractRayRuntime obj) {
return (RayRuntimeInternal) java.lang.reflect.Proxy
.newProxyInstance(obj.getClass().getClassLoader(), new Class<?>[]{RayRuntimeInternal.class},
return (RayRuntimeInternal)
java.lang.reflect.Proxy.newProxyInstance(
obj.getClass().getClassLoader(),
new Class<?>[] {RayRuntimeInternal.class},
new RayRuntimeProxy(obj));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (isInterfaceMethod(method) && !method.getName().equals("shutdown") && !method.getName()
.equals("setAsyncContext")) {
if (isInterfaceMethod(method)
&& !method.getName().equals("shutdown")
&& !method.getName().equals("setAsyncContext")) {
checkIsContextSet();
}
try {
@@ -52,9 +51,7 @@ public class RayRuntimeProxy implements InvocationHandler {
}
}
/**
* Whether the method is defined in the {@link RayRuntime} interface.
*/
/** Whether the method is defined in the {@link RayRuntime} interface. */
private boolean isInterfaceMethod(Method method) {
try {
RayRuntime.class.getMethod(method.getName(), method.getParameterTypes());
@@ -66,8 +63,8 @@ public class RayRuntimeProxy implements InvocationHandler {
/**
* Check if thread context is set.
* <p/>
* This method should be invoked at the beginning of most public methods of {@link RayRuntime},
*
* <p>This method should be invoked at the beginning of most public methods of {@link RayRuntime},
* otherwise the native code might crash due to thread local core worker was not set. We check it
* for {@link AbstractRayRuntime} instead of {@link RayNativeRuntime} because we want to catch the
* error even if the application runs in {@link RunMode#SINGLE_PROCESS} mode.
@@ -9,9 +9,7 @@ import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.concurrent.atomic.AtomicReference;
/**
* Implementation of actor handle for local mode.
*/
/** Implementation of actor handle for local mode. */
public class LocalModeActorHandle implements ActorHandle, Externalizable {
private ActorId actorId;
@@ -23,11 +21,8 @@ public class LocalModeActorHandle implements ActorHandle, Externalizable {
this.previousActorTaskDummyObjectId.set(previousActorTaskDummyObjectId);
}
/**
* Required by FST
*/
public LocalModeActorHandle() {
}
/** Required by FST. */
public LocalModeActorHandle() {}
@Override
public ActorId getId() {
@@ -16,9 +16,7 @@ import java.util.List;
*/
public abstract class NativeActorHandle implements BaseActorHandle, Externalizable {
/**
* ID of the actor.
*/
/** ID of the actor. */
byte[] actorId;
private Language language;
@@ -29,11 +27,8 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
this.language = language;
}
/**
* Required by FST
*/
NativeActorHandle() {
}
/** Required by FST. */
NativeActorHandle() {}
public static NativeActorHandle create(byte[] actorId) {
Language language = Language.forNumber(nativeGetLanguage(actorId));
@@ -76,7 +71,7 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
/**
* Serialize this actor handle to bytes.
*
* @return the bytes of the actor handle
* <p>Returns the bytes of the actor handle
*/
public byte[] toBytes() {
return nativeSerialize(actorId);
@@ -85,7 +80,7 @@ public abstract class NativeActorHandle implements BaseActorHandle, Externalizab
/**
* Deserialize an actor handle from bytes.
*
* @return the bytes of an actor handle
* <p>Returns the bytes of an actor handle
*/
public static NativeActorHandle fromBytes(byte[] bytes) {
byte[] actorId = nativeDeserialize(bytes);
@@ -7,20 +7,24 @@ import org.nustaq.serialization.FSTClazzInfo.FSTFieldInfo;
import org.nustaq.serialization.FSTObjectInput;
import org.nustaq.serialization.FSTObjectOutput;
/**
* To deal with serialization about {@link NativeActorHandle}.
*/
/** To deal with serialization about {@link NativeActorHandle}. */
public class NativeActorHandleSerializer extends FSTBasicObjectSerializer {
@Override
public void writeObject(FSTObjectOutput out, Object toWrite, FSTClazzInfo clzInfo,
FSTClazzInfo.FSTFieldInfo referencedBy, int streamPosition) throws IOException {
public void writeObject(
FSTObjectOutput out,
Object toWrite,
FSTClazzInfo clzInfo,
FSTClazzInfo.FSTFieldInfo referencedBy,
int streamPosition)
throws IOException {
((NativeActorHandle) toWrite).writeExternal(out);
}
@Override
public void readObject(FSTObjectInput in, Object toRead, FSTClazzInfo clzInfo,
FSTFieldInfo referencedBy) throws Exception {
public void readObject(
FSTObjectInput in, Object toRead, FSTClazzInfo clzInfo, FSTFieldInfo referencedBy)
throws Exception {
super.readObject(in, toRead, clzInfo, referencedBy);
((NativeActorHandle) toRead).readExternal(in);
}
@@ -6,18 +6,14 @@ import io.ray.runtime.generated.Common.Language;
import java.io.IOException;
import java.io.ObjectInput;
/**
* Java implementation of actor handle for cluster mode.
*/
/** Java implementation of actor handle for cluster mode. */
public class NativeJavaActorHandle extends NativeActorHandle implements ActorHandle {
NativeJavaActorHandle(byte[] actorId) {
super(actorId, Language.JAVA);
}
/**
* Required by FST
*/
/** Required by FST. */
public NativeJavaActorHandle() {
super();
}
@@ -6,18 +6,14 @@ import io.ray.runtime.generated.Common.Language;
import java.io.IOException;
import java.io.ObjectInput;
/**
* Python actor handle implementation for cluster mode.
*/
/** Python actor handle implementation for cluster mode. */
public class NativePyActorHandle extends NativeActorHandle implements PyActorHandle {
NativePyActorHandle(byte[] actorId) {
super(actorId, Language.PYTHON);
}
/**
* Required by FST
*/
/** Required by FST. */
public NativePyActorHandle() {
super();
}
@@ -20,10 +20,7 @@ import org.apache.commons.lang3.BooleanUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
/**
* Configurations of Ray runtime.
* See `ray.default.conf` for the meaning of each field.
*/
/** Configurations of Ray runtime. See `ray.default.conf` for the meaning of each field. */
public class RayConfig {
public static final String DEFAULT_CONFIG_FILE = "ray.default.conf";
@@ -31,10 +28,9 @@ public class RayConfig {
private Config config;
/**
* IP of this node. if not provided, IP will be automatically detected.
*/
/** IP of this node. if not provided, IP will be automatically detected. */
public final String nodeIp;
public final WorkerType workerMode;
public final RunMode runMode;
private JobId jobId;
@@ -64,8 +60,8 @@ public class RayConfig {
private void validate() {
if (workerMode == WorkerType.WORKER) {
Preconditions.checkArgument(redisAddress != null,
"Redis address must be set in worker mode.");
Preconditions.checkArgument(
redisAddress != null, "Redis address must be set in worker mode.");
}
}
@@ -145,7 +141,8 @@ public class RayConfig {
if (config.hasPath("ray.raylet.node-manager-port")) {
nodeManagerPort = config.getInt("ray.raylet.node-manager-port");
} else {
Preconditions.checkState(workerMode != WorkerType.WORKER,
Preconditions.checkState(
workerMode != WorkerType.WORKER,
"Worker started by raylet should accept the node manager port from raylet.");
}
@@ -217,9 +214,7 @@ public class RayConfig {
return config;
}
/**
* Renders the config value as a HOCON string.
*/
/** Renders the config value as a HOCON string. */
@Override
public String toString() {
// These items might be dynamically generated or mutated at runtime.
@@ -257,10 +252,8 @@ public class RayConfig {
}
/**
* Create a RayConfig by reading configuration in the following order:
* 1. System properties.
* 2. `ray.conf` file.
* 3. `ray.default.conf` file.
* Create a RayConfig by reading configuration in the following order: 1. System properties. 2.
* `ray.conf` file. 3. `ray.default.conf` file.
*/
public static RayConfig create() {
ConfigFactory.invalidateCaches();
@@ -274,5 +267,4 @@ public class RayConfig {
config = config.withFallback(ConfigFactory.load(DEFAULT_CONFIG_FILE));
return new RayConfig(config.withOnlyPath("ray"));
}
}
@@ -3,13 +3,11 @@ package io.ray.runtime.config;
public enum RunMode {
/**
* Ray is running in one single Java process, without Raylet backend, object store, and GCS.
* It's useful for debug.
* Ray is running in one single Java process, without Raylet backend, object store, and GCS. It's
* useful for debug.
*/
SINGLE_PROCESS,
/**
* Ray is running on one or more nodes, with multiple processes.
*/
/** Ray is running on one or more nodes, with multiple processes. */
CLUSTER,
}
@@ -10,9 +10,7 @@ import io.ray.runtime.generated.Common.TaskSpec;
import io.ray.runtime.generated.Common.TaskType;
import io.ray.runtime.task.LocalModeTaskSubmitter;
/**
* Worker context for local mode.
*/
/** Worker context for local mode. */
public class LocalModeWorkerContext implements WorkerContext {
private final JobId jobId;
@@ -52,8 +50,7 @@ public class LocalModeWorkerContext implements WorkerContext {
}
@Override
public void setCurrentClassLoader(ClassLoader currentClassLoader) {
}
public void setCurrentClassLoader(ClassLoader currentClassLoader) {}
@Override
public TaskType getCurrentTaskType() {
@@ -9,9 +9,7 @@ import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskType;
import java.nio.ByteBuffer;
/**
* Worker context for cluster mode. This is a wrapper class for worker context of core worker.
*/
/** Worker context for cluster mode. This is a wrapper class for worker context of core worker. */
public class NativeWorkerContext implements WorkerContext {
private final ThreadLocal<ClassLoader> currentClassLoader = new ThreadLocal<>();
@@ -26,15 +26,16 @@ public class RuntimeContextImpl implements RuntimeContext {
@Override
public ActorId getCurrentActorId() {
ActorId actorId = runtime.getWorkerContext().getCurrentActorId();
Preconditions.checkState(actorId != null && !actorId.isNil(),
"This method should only be called from an actor.");
Preconditions.checkState(
actorId != null && !actorId.isNil(), "This method should only be called from an actor.");
return actorId;
}
@Override
public boolean wasCurrentActorRestarted() {
TaskType currentTaskType = runtime.getWorkerContext().getCurrentTaskType();
Preconditions.checkState(currentTaskType == TaskType.ACTOR_CREATION_TASK,
Preconditions.checkState(
currentTaskType == TaskType.ACTOR_CREATION_TASK,
"This method can only be called from an actor creation task.");
if (isSingleProcess()) {
return false;
@@ -7,24 +7,16 @@ import io.ray.api.id.UniqueId;
import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.generated.Common.TaskType;
/**
* The context of worker.
*/
/** The context of worker. */
public interface WorkerContext {
/**
* ID of the current worker.
*/
/** ID of the current worker. */
UniqueId getCurrentWorkerId();
/**
* ID of the current job.
*/
/** ID of the current job. */
JobId getCurrentJobId();
/**
* ID of the current actor.
*/
/** ID of the current actor. */
ActorId getCurrentActorId();
/**
@@ -33,19 +25,13 @@ public interface WorkerContext {
*/
ClassLoader getCurrentClassLoader();
/**
* Set the current class loader.
*/
/** Set the current class loader. */
void setCurrentClassLoader(ClassLoader currentClassLoader);
/**
* Type of the current task.
*/
/** Type of the current task. */
TaskType getCurrentTaskType();
/**
* ID of the current task.
*/
/** ID of the current task. */
TaskId getCurrentTaskId();
Address getRpcAddress();
@@ -7,8 +7,10 @@ public class CrossLanguageException extends RayException {
private Language language;
public CrossLanguageException(io.ray.runtime.generated.Common.RayException exception) {
super(String.format("An exception raised from %s:\n%s", exception.getLanguage().name(),
exception.getFormattedExceptionString()));
super(
String.format(
"An exception raised from %s:\n%s",
exception.getLanguage().name(), exception.getFormattedExceptionString()));
this.language = exception.getLanguage();
}
@@ -5,7 +5,7 @@ import io.ray.api.id.ActorId;
/**
* Indicates that the actor died unexpectedly before finishing a task.
*
* This exception could happen either because the actor process dies while executing a task, or
* <p>This exception could happen either because the actor process dies while executing a task, or
* because a task is submitted to a dead actor.
*/
public class RayActorException extends RayException {
@@ -17,8 +17,7 @@ public class RayActorException extends RayException {
}
public RayActorException(ActorId actorId) {
super(String.format(
"The actor %s died unexpectedly before finishing this task.", actorId));
super(String.format("The actor %s died unexpectedly before finishing this task.", actorId));
this.actorId = actorId;
}
@@ -29,5 +28,4 @@ public class RayActorException extends RayException {
public RayActorException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -16,8 +16,8 @@ public class RayException extends RuntimeException {
}
public byte[] toBytes() {
String formattedException = org.apache.commons.lang3.exception.ExceptionUtils
.getStackTrace(this);
String formattedException =
org.apache.commons.lang3.exception.ExceptionUtils.getStackTrace(this);
io.ray.runtime.generated.Common.RayException.Builder builder =
io.ray.runtime.generated.Common.RayException.newBuilder();
builder.setLanguage(Language.JAVA);
@@ -26,13 +26,12 @@ public class RayException extends RuntimeException {
return builder.build().toByteArray();
}
public static RayException fromBytes(byte[] serialized)
throws InvalidProtocolBufferException {
public static RayException fromBytes(byte[] serialized) throws InvalidProtocolBufferException {
io.ray.runtime.generated.Common.RayException exception =
io.ray.runtime.generated.Common.RayException.parseFrom(serialized);
if (exception.getLanguage() == Language.JAVA) {
return Serializer
.decode(exception.getSerializedException().toByteArray(), RayException.class);
return Serializer.decode(
exception.getSerializedException().toByteArray(), RayException.class);
} else {
return new CrossLanguageException(exception);
}
@@ -1,8 +1,6 @@
package io.ray.runtime.exception;
/**
* The exception represents that there is an intentional system exit.
*/
/** The exception represents that there is an intentional system exit. */
public class RayIntentionalSystemExitException extends RuntimeException {
public RayIntentionalSystemExitException(String message) {
@@ -6,8 +6,9 @@ import io.ray.runtime.util.SystemUtil;
public class RayTaskException extends RayException {
public RayTaskException(String message, Throwable cause) {
super(String.format("(pid=%d, ip=%s) %s",
SystemUtil.pid(), NetworkUtil.getIpAddress(null), message), cause);
super(
String.format(
"(pid=%d, ip=%s) %s", SystemUtil.pid(), NetworkUtil.getIpAddress(null), message),
cause);
}
}
@@ -1,8 +1,6 @@
package io.ray.runtime.exception;
/**
* Indicates that the worker died unexpectedly while executing a task.
*/
/** Indicates that the worker died unexpectedly while executing a task. */
public class RayWorkerException extends RayException {
public RayWorkerException() {
@@ -16,5 +14,4 @@ public class RayWorkerException extends RayException {
public RayWorkerException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -3,20 +3,20 @@ package io.ray.runtime.exception;
import io.ray.api.id.ObjectId;
/**
* Indicates that an object is lost (either evicted or explicitly deleted) and cannot be
* restarted.
* Indicates that an object is lost (either evicted or explicitly deleted) and cannot be restarted.
*
* Note, this exception only happens for actor objects. If actor's current state is after object's
* creating task, the actor cannot re-run the task to reconstruct the object.
* <p>Note, this exception only happens for actor objects. If actor's current state is after
* object's creating task, the actor cannot re-run the task to reconstruct the object.
*/
public class UnreconstructableException extends RayException {
public ObjectId objectId;
public UnreconstructableException(ObjectId objectId) {
super(String.format(
"Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.",
objectId));
super(
String.format(
"Object %s is lost (either evicted or explicitly deleted) and cannot be reconstructed.",
objectId));
this.objectId = objectId;
}
@@ -27,5 +27,4 @@ public class UnreconstructableException extends RayException {
public UnreconstructableException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -6,18 +6,14 @@ import java.util.List;
/**
* Base interface of a Ray task's function descriptor.
*
* A function descriptor is a list of strings that can uniquely describe a function. It's used to
* <p>A function descriptor is a list of strings that can uniquely describe a function. It's used to
* load a function in workers.
*/
public interface FunctionDescriptor {
/**
* @return A list of strings represents the functions.
*/
/** Returns A list of strings represents the functions. */
List<String> toList();
/**
* @return The language of the function.
*/
/** Returns The language of the function. */
Language getLanguage();
}
@@ -34,9 +34,7 @@ import org.objectweb.asm.Type;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Manages functions by job id.
*/
/** Manages functions by job id. */
public class FunctionManager {
private static final Logger LOGGER = LoggerFactory.getLogger(FunctionManager.class);
@@ -52,21 +50,16 @@ public class FunctionManager {
private static final ThreadLocal<WeakHashMap<Class<? extends RayFunc>, JavaFunctionDescriptor>>
RAY_FUNC_CACHE = ThreadLocal.withInitial(WeakHashMap::new);
/**
* Mapping from the job id to the functions that belong to this job.
*/
/** Mapping from the job id to the functions that belong to this job. */
private ConcurrentMap<JobId, JobFunctionTable> jobFunctionTables = new ConcurrentHashMap<>();
/**
* The resource path which we can load the job's jar resources.
*/
/** The resource path which we can load the job's jar resources. */
private final List<String> codeSearchPath;
/**
* Construct a FunctionManager with the specified code search path.
*
* @param codeSearchPath The specified job resource that can store the job's
* resources.
* @param codeSearchPath The specified job resource that can store the job's resources.
*/
public FunctionManager(List<String> codeSearchPath) {
this.codeSearchPath = codeSearchPath;
@@ -76,8 +69,7 @@ public class FunctionManager {
* Get the RayFunction from a RayFunc instance (a lambda).
*
* @param jobId current job id.
* @param func The lambda.
* @return A RayFunction object.
* @param func The lambda. Returns A RayFunction object.
*/
public RayFunction getFunction(JobId jobId, RayFunc func) {
JavaFunctionDescriptor functionDescriptor = RAY_FUNC_CACHE.get().get(func.getClass());
@@ -97,12 +89,10 @@ public class FunctionManager {
/**
* Get the RayFunction from a function descriptor.
*
* @param jobId Current job id.
* @param functionDescriptor The function descriptor.
* @return A RayFunction object.
* @param jobId Current job id.
* @param functionDescriptor The function descriptor. Returns A RayFunction object.
*/
public RayFunction getFunction(JobId jobId,
JavaFunctionDescriptor functionDescriptor) {
public RayFunction getFunction(JobId jobId, JavaFunctionDescriptor functionDescriptor) {
JobFunctionTable jobFunctionTable = jobFunctionTables.get(jobId);
if (jobFunctionTable == null) {
synchronized (this) {
@@ -121,30 +111,37 @@ public class FunctionManager {
if (codeSearchPath == null || codeSearchPath.isEmpty()) {
classLoader = getClass().getClassLoader();
} else {
URL[] urls = codeSearchPath.stream()
.filter(p -> StringUtils.isNotBlank(p) && Files.exists(Paths.get(p)))
.flatMap(p -> {
try {
if (!Files.isDirectory(Paths.get(p))) {
if (!p.endsWith(".jar")) {
return Stream.of(Paths.get(p).getParent().toAbsolutePath().toUri().toURL());
} else {
return Stream.of(Paths.get(p).toAbsolutePath().toUri().toURL());
}
} else {
List<URL> subUrls = new ArrayList<>();
subUrls.add(Paths.get(p).toAbsolutePath().toUri().toURL());
Collection<File> jars = FileUtils.listFiles(new File(p),
new RegexFileFilter(".*\\.jar"), DirectoryFileFilter.DIRECTORY);
for (File jar : jars) {
subUrls.add(jar.toPath().toUri().toURL());
}
return subUrls.stream();
}
} catch (MalformedURLException e) {
throw new RuntimeException(String.format("Illegal %s resource path", p));
}
}).toArray(URL[]::new);
URL[] urls =
codeSearchPath.stream()
.filter(p -> StringUtils.isNotBlank(p) && Files.exists(Paths.get(p)))
.flatMap(
p -> {
try {
if (!Files.isDirectory(Paths.get(p))) {
if (!p.endsWith(".jar")) {
return Stream.of(
Paths.get(p).getParent().toAbsolutePath().toUri().toURL());
} else {
return Stream.of(Paths.get(p).toAbsolutePath().toUri().toURL());
}
} else {
List<URL> subUrls = new ArrayList<>();
subUrls.add(Paths.get(p).toAbsolutePath().toUri().toURL());
Collection<File> jars =
FileUtils.listFiles(
new File(p),
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY);
for (File jar : jars) {
subUrls.add(jar.toPath().toUri().toURL());
}
return subUrls.stream();
}
} catch (MalformedURLException e) {
throw new RuntimeException(String.format("Illegal %s resource path", p));
}
})
.toArray(URL[]::new);
classLoader = new URLClassLoader(urls);
LOGGER.debug("Resource loaded for job {} from path {}.", jobId, urls);
}
@@ -152,18 +149,12 @@ public class FunctionManager {
return new JobFunctionTable(classLoader);
}
/**
* Manages all functions that belong to one job.
*/
/** Manages all functions that belong to one job. */
static class JobFunctionTable {
/**
* The job's corresponding class loader.
*/
/** The job's corresponding class loader. */
final ClassLoader classLoader;
/**
* Functions per class, per function name + type descriptor.
*/
/** Functions per class, per function name + type descriptor. */
ConcurrentMap<String, Map<Pair<String, String>, RayFunction>> functions;
JobFunctionTable(ClassLoader classLoader) {
@@ -187,19 +178,18 @@ public class FunctionManager {
if (func == null) {
if (classFunctions.containsKey(key)) {
throw new RuntimeException(
String.format("RayFunction %s is overloaded, the signature can't be empty.",
descriptor.toString()));
String.format(
"RayFunction %s is overloaded, the signature can't be empty.",
descriptor.toString()));
} else {
throw new RuntimeException(
String.format("RayFunction %s not found", descriptor.toString()));
String.format("RayFunction %s not found", descriptor.toString()));
}
}
return func;
}
/**
* Load all functions from a class.
*/
/** Load all functions from a class. */
Map<Pair<String, String>, RayFunction> loadFunctionsForClass(String className) {
// If RayFunction is null, the function is overloaded.
Map<Pair<String, String>, RayFunction> map = new HashMap<>();
@@ -232,8 +222,9 @@ public class FunctionManager {
final Type type =
e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e);
final String signature = type.getDescriptor();
RayFunction rayFunction = new RayFunction(e, classLoader,
new JavaFunctionDescriptor(className, methodName, signature));
RayFunction rayFunction =
new RayFunction(
e, classLoader, new JavaFunctionDescriptor(className, methodName, signature));
map.put(ImmutablePair.of(methodName, signature), rayFunction);
// For cross language call java function without signature
final Pair<String, String> emptyDescriptor = ImmutablePair.of(methodName, "");
@@ -5,22 +5,14 @@ import com.google.common.collect.ImmutableList;
import io.ray.runtime.generated.Common.Language;
import java.util.List;
/**
* Represents metadata of Java function.
*/
/** Represents metadata of Java function. */
public final class JavaFunctionDescriptor implements FunctionDescriptor {
/**
* Function's class name.
*/
/** Function's class name. */
public final String className;
/**
* Function's name.
*/
/** Function's name. */
public final String name;
/**
* Function's signature.
*/
/** Function's signature. */
public final String signature;
public JavaFunctionDescriptor(String className, String name, String signature) {
@@ -43,9 +35,9 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor {
return false;
}
JavaFunctionDescriptor that = (JavaFunctionDescriptor) o;
return Objects.equal(className, that.className) &&
Objects.equal(name, that.name) &&
Objects.equal(signature, that.signature);
return Objects.equal(className, that.className)
&& Objects.equal(name, that.name)
&& Objects.equal(signature, that.signature);
}
@Override
@@ -5,9 +5,7 @@ import io.ray.runtime.generated.Common.Language;
import java.util.Arrays;
import java.util.List;
/**
* Represents metadata of a Python function.
*/
/** Represents metadata of a Python function. */
public class PyFunctionDescriptor implements FunctionDescriptor {
public String moduleName;
@@ -36,9 +34,9 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
return false;
}
PyFunctionDescriptor that = (PyFunctionDescriptor) o;
return Objects.equal(moduleName, that.moduleName) &&
Objects.equal(className, that.className) &&
Objects.equal(functionName, that.functionName);
return Objects.equal(moduleName, that.moduleName)
&& Objects.equal(className, that.className)
&& Objects.equal(functionName, that.functionName);
}
@Override
@@ -56,4 +54,3 @@ public class PyFunctionDescriptor implements FunctionDescriptor {
return Language.PYTHON;
}
}
@@ -5,50 +5,36 @@ import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.util.Optional;
/**
* Represents a Ray function (either a Method or a Constructor in Java) and its metadata.
*/
/** Represents a Ray function (either a Method or a Constructor in Java) and its metadata. */
public class RayFunction {
/**
* The executor object, can be either a Method or a Constructor.
*/
/** The executor object, can be either a Method or a Constructor. */
public final Executable executable;
/**
* This function's class loader.
*/
/** This function's class loader. */
public final ClassLoader classLoader;
/**
* Function's metadata.
*/
/** Function's metadata. */
public final JavaFunctionDescriptor functionDescriptor;
public RayFunction(Executable executable, ClassLoader classLoader,
JavaFunctionDescriptor functionDescriptor) {
public RayFunction(
Executable executable, ClassLoader classLoader, JavaFunctionDescriptor functionDescriptor) {
this.executable = executable;
this.classLoader = classLoader;
this.functionDescriptor = functionDescriptor;
}
/**
* @return True if it's a constructor, otherwise it's a method.
*/
/** Returns True if it's a constructor, otherwise it's a method. */
public boolean isConstructor() {
return executable instanceof Constructor;
}
/**
* @return The underlying constructor object.
*/
/** Returns The underlying constructor object. */
public Constructor<?> getConstructor() {
return (Constructor<?>) executable;
}
/**
* @return The underlying method object.
*/
/** Returns The underlying method object. */
public Method getMethod() {
return (Method) executable;
}
@@ -57,9 +43,7 @@ public class RayFunction {
return functionDescriptor;
}
/**
* @return Whether this function has a return value.
*/
/** Returns Whether this function has a return value. */
public boolean hasReturn() {
if (isConstructor()) {
return true;
@@ -68,9 +52,7 @@ public class RayFunction {
}
}
/**
* @return Return type.
*/
/** Returns Return type. */
public Optional<Class<?>> getReturnType() {
if (hasReturn()) {
return Optional.of(((Method) executable).getReturnType());
@@ -20,9 +20,7 @@ import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An implementation of GcsClient.
*/
/** An implementation of GcsClient. */
public class GcsClient {
private static Logger LOGGER = LoggerFactory.getLogger(GcsClient.class);
private RedisClient primary;
@@ -35,9 +33,9 @@ public class GcsClient {
}
/**
* Get placement group by {@link PlacementGroupId}
* @param placementGroupId Id of placement group.
* @return The placement group.
* Get placement group by {@link PlacementGroupId}.
*
* @param placementGroupId Id of placement group. Returns The placement group.
*/
public PlacementGroup getPlacementGroupInfo(PlacementGroupId placementGroupId) {
byte[] result = globalStateAccessor.getPlacementGroupInfo(placementGroupId);
@@ -46,7 +44,8 @@ public class GcsClient {
/**
* Get all placement groups in this cluster.
* @return All placement groups.
*
* <p>Returns All placement groups.
*/
public List<PlacementGroup> getAllPlacementGroupInfo() {
List<byte[]> results = globalStateAccessor.getAllPlacementGroupInfo();
@@ -71,15 +70,20 @@ public class GcsClient {
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException("Received invalid protobuf data from GCS.");
}
final UniqueId nodeId = UniqueId
.fromByteBuffer(data.getNodeId().asReadOnlyByteBuffer());
final UniqueId nodeId = UniqueId.fromByteBuffer(data.getNodeId().asReadOnlyByteBuffer());
// NOTE(lingxuan.zlx): we assume no duplicated node id in fetched node list
// and it's only one final state for each node in recorded table.
NodeInfo nodeInfo = new NodeInfo(
nodeId, data.getNodeManagerAddress(), data.getNodeManagerHostname(),
data.getNodeManagerPort(), data.getObjectStoreSocketName(), data.getRayletSocketName(),
data.getState() == GcsNodeInfo.GcsNodeState.ALIVE, new HashMap<>());
NodeInfo nodeInfo =
new NodeInfo(
nodeId,
data.getNodeManagerAddress(),
data.getNodeManagerHostname(),
data.getNodeManagerPort(),
data.getObjectStoreSocketName(),
data.getRayletSocketName(),
data.getState() == GcsNodeInfo.GcsNodeState.ALIVE,
new HashMap<>());
nodes.put(nodeId, nodeInfo);
}
@@ -119,9 +123,7 @@ public class GcsClient {
return resources;
}
/**
* If the actor exists in GCS.
*/
/** If the actor exists in GCS. */
public boolean actorExists(ActorId actorId) {
byte[] result = globalStateAccessor.getActorInfo(actorId);
return result != null;
@@ -149,9 +151,7 @@ public class GcsClient {
return JobId.fromInt(jobCounter);
}
/**
* Destroy global state accessor when ray native runtime will be shutdown.
*/
/** Destroy global state accessor when ray native runtime will be shutdown. */
public void destroy() {
// Only ray shutdown should call gcs client destroy.
LOGGER.debug("Destroying global state accessor.");
@@ -3,9 +3,7 @@ package io.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import io.ray.runtime.config.RayConfig;
/**
* Options to create GCS Client.
*/
/** Options to create GCS Client. */
public class GcsClientOptions {
public String ip;
public int port;
@@ -6,18 +6,15 @@ import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.UniqueId;
import java.util.List;
/**
* `GlobalStateAccessor` is used for accessing information from GCS.
*
**/
/** `GlobalStateAccessor` is used for accessing information from GCS. */
public class GlobalStateAccessor {
// NOTE(lingxuan.zlx): this is a singleton, it can not be changed during a Ray session.
// Native pointer to the C++ GcsStateAccessor.
private Long globalStateAccessorNativePointer = 0L;
private static GlobalStateAccessor globalStateAccessor;
public static synchronized GlobalStateAccessor getInstance(String redisAddress,
String redisPassword) {
public static synchronized GlobalStateAccessor getInstance(
String redisAddress, String redisPassword) {
if (null == globalStateAccessor) {
globalStateAccessor = new GlobalStateAccessor(redisAddress, redisPassword);
}
@@ -32,8 +29,7 @@ public class GlobalStateAccessor {
}
private GlobalStateAccessor(String redisAddress, String redisPassword) {
globalStateAccessorNativePointer =
nativeCreateGlobalStateAccessor(redisAddress, redisPassword);
globalStateAccessorNativePointer = nativeCreateGlobalStateAccessor(redisAddress, redisPassword);
validateGlobalStateAccessorPointer();
connect();
}
@@ -43,13 +39,12 @@ public class GlobalStateAccessor {
}
private void validateGlobalStateAccessorPointer() {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
Preconditions.checkState(
globalStateAccessorNativePointer != 0,
"Global state accessor native pointer must not be 0.");
}
/**
* @return A list of job info with JobInfo protobuf schema.
*/
/** Returns A list of job info with JobInfo protobuf schema. */
public List<byte[]> getAllJobInfo() {
// Fetch a job list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
@@ -58,9 +53,7 @@ public class GlobalStateAccessor {
}
}
/**
* @return A list of node info with GcsNodeInfo protobuf schema.
*/
/** Returns A list of node info with GcsNodeInfo protobuf schema. */
public List<byte[]> getAllNodeInfo() {
// Fetch a node list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
@@ -70,6 +63,8 @@ public class GlobalStateAccessor {
}
/**
* Get node resource info.
*
* @param nodeId node unique id.
* @return A map of node resource info in protobuf schema.
*/
@@ -83,8 +78,8 @@ public class GlobalStateAccessor {
public byte[] getPlacementGroupInfo(PlacementGroupId placementGroupId) {
synchronized (GlobalStateAccessor.class) {
validateGlobalStateAccessorPointer();
return nativeGetPlacementGroupInfo(globalStateAccessorNativePointer,
placementGroupId.getBytes());
return nativeGetPlacementGroupInfo(
globalStateAccessorNativePointer, placementGroupId.getBytes());
}
}
@@ -102,9 +97,7 @@ public class GlobalStateAccessor {
}
}
/**
* @return A list of actor info with ActorInfo protobuf schema.
*/
/** Returns A list of actor info with ActorInfo protobuf schema. */
public List<byte[]> getAllActorInfo() {
// Fetch a actor list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
@@ -113,9 +106,7 @@ public class GlobalStateAccessor {
}
}
/**
* @return An actor info with ActorInfo protobuf schema.
*/
/** Returns An actor info with ActorInfo protobuf schema. */
public byte[] getActorInfo(ActorId actorId) {
// Fetch an actor with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
@@ -152,8 +143,7 @@ public class GlobalStateAccessor {
private native byte[] nativeGetActorInfo(long nativePtr, byte[] actorId);
private native byte[] nativeGetPlacementGroupInfo(long nativePtr,
byte[] placementGroupId);
private native byte[] nativeGetPlacementGroupInfo(long nativePtr, byte[] placementGroupId);
private native List<byte[]> nativeGetAllPlacementGroupInfo(long nativePtr);
}
@@ -7,9 +7,7 @@ import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
/**
* Redis client class.
*/
/** Redis client class. */
public class RedisClient {
private static final int JEDIS_POOL_SIZE = 1;
@@ -23,19 +21,20 @@ public class RedisClient {
public RedisClient(String redisAddress, String password) {
String[] ipAndPort = redisAddress.split(":");
if (ipAndPort.length != 2) {
throw new IllegalArgumentException("The argument redisAddress " +
"should be formatted as ip:port.");
throw new IllegalArgumentException(
"The argument redisAddress " + "should be formatted as ip:port.");
}
JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
jedisPoolConfig.setMaxTotal(JEDIS_POOL_SIZE);
if (Strings.isNullOrEmpty(password)) {
jedisPool = new JedisPool(jedisPoolConfig,
ipAndPort[0], Integer.parseInt(ipAndPort[1]), 30000);
jedisPool =
new JedisPool(jedisPoolConfig, ipAndPort[0], Integer.parseInt(ipAndPort[1]), 30000);
} else {
jedisPool = new JedisPool(jedisPoolConfig, ipAndPort[0],
Integer.parseInt(ipAndPort[1]), 30000, password);
jedisPool =
new JedisPool(
jedisPoolConfig, ipAndPort[0], Integer.parseInt(ipAndPort[1]), 30000, password);
}
}
@@ -89,7 +88,7 @@ public class RedisClient {
/**
* Return the specified elements of the list stored at the specified key.
*
* @return Multi bulk reply, specifically a list of elements in the specified range.
* <p>Returns Multi bulk reply, specifically a list of elements in the specified range.
*/
public List<byte[]> lrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
@@ -97,9 +96,7 @@ public class RedisClient {
}
}
/**
* Whether the key exists in Redis.
*/
/** Whether the key exists in Redis. */
public boolean exists(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.exists(key);
@@ -1,14 +1,11 @@
package io.ray.runtime.metric;
import com.google.common.base.Preconditions;
import java.util.Map;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.stream.Collectors;
/**
* Count measurement is mapped to count object in stats and counts the number.
*/
/** Count measurement is mapped to count object in stats and counts the number. */
public class Count extends Metric {
private DoubleAdder count;
@@ -16,8 +13,12 @@ public class Count extends Metric {
public Count(String name, String description, String unit, Map<TagKey, String> tags) {
super(name, tags);
count = new DoubleAdder();
metricNativePointer = NativeMetric.registerCountNative(name, description, unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
metricNativePointer =
NativeMetric.registerCountNative(
name,
description,
unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
Preconditions.checkState(metricNativePointer != 0, "Count native pointer must not be 0.");
}
@@ -4,16 +4,17 @@ import com.google.common.base.Preconditions;
import java.util.Map;
import java.util.stream.Collectors;
/**
* Gauge measurement is mapped to gauge object in stats and is recording the last value.
*/
/** Gauge measurement is mapped to gauge object in stats and is recording the last value. */
public class Gauge extends Metric {
public Gauge(String name, String description, String unit, Map<TagKey, String> tags) {
super(name, tags);
metricNativePointer = NativeMetric.registerGaugeNative(name, description, unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
metricNativePointer =
NativeMetric.registerGaugeNative(
name,
description,
unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
Preconditions.checkState(metricNativePointer != 0, "Gauge native pointer must not be 0.");
}
@@ -31,4 +32,3 @@ public class Gauge extends Metric {
this.value.set(value);
}
}
@@ -1,7 +1,6 @@
package io.ray.runtime.metric;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -9,23 +8,30 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
* Histogram measurement is mapped to histogram object in stats.
* In order to reduce JNI calls overhead, a memory historical window is used
* for storing transient value and we assume its max size is 100.
* Histogram measurement is mapped to histogram object in stats. In order to reduce JNI calls
* overhead, a memory historical window is used for storing transient value and we assume its max
* size is 100.
*/
public class Histogram extends Metric {
private List<Double> histogramWindow;
public static final int HISTOGRAM_WINDOW_SIZE = 100;
public Histogram(String name, String description, String unit, List<Double> boundaries,
Map<TagKey, String> tags) {
public Histogram(
String name,
String description,
String unit,
List<Double> boundaries,
Map<TagKey, String> tags) {
super(name, tags);
metricNativePointer = NativeMetric.registerHistogramNative(name, description, unit,
boundaries.stream().mapToDouble(Double::doubleValue).toArray(),
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
Preconditions.checkState(metricNativePointer != 0,
"Histogram native pointer must not be 0.");
metricNativePointer =
NativeMetric.registerHistogramNative(
name,
description,
unit,
boundaries.stream().mapToDouble(Double::doubleValue).toArray(),
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
Preconditions.checkState(metricNativePointer != 0, "Histogram native pointer must not be 0.");
histogramWindow = Collections.synchronizedList(new ArrayList<>());
}
@@ -8,8 +8,8 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
* Class metric is mapped to stats metric object in core worker.
* it must be in categories set [Gague, Count, Sum, Histogram].
* Class metric is mapped to stats metric object in core worker. it must be in categories set
* [Gague, Count, Sum, Histogram].
*/
public abstract class Metric {
protected String name;
@@ -33,9 +33,7 @@ public abstract class Metric {
// Metric data will be flushed into stats view data inside core worker immediately after
// record is called.
/**
* Flush records to stats in last aggregator.
*/
/** Flush records to stats in last aggregator. */
public void record() {
Preconditions.checkState(metricNativePointer != 0, "Metric native pointer must not be 0.");
// Get tag key list from map;
@@ -46,20 +44,22 @@ public abstract class Metric {
tagValues.add(entry.getValue());
}
// Get tag value list from map;
NativeMetric.recordNative(metricNativePointer, getAndReset(), nativeTagKeyList.stream()
.map(TagKey::getTagKey).collect(Collectors.toList()), tagValues);
NativeMetric.recordNative(
metricNativePointer,
getAndReset(),
nativeTagKeyList.stream().map(TagKey::getTagKey).collect(Collectors.toList()),
tagValues);
}
/**
* Get the value to record and then reset.
*
* @return latest updating value.
* <p>Returns latest updating value.
*/
protected abstract double getAndReset();
/**
* Update gauge value without tags.
* Update metric info for user.
* Update gauge value without tags. Update metric info for user.
*
* @param value latest value for updating
*/
@@ -69,21 +69,18 @@ public abstract class Metric {
* Update gauge value with dynamic tag values.
*
* @param value latest value for updating
* @param tags tag map
* @param tags tag map
*/
public void update(double value, Map<TagKey, String> tags) {
update(value);
this.tags = tags;
}
/**
* Deallocate object from stats and reset native pointer in null.
*/
/** Deallocate object from stats and reset native pointer in null. */
public void unregister() {
if (0 != metricNativePointer) {
NativeMetric.unregisterMetricNative(metricNativePointer);
}
metricNativePointer = 0;
}
}
@@ -2,17 +2,16 @@ package io.ray.runtime.metric;
import com.google.common.base.MoreObjects;
/**
* Configurations of the metric.
*/
/** Configurations of the metric. */
public class MetricConfig {
private static final long DEFAULT_TIME_INTERVAL_MS = 5000L;
private static final int DEFAULT_THREAD_POLL_SIZE = 1;
private static final long DEFAULT_SHUTDOWN_WAIT_TIME_MS = 3000L;
public static final MetricConfig DEFAULT_CONFIG = new MetricConfig(DEFAULT_TIME_INTERVAL_MS,
DEFAULT_THREAD_POLL_SIZE, DEFAULT_SHUTDOWN_WAIT_TIME_MS);
public static final MetricConfig DEFAULT_CONFIG =
new MetricConfig(
DEFAULT_TIME_INTERVAL_MS, DEFAULT_THREAD_POLL_SIZE, DEFAULT_SHUTDOWN_WAIT_TIME_MS);
private final long timeIntervalMs;
private final int threadPoolSize;
@@ -73,6 +72,4 @@ public class MetricConfig {
return this;
}
}
}
}
@@ -5,10 +5,9 @@ import java.util.Map;
import java.util.Objects;
/**
* MetricId represents a metric with a given type, name and tags.
* If two metrics have the same type and name but different tags(including key and value), they have
* a different MetricId. And in this way, {@link MetricRegistry} can register two metrics with same
* name but different tags.
* MetricId represents a metric with a given type, name and tags. If two metrics have the same type
* and name but different tags(including key and value), they have a different MetricId. And in this
* way, {@link MetricRegistry} can register two metrics with same name but different tags.
*/
public class MetricId {
@@ -31,9 +30,9 @@ public class MetricId {
return false;
}
MetricId metricId = (MetricId) o;
return type == metricId.type &&
Objects.equals(name, metricId.name) &&
Objects.equals(tags, metricId.tags);
return type == metricId.type
&& Objects.equals(name, metricId.name)
&& Objects.equals(tags, metricId.tags);
}
@Override
@@ -61,4 +60,4 @@ public class MetricId {
public Map<TagKey, String> getTags() {
return tags;
}
}
}
@@ -9,9 +9,7 @@ import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* MetricRegistry is a registry for metrics to be registered and updates metrics.
*/
/** MetricRegistry is a registry for metrics to be registered and updates metrics. */
public class MetricRegistry {
public static final MetricRegistry DEFAULT_REGISTRY = new MetricRegistry();
@@ -32,10 +30,15 @@ public class MetricRegistry {
synchronized (this) {
if (!isRunning) {
this.metricConfig = metricConfig;
scheduledExecutorService = new ScheduledThreadPoolExecutor(metricConfig.threadPoolSize(),
new ThreadFactoryBuilder().setNameFormat("metric-registry-%d").build());
scheduledExecutorService.scheduleAtFixedRate(this::update, metricConfig.timeIntervalMs(),
metricConfig.timeIntervalMs(), TimeUnit.MILLISECONDS);
scheduledExecutorService =
new ScheduledThreadPoolExecutor(
metricConfig.threadPoolSize(),
new ThreadFactoryBuilder().setNameFormat("metric-registry-%d").build());
scheduledExecutorService.scheduleAtFixedRate(
this::update,
metricConfig.timeIntervalMs(),
metricConfig.timeIntervalMs(),
TimeUnit.MILLISECONDS);
isRunning = true;
LOG.info("Finished startup metric registry, metricConfig is {}.", metricConfig);
}
@@ -47,15 +50,18 @@ public class MetricRegistry {
if (isRunning && scheduledExecutorService != null) {
try {
scheduledExecutorService.shutdownNow();
if (!scheduledExecutorService.awaitTermination(metricConfig.shutdownWaitTimeMs(),
TimeUnit.MILLISECONDS)) {
LOG.warn("Metric registry did not shut down in {}ms time, so try to shut down again.",
if (!scheduledExecutorService.awaitTermination(
metricConfig.shutdownWaitTimeMs(), TimeUnit.MILLISECONDS)) {
LOG.warn(
"Metric registry did not shut down in {}ms time, so try to shut down again.",
metricConfig.shutdownWaitTimeMs());
scheduledExecutorService.shutdownNow();
}
} catch (InterruptedException e) {
LOG.warn("Interrupted when shutting down metric registry, so try to shut down again.",
e.getMessage(), e);
LOG.warn(
"Interrupted when shutting down metric registry, so try to shut down again.",
e.getMessage(),
e);
scheduledExecutorService.shutdownNow();
}
if (scheduledExecutorService.isShutdown()) {
@@ -106,9 +112,10 @@ public class MetricRegistry {
}
private void update() {
registeredMetrics.forEach((id, metric) -> {
metric.record();
});
registeredMetrics.forEach(
(id, metric) -> {
metric.record();
});
}
private MetricType getMetricType(Metric metric) {
@@ -131,5 +138,4 @@ public class MetricRegistry {
private MetricId genMetricIdByMetric(Metric metric) {
return new MetricId(getMetricType(metric), metric.name, metric.tags);
}
}
}
@@ -1,10 +1,7 @@
package io.ray.runtime.metric;
/**
* Types of the metric.
*/
/** Types of the metric. */
public enum MetricType {
COUNT,
GAUGE,
@@ -12,5 +9,4 @@ public enum MetricType {
SUM,
HISTOGRAM
}
}
@@ -5,9 +5,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* The entry of metrics for easy use.
*/
/** The entry of metrics for easy use. */
public final class Metrics {
private static MetricRegistry metricRegistry;
@@ -113,7 +111,7 @@ public final class Metrics {
/**
* Creates a metric by sub-class.
*
* @return a metric
* <p>Returns a metric
*/
protected abstract M create();
@@ -126,10 +124,11 @@ public final class Metrics {
private static Map<TagKey, String> generateTagKeysMap(Map<String, String> tags) {
Map<TagKey, String> tagKeys = new HashMap<>(tags.size() * 2);
tags.forEach((key, value) -> {
TagKey tagKey = new TagKey(key);
tagKeys.put(tagKey, value);
});
tags.forEach(
(key, value) -> {
TagKey tagKey = new TagKey(key);
tagKeys.put(tagKey, value);
});
return tagKeys;
}
@@ -144,5 +143,4 @@ public final class Metrics {
}
}
}
}
}
@@ -8,22 +8,20 @@ import java.util.List;
class NativeMetric {
public static native void registerTagkeyNative(String tagKey);
public static native long registerCountNative(String name, String description,
String unit, List<String> tagKeys);
public static native long registerCountNative(
String name, String description, String unit, List<String> tagKeys);
public static native long registerGaugeNative(String name, String description,
String unit, List<String> tagKeys);
public static native long registerGaugeNative(
String name, String description, String unit, List<String> tagKeys);
public static native long registerHistogramNative(String name, String description,
String unit, double[] boundaries,
List<String> tagKeys);
public static native long registerHistogramNative(
String name, String description, String unit, double[] boundaries, List<String> tagKeys);
public static native long registerSumNative(String name, String description,
String unit, List<String> tagKeys);
public static native long registerSumNative(
String name, String description, String unit, List<String> tagKeys);
public static native void recordNative(long metricNativePointer, double value,
List tagKeys, List<String> tagValues);
public static native void recordNative(
long metricNativePointer, double value, List tagKeys, List<String> tagValues);
public static native void unregisterMetricNative(long gaugePtr);
}
@@ -1,14 +1,13 @@
package io.ray.runtime.metric;
import com.google.common.base.Preconditions;
import java.util.Map;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.stream.Collectors;
/**
* Sum measurement is mapped to sum object in stats.
* Property sum is used for storing transient sum for registry aggregation.
* Sum measurement is mapped to sum object in stats. Property sum is used for storing transient sum
* for registry aggregation.
*/
public class Sum extends Metric {
@@ -16,8 +15,12 @@ public class Sum extends Metric {
public Sum(String name, String description, String unit, Map<TagKey, String> tags) {
super(name, tags);
metricNativePointer = NativeMetric.registerSumNative(name, description, unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
metricNativePointer =
NativeMetric.registerSumNative(
name,
description,
unit,
tags.keySet().stream().map(TagKey::getTagKey).collect(Collectors.toList()));
Preconditions.checkState(metricNativePointer != 0, "Count native pointer must not be 0.");
this.sum = new DoubleAdder();
}
@@ -2,9 +2,7 @@ package io.ray.runtime.metric;
import java.util.Objects;
/**
* Tagkey is mapping java object to stats tagkey object.
*/
/** Tagkey is mapping java object to stats tagkey object. */
public class TagKey {
private String tagKey;
@@ -37,8 +35,6 @@ public class TagKey {
@Override
public String toString() {
return "TagKey{" +
", tagKey='" + tagKey + '\'' +
'}';
return "TagKey{" + ", tagKey='" + tagKey + '\'' + '}';
}
}
}
@@ -14,9 +14,7 @@ import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Object store methods for local mode.
*/
/** Object store methods for local mode. */
public class LocalModeObjectStore extends ObjectStore {
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class);
@@ -100,12 +98,10 @@ public class LocalModeObjectStore extends ObjectStore {
}
@Override
public void addLocalReference(UniqueId workerId, ObjectId objectId) {
}
public void addLocalReference(UniqueId workerId, ObjectId objectId) {}
@Override
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {
}
public void removeLocalReference(UniqueId workerId, ObjectId objectId) {}
@Override
public Address getOwnerAddress(ObjectId id) {
@@ -119,6 +115,5 @@ public class LocalModeObjectStore extends ObjectStore {
@Override
public void registerOwnershipInfoAndResolveFuture(
ObjectId objectId, ObjectId outerObjectId, byte[] ownerAddress) {
}
ObjectId objectId, ObjectId outerObjectId, byte[] ownerAddress) {}
}
@@ -76,8 +76,8 @@ public class NativeObjectStore extends ObjectStore {
}
@Override
public void registerOwnershipInfoAndResolveFuture(ObjectId objectId, ObjectId outerObjectId,
byte[] ownerAddress) {
public void registerOwnershipInfoAndResolveFuture(
ObjectId objectId, ObjectId outerObjectId, byte[] ownerAddress) {
byte[] outer = null;
if (outerObjectId != null) {
outer = outerObjectId.getBytes();
@@ -87,8 +87,7 @@ public class NativeObjectStore extends ObjectStore {
public Map<ObjectId, long[]> getAllReferenceCounts() {
Map<ObjectId, long[]> referenceCounts = new HashMap<>();
for (Map.Entry<byte[], long[]> entry :
nativeGetAllReferenceCounts().entrySet()) {
for (Map.Entry<byte[], long[]> entry : nativeGetAllReferenceCounts().entrySet()) {
referenceCounts.put(new ObjectId(entry.getKey()), entry.getValue());
}
return referenceCounts;
@@ -113,8 +112,8 @@ public class NativeObjectStore extends ObjectStore {
private static native List<NativeRayObject> nativeGet(List<byte[]> ids, long timeoutMs);
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
long timeoutMs);
private static native List<Boolean> nativeWait(
List<byte[]> objectIds, int numObjects, long timeoutMs);
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly);
@@ -128,6 +127,6 @@ public class NativeObjectStore extends ObjectStore {
private static native byte[] nativePromoteAndGetOwnershipInfo(byte[] objectId);
private static native void nativeRegisterOwnershipInfoAndResolveFuture(byte[] objectId,
byte[] outerObjectId, byte[] ownerAddress);
private static native void nativeRegisterOwnershipInfoAndResolveFuture(
byte[] objectId, byte[] outerObjectId, byte[] ownerAddress);
}
@@ -7,9 +7,7 @@ import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
/**
* Binary representation of a ray object. See `RayObject` class in C++ for details.
*/
/** Binary representation of a ray object. See `RayObject` class in C++ for details. */
public class NativeRayObject {
public byte[] data;
@@ -43,4 +41,3 @@ public class NativeRayObject {
return "<data>: " + bufferLength(data) + ", <metadata>: " + bufferLength(metadata);
}
}
@@ -17,9 +17,7 @@ import java.lang.ref.Reference;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Implementation of {@link ObjectRef}.
*/
/** Implementation of {@link ObjectRef}. */
public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue();
@@ -40,8 +38,7 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
addLocalReference();
}
public ObjectRefImpl() {
}
public ObjectRefImpl() {}
@Override
public synchronized T get() {
@@ -81,8 +78,10 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
in.readFully(ownerAddress);
addLocalReference();
RayRuntimeInternal runtime = (RayRuntimeInternal) Ray.internal();
runtime.getObjectStore().registerOwnershipInfoAndResolveFuture(
this.id, ObjectSerializer.getOuterObjectId(), ownerAddress);
runtime
.getObjectStore()
.registerOwnershipInfoAndResolveFuture(
this.id, ObjectSerializer.getOuterObjectId(), ownerAddress);
}
private void addLocalReference() {
@@ -93,8 +92,8 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
new ObjectRefImplReference(this);
}
private static final class ObjectRefImplReference extends
FinalizableWeakReference<ObjectRefImpl<?>> {
private static final class ObjectRefImplReference
extends FinalizableWeakReference<ObjectRefImpl<?>> {
private final UniqueId workerId;
private final ObjectId objectId;
@@ -116,7 +115,8 @@ public final class ObjectRefImpl<T> implements ObjectRef<T>, Externalizable {
REFERENCES.remove(this);
// It's possible that GC is executed after the runtime is shutdown.
if (Ray.isInitialized()) {
((RayRuntimeInternal) (Ray.internal())).getObjectStore()
((RayRuntimeInternal) (Ray.internal()))
.getObjectStore()
.removeLocalReference(workerId, objectId);
}
}
@@ -25,14 +25,14 @@ import org.apache.commons.lang3.tuple.Pair;
*/
public class ObjectSerializer {
private static final byte[] WORKER_EXCEPTION_META = String
.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META = String
.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META = String
.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META = String
.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
private static final byte[] WORKER_EXCEPTION_META =
String.valueOf(ErrorType.WORKER_DIED.getNumber()).getBytes();
private static final byte[] ACTOR_EXCEPTION_META =
String.valueOf(ErrorType.ACTOR_DIED.getNumber()).getBytes();
private static final byte[] UNRECONSTRUCTABLE_EXCEPTION_META =
String.valueOf(ErrorType.OBJECT_UNRECONSTRUCTABLE.getNumber()).getBytes();
private static final byte[] TASK_EXECUTION_EXCEPTION_META =
String.valueOf(ErrorType.TASK_EXECUTION_EXCEPTION.getNumber()).getBytes();
public static final byte[] OBJECT_METADATA_TYPE_CROSS_LANGUAGE = "XLANG".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
@@ -55,11 +55,10 @@ public class ObjectSerializer {
* Deserialize an object from an {@link NativeRayObject} instance.
*
* @param nativeRayObject The object to deserialize.
* @param objectId The associated object ID of the object.
* @return The deserialized object.
* @param objectId The associated object ID of the object. Returns The deserialized object.
*/
public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objectId,
Class<?> objectType) {
public static Object deserialize(
NativeRayObject nativeRayObject, ObjectId objectId, Class<?> objectType) {
byte[] meta = nativeRayObject.metadata;
byte[] data = nativeRayObject.data;
@@ -70,8 +69,8 @@ public class ObjectSerializer {
return ByteBuffer.wrap(data);
}
return data;
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) == 0 ||
Bytes.indexOf(meta, OBJECT_METADATA_TYPE_JAVA) == 0) {
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) == 0
|| Bytes.indexOf(meta, OBJECT_METADATA_TYPE_JAVA) == 0) {
return Serializer.decode(data, objectType);
} else if (Bytes.indexOf(meta, WORKER_EXCEPTION_META) == 0) {
return new RayWorkerException();
@@ -92,15 +91,14 @@ public class ObjectSerializer {
return RayTaskException.fromBytes(serialized);
} catch (InvalidProtocolBufferException e) {
throw new IllegalArgumentException(
"Can't deserialize RayTaskException object: " + objectId
.toString());
"Can't deserialize RayTaskException object: " + objectId.toString());
}
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_ACTOR_HANDLE) == 0) {
byte[] serialized = Serializer.decode(data, byte[].class);
return NativeActorHandle.fromBytes(serialized);
} else if (Bytes.indexOf(meta, OBJECT_METADATA_TYPE_PYTHON) == 0) {
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
.toString());
throw new IllegalArgumentException(
"Can't deserialize Python object: " + objectId.toString());
}
throw new IllegalArgumentException("Unrecognized metadata " + Arrays.toString(meta));
} else {
@@ -112,8 +110,7 @@ public class ObjectSerializer {
/**
* Serialize an Java object to an {@link NativeRayObject} instance.
*
* @param object The object to serialize.
* @return The serialized object.
* @param object The object to serialize. Returns The serialized object.
*/
public static NativeRayObject serialize(Object object) {
if (object instanceof NativeRayObject) {
@@ -142,7 +139,7 @@ public class ObjectSerializer {
// any other type should be the MessagePack serialized bytes.
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
} else if (object instanceof NativeActorHandle) {
NativeActorHandle actorHandle = (NativeActorHandle)object;
NativeActorHandle actorHandle = (NativeActorHandle) object;
byte[] serializedBytes = Serializer.encode(actorHandle.toBytes()).getLeft();
// serializedBytes is MessagePack serialized bytes
// Only OBJECT_METADATA_TYPE_RAW is raw bytes,
@@ -151,10 +148,12 @@ public class ObjectSerializer {
} else {
try {
Pair<byte[], Boolean> serialized = Serializer.encode(object);
NativeRayObject nativeRayObject = new NativeRayObject(serialized.getLeft(),
serialized.getRight()
? OBJECT_METADATA_TYPE_CROSS_LANGUAGE
: OBJECT_METADATA_TYPE_JAVA);
NativeRayObject nativeRayObject =
new NativeRayObject(
serialized.getLeft(),
serialized.getRight()
? OBJECT_METADATA_TYPE_CROSS_LANGUAGE
: OBJECT_METADATA_TYPE_JAVA);
nativeRayObject.setContainedObjectIds(getAndClearContainedObjectIds());
return nativeRayObject;
} catch (Exception e) {
@@ -14,9 +14,7 @@ import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* A class that is used to put/get objects to/from the object store.
*/
/** A class that is used to put/get objects to/from the object store. */
public abstract class ObjectStore {
private final WorkerContext workerContext;
@@ -28,8 +26,7 @@ public abstract class ObjectStore {
/**
* Put a raw object into object store.
*
* @param obj The ray object.
* @return Generated ID of the object.
* @param obj The ray object. Returns Generated ID of the object.
*/
public abstract ObjectId putRaw(NativeRayObject obj);
@@ -44,8 +41,7 @@ public abstract class ObjectStore {
/**
* Serialize and put an object to the object store.
*
* @param object The object to put.
* @return Id of the object.
* @param object The object to put. Returns Id of the object.
*/
public ObjectId put(Object object) {
if (object instanceof NativeRayObject) {
@@ -58,7 +54,7 @@ public abstract class ObjectStore {
/**
* Serialize and put an object to the object store, with the given object id.
*
* This method is only used for testing.
* <p>This method is only used for testing.
*
* @param object The object to put.
* @param objectId Object id.
@@ -75,8 +71,8 @@ public abstract class ObjectStore {
* Get a list of raw objects from the object store.
*
* @param objectIds IDs of the objects to get.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return Result list of objects data.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. Returns Result list
* of objects data.
*/
public abstract List<NativeRayObject> getRaw(List<ObjectId> objectIds, long timeoutMs);
@@ -84,8 +80,7 @@ public abstract class ObjectStore {
* Get a list of objects from the object store.
*
* @param ids List of the object ids.
* @param <T> Type of these objects.
* @return A list of GetResult objects.
* @param <T> Type of these objects. Returns A list of GetResult objects.
*/
@SuppressWarnings("unchecked")
public <T> List<T> get(List<ObjectId> ids, Class<?> elementType) {
@@ -99,8 +94,7 @@ public abstract class ObjectStore {
if (dataAndMeta != null) {
try {
ObjectSerializer.setOuterObjectId(ids.get(i));
object = ObjectSerializer
.deserialize(dataAndMeta, ids.get(i), elementType);
object = ObjectSerializer.deserialize(dataAndMeta, ids.get(i), elementType);
} finally {
ObjectSerializer.resetOuterObjectId();
}
@@ -124,8 +118,8 @@ public abstract class ObjectStore {
*
* @param objectIds IDs of the objects to wait for.
* @param numObjects Number of objects that should appear.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative.
* @return A bitset that indicates each object has appeared or not.
* @param timeoutMs Timeout in milliseconds, wait infinitely if it's negative. Returns A bitset
* that indicates each object has appeared or not.
*/
public abstract List<Boolean> wait(List<ObjectId> objectIds, int numObjects, long timeoutMs);
@@ -135,8 +129,8 @@ public abstract class ObjectStore {
*
* @param waitList A list of object references to wait for.
* @param numReturns The number of objects that should be returned.
* @param timeoutMs The maximum time in milliseconds to wait before returning.
* @return Two lists, one containing locally available objects, one containing the rest.
* @param timeoutMs The maximum time in milliseconds to wait before returning. Returns Two lists,
* one containing locally available objects, one containing the rest.
*/
public <T> WaitResult<T> wait(List<ObjectRef<T>> waitList, int numReturns, int timeoutMs) {
Preconditions.checkNotNull(waitList);
@@ -144,8 +138,8 @@ public abstract class ObjectStore {
return new WaitResult<>(Collections.emptyList(), Collections.emptyList());
}
List<ObjectId> ids = waitList.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId())
.collect(Collectors.toList());
List<ObjectId> ids =
waitList.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(Collectors.toList());
List<Boolean> ready = wait(ids, numReturns, timeoutMs);
List<ObjectRef<T>> readyList = new ArrayList<>();
@@ -191,8 +185,7 @@ public abstract class ObjectStore {
/**
* Promote the given object to the underlying object store, and get the ownership info.
*
* @param objectId The ID of the object to promote
* @return the serialized ownership address
* @param objectId The ID of the object to promote Returns the serialized ownership address
*/
public abstract byte[] promoteAndGetOwnershipInfo(ObjectId objectId);
@@ -204,10 +197,10 @@ public abstract class ObjectStore {
*
* @param objectId The object ID to deserialize.
* @param outerObjectId The object ID that contained objectId, if any. This may be nil if the
* object ID was inlined directly in a task spec or if it was passed
* out-of-band by the application (deserialized from a byte string).
* object ID was inlined directly in a task spec or if it was passed out-of-band by the
* application (deserialized from a byte string).
* @param ownerAddress The address of the object's owner.
*/
public abstract void registerOwnershipInfoAndResolveFuture(ObjectId objectId,
ObjectId outerObjectId, byte[] ownerAddress);
public abstract void registerOwnershipInfoAndResolveFuture(
ObjectId objectId, ObjectId outerObjectId, byte[] ownerAddress);
}
@@ -8,9 +8,7 @@ import io.ray.api.placementgroup.PlacementStrategy;
import java.util.List;
import java.util.Map;
/**
* The default implementation of `PlacementGroup` interface.
*/
/** The default implementation of `PlacementGroup` interface. */
public class PlacementGroupImpl implements PlacementGroup {
private final PlacementGroupId id;
@@ -19,10 +17,12 @@ public class PlacementGroupImpl implements PlacementGroup {
private final PlacementStrategy strategy;
private final PlacementGroupState state;
private PlacementGroupImpl(PlacementGroupId id, String name,
List<Map<String, Double>> bundles,
PlacementStrategy strategy,
PlacementGroupState state) {
private PlacementGroupImpl(
PlacementGroupId id,
String name,
List<Map<String, Double>> bundles,
PlacementStrategy strategy,
PlacementGroupState state) {
this.id = id;
this.name = name;
this.bundles = bundles;
@@ -52,16 +52,15 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Wait for the placement group to be ready within the specified time.
* @param timeoutSeconds Timeout in seconds.
* @return True if the placement group is created. False otherwise.
*
* @param timeoutSeconds Timeout in seconds. Returns True if the placement group is created. False
* otherwise.
*/
public boolean wait(int timeoutSeconds) {
return Ray.internal().waitPlacementGroupReady(id, timeoutSeconds);
}
/**
* A help class for create the placement group.
*/
/** A help class for create the placement group. */
public static class Builder {
private PlacementGroupId id;
private String name;
@@ -71,8 +70,8 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Set the Id of the placement group.
* @param id Id of the placement group.
* @return self.
*
* @param id Id of the placement group. Returns self.
*/
public Builder setId(PlacementGroupId id) {
this.id = id;
@@ -81,8 +80,8 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Set the name of the placement group.
* @param name Name of the placement group.
* @return self.
*
* @param name Name of the placement group. Returns self.
*/
public Builder setName(String name) {
this.name = name;
@@ -91,8 +90,8 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Set the bundles of the placement group.
* @param bundles the bundles of the placement group.
* @return self.
*
* @param bundles the bundles of the placement group. Returns self.
*/
public Builder setBundles(List<Map<String, Double>> bundles) {
this.bundles = bundles;
@@ -101,8 +100,8 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Set the placement strategy of the placement group.
* @param strategy the placement strategy of the placement group.
* @return self.
*
* @param strategy the placement strategy of the placement group. Returns self.
*/
public Builder setStrategy(PlacementStrategy strategy) {
this.strategy = strategy;
@@ -111,8 +110,8 @@ public class PlacementGroupImpl implements PlacementGroup {
/**
* Set the placement state of the placement group.
* @param state the state of the placement group.
* @return self.
*
* @param state the state of the placement group. Returns self.
*/
public Builder setState(PlacementGroupState state) {
this.state = state;
@@ -123,5 +122,4 @@ public class PlacementGroupImpl implements PlacementGroup {
return new PlacementGroupImpl(id, name, bundles, strategy, state);
}
}
}
@@ -12,9 +12,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Utils for placement group.
*/
/** Utils for placement group. */
public class PlacementGroupUtils {
private static List<Map<String, Double>> covertToUserSpecifiedBundles(List<Bundle> bundles) {
@@ -62,45 +60,50 @@ public class PlacementGroupUtils {
/**
* Generate a PlacementGroupImpl from placementGroupTableData protobuf data.
* @param placementGroupTableData protobuf data.
* @return placement group info {@link PlacementGroupImpl}
*
* @param placementGroupTableData protobuf data. Returns placement group info {@link
* PlacementGroupImpl}
*/
private static PlacementGroupImpl generatePlacementGroupFromPbData(
PlacementGroupTableData placementGroupTableData) {
PlacementGroupState state = covertToUserSpecifiedState(
placementGroupTableData.getState());
PlacementStrategy strategy = covertToUserSpecifiedStrategy(
placementGroupTableData.getStrategy());
PlacementGroupState state = covertToUserSpecifiedState(placementGroupTableData.getState());
PlacementStrategy strategy =
covertToUserSpecifiedStrategy(placementGroupTableData.getStrategy());
List<Map<String, Double>> bundles = covertToUserSpecifiedBundles(
placementGroupTableData.getBundlesList());
List<Map<String, Double>> bundles =
covertToUserSpecifiedBundles(placementGroupTableData.getBundlesList());
PlacementGroupId placementGroupId = PlacementGroupId.fromByteBuffer(
placementGroupTableData.getPlacementGroupId().asReadOnlyByteBuffer());
PlacementGroupId placementGroupId =
PlacementGroupId.fromByteBuffer(
placementGroupTableData.getPlacementGroupId().asReadOnlyByteBuffer());
return new PlacementGroupImpl.Builder()
.setId(placementGroupId).setName(placementGroupTableData.getName())
.setState(state).setStrategy(strategy).setBundles(bundles)
.build();
.setId(placementGroupId)
.setName(placementGroupTableData.getName())
.setState(state)
.setStrategy(strategy)
.setBundles(bundles)
.build();
}
/**
* Generate a PlacementGroupImpl from byte array.
* @param placementGroupByteArray bytes array from native method.
* @return placement group info {@link PlacementGroupImpl}
*
* @param placementGroupByteArray bytes array from native method. Returns placement group info
* {@link PlacementGroupImpl}
*/
public static PlacementGroupImpl generatePlacementGroupFromByteArray(
byte[] placementGroupByteArray) {
Preconditions.checkNotNull(placementGroupByteArray,
"Can't generate a placement group from empty byte array.");
Preconditions.checkNotNull(
placementGroupByteArray, "Can't generate a placement group from empty byte array.");
PlacementGroupTableData placementGroupTableData;
try {
placementGroupTableData = PlacementGroupTableData.parseFrom(placementGroupByteArray);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(
"Received invalid placement group table protobuf data from GCS.", e);
"Received invalid placement group table protobuf data from GCS.", e);
}
return generatePlacementGroupFromPbData(placementGroupTableData);
@@ -16,18 +16,14 @@ import org.apache.commons.io.IOUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Ray service management on one box.
*/
/** Ray service management on one box. */
public class RunManager {
private static final Logger LOGGER = LoggerFactory.getLogger(RunManager.class);
private static final Pattern pattern = Pattern.compile("--address='([^']+)'");
/**
* Start the head node.
*/
/** Start the head node. */
public static void startRayHead(RayConfig rayConfig) {
LOGGER.debug("Starting ray runtime @ {}.", rayConfig.nodeIp);
List<String> command = new ArrayList<>();
@@ -54,9 +50,7 @@ public class RunManager {
LOGGER.info("Ray runtime started @ {}.", rayConfig.nodeIp);
}
/**
* Stop ray.
*/
/** Stop ray. */
public static void stopRay() {
List<String> command = new ArrayList<>();
command.add("ray");
@@ -73,10 +67,12 @@ public class RunManager {
public static void getAddressInfoAndFillConfig(RayConfig rayConfig) {
// NOTE(kfstorm): This method depends on an internal Python API of ray to get the
// address info of the local node.
String script = String.format("import ray;"
+ " print(ray._private.services.get_address_info_from_redis("
+ "'%s', '%s', redis_password='%s'))",
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
String script =
String.format(
"import ray;"
+ " print(ray._private.services.get_address_info_from_redis("
+ "'%s', '%s', redis_password='%s'))",
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
List<String> command = Arrays.asList("python", "-c", script);
String output = null;
@@ -110,9 +106,14 @@ public class RunManager {
String output = IOUtils.toString(p.getInputStream(), Charset.defaultCharset());
p.waitFor();
if (p.exitValue() != 0) {
String sb = "The exit value of the process is " + p.exitValue()
+ ". Command: " + Joiner.on(" ").join(command) + "\n"
+ "output:\n" + output;
String sb =
"The exit value of the process is "
+ p.exitValue()
+ ". Command: "
+ Joiner.on(" ").join(command)
+ "\n"
+ "output:\n"
+ output;
throw new RuntimeException(sb);
}
return output;
@@ -3,9 +3,7 @@ package io.ray.runtime.runner.worker;
import io.ray.api.Ray;
import io.ray.runtime.RayRuntimeInternal;
/**
* Default implementation of the worker process.
*/
/** Default implementation of the worker process. */
public class DefaultWorker {
public static void main(String[] args) {
@@ -16,4 +14,4 @@ public class DefaultWorker {
Ray.init();
((RayRuntimeInternal) Ray.internal()).run();
}
}
}
@@ -4,17 +4,17 @@ import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.actor.NativeActorHandleSerializer;
import org.nustaq.serialization.FSTConfiguration;
/**
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
*/
/** Java object serialization TODO: use others (e.g. Arrow) for higher performance */
public class FstSerializer {
private static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(() -> {
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
conf.registerSerializer(NativeActorHandle.class, new NativeActorHandleSerializer(), true);
return conf;
});
private static final ThreadLocal<FSTConfiguration> conf =
ThreadLocal.withInitial(
() -> {
FSTConfiguration conf = FSTConfiguration.createDefaultConfiguration();
conf.registerSerializer(
NativeActorHandle.class, new NativeActorHandleSerializer(), true);
return conf;
});
public static byte[] encode(Object obj) {
FSTConfiguration current = conf.get();
@@ -22,7 +22,6 @@ public class FstSerializer {
return current.asByteArray(obj);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs) {
FSTConfiguration current = conf.get();
@@ -44,38 +44,40 @@ public class MessagePackSerializer {
NULL_PACKER = (object, packer, javaSerializer) -> packer.packNil();
// Array packer.
ARRAY_PACKER = ((object, packer, javaSerializer) -> {
int length = Array.getLength(object);
packer.packArrayHeader(length);
for (int i = 0; i < length; ++i) {
pack(Array.get(object, i), packer, javaSerializer);
}
});
ARRAY_PACKER =
((object, packer, javaSerializer) -> {
int length = Array.getLength(object);
packer.packArrayHeader(length);
for (int i = 0; i < length; ++i) {
pack(Array.get(object, i), packer, javaSerializer);
}
});
// Extension packer.
EXTENSION_PACKER = ((object, packer, javaSerializer) -> {
javaSerializer.serialize(object, packer);
});
EXTENSION_PACKER =
((object, packer, javaSerializer) -> {
javaSerializer.serialize(object, packer);
});
packers.put(Boolean.class,
((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object)));
packers.put(Byte.class,
((object, packer, javaSerializer) -> packer.packByte((Byte) object)));
packers.put(Short.class,
((object, packer, javaSerializer) -> packer.packShort((Short) object)));
packers.put(Integer.class,
((object, packer, javaSerializer) -> packer.packInt((Integer) object)));
packers.put(Long.class,
((object, packer, javaSerializer) -> packer.packLong((Long) object)));
packers.put(BigInteger.class,
packers.put(
Boolean.class, ((object, packer, javaSerializer) -> packer.packBoolean((Boolean) object)));
packers.put(Byte.class, ((object, packer, javaSerializer) -> packer.packByte((Byte) object)));
packers.put(
Short.class, ((object, packer, javaSerializer) -> packer.packShort((Short) object)));
packers.put(
Integer.class, ((object, packer, javaSerializer) -> packer.packInt((Integer) object)));
packers.put(Long.class, ((object, packer, javaSerializer) -> packer.packLong((Long) object)));
packers.put(
BigInteger.class,
((object, packer, javaSerializer) -> packer.packBigInteger((BigInteger) object)));
packers.put(Float.class,
((object, packer, javaSerializer) -> packer.packFloat((Float) object)));
packers.put(Double.class,
((object, packer, javaSerializer) -> packer.packDouble((Double) object)));
packers.put(String.class,
((object, packer, javaSerializer) -> packer.packString((String) object)));
packers.put(byte[].class,
packers.put(
Float.class, ((object, packer, javaSerializer) -> packer.packFloat((Float) object)));
packers.put(
Double.class, ((object, packer, javaSerializer) -> packer.packDouble((Double) object)));
packers.put(
String.class, ((object, packer, javaSerializer) -> packer.packString((String) object)));
packers.put(
byte[].class,
((object, packer, javaSerializer) -> {
byte[] bytes = (byte[]) object;
packer.packBinaryHeader(bytes.length);
@@ -97,68 +99,89 @@ public class MessagePackSerializer {
// Null unpacker.
unpackers.put(ValueType.NIL, (value, targetClass, javaDeserializer) -> null);
// Boolean unpacker.
unpackers.put(ValueType.BOOLEAN, (value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(booleanClasses, targetClass),
"Boolean can't be deserialized as {}.", targetClass);
return value.asBooleanValue().getBoolean();
});
unpackers.put(
ValueType.BOOLEAN,
(value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(
checkTypeCompatible(booleanClasses, targetClass),
"Boolean can't be deserialized as {}.",
targetClass);
return value.asBooleanValue().getBoolean();
});
// Integer unpacker.
unpackers.put(ValueType.INTEGER, ((value, targetClass, javaDeserializer) -> {
IntegerValue iv = value.asIntegerValue();
if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) {
return iv.asByte();
} else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) {
return iv.asShort();
} else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) {
return iv.asInt();
} else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) {
return iv.asLong();
} else if (checkTypeCompatible(bigIntClasses, targetClass)) {
return iv.asBigInteger();
}
throw new IllegalArgumentException("Integer can't be deserialized as " + targetClass + ".");
}));
unpackers.put(
ValueType.INTEGER,
((value, targetClass, javaDeserializer) -> {
IntegerValue iv = value.asIntegerValue();
if (iv.isInByteRange() && checkTypeCompatible(byteClasses, targetClass)) {
return iv.asByte();
} else if (iv.isInShortRange() && checkTypeCompatible(shortClasses, targetClass)) {
return iv.asShort();
} else if (iv.isInIntRange() && checkTypeCompatible(intClasses, targetClass)) {
return iv.asInt();
} else if (iv.isInLongRange() && checkTypeCompatible(longClasses, targetClass)) {
return iv.asLong();
} else if (checkTypeCompatible(bigIntClasses, targetClass)) {
return iv.asBigInteger();
}
throw new IllegalArgumentException(
"Integer can't be deserialized as " + targetClass + ".");
}));
// Float unpacker.
unpackers.put(ValueType.FLOAT, ((value, targetClass, javaDeserializer) -> {
if (checkTypeCompatible(doubleClasses, targetClass)) {
return value.asFloatValue().toDouble();
} else if (checkTypeCompatible(floatClasses, targetClass)) {
return value.asFloatValue().toFloat();
}
throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + ".");
}));
unpackers.put(
ValueType.FLOAT,
((value, targetClass, javaDeserializer) -> {
if (checkTypeCompatible(doubleClasses, targetClass)) {
return value.asFloatValue().toDouble();
} else if (checkTypeCompatible(floatClasses, targetClass)) {
return value.asFloatValue().toFloat();
}
throw new IllegalArgumentException("Float can't be deserialized as " + targetClass + ".");
}));
// String unpacker.
unpackers.put(ValueType.STRING, ((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(stringClasses, targetClass),
"String can't be deserialized as {}.", targetClass);
return value.asStringValue().asString();
}));
unpackers.put(
ValueType.STRING,
((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(
checkTypeCompatible(stringClasses, targetClass),
"String can't be deserialized as {}.",
targetClass);
return value.asStringValue().asString();
}));
// Binary unpacker.
unpackers.put(ValueType.BINARY, ((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(checkTypeCompatible(binaryClasses, targetClass),
"Binary can't be deserialized as {}.", targetClass);
return value.asBinaryValue().asByteArray();
}));
unpackers.put(
ValueType.BINARY,
((value, targetClass, javaDeserializer) -> {
Preconditions.checkArgument(
checkTypeCompatible(binaryClasses, targetClass),
"Binary can't be deserialized as {}.",
targetClass);
return value.asBinaryValue().asByteArray();
}));
// Array unpacker.
unpackers.put(ValueType.ARRAY, ((value, targetClass, javaDeserializer) -> {
ArrayValue av = value.asArrayValue();
Class<?> componentType =
targetClass.isArray() ? targetClass.getComponentType() : Object.class;
Object array = Array.newInstance(componentType, av.size());
for (int i = 0; i < av.size(); ++i) {
Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer));
}
return array;
}));
unpackers.put(
ValueType.ARRAY,
((value, targetClass, javaDeserializer) -> {
ArrayValue av = value.asArrayValue();
Class<?> componentType =
targetClass.isArray() ? targetClass.getComponentType() : Object.class;
Object array = Array.newInstance(componentType, av.size());
for (int i = 0; i < av.size(); ++i) {
Array.set(array, i, unpack(av.get(i), componentType, javaDeserializer));
}
return array;
}));
// Extension unpacker.
unpackers.put(ValueType.EXTENSION, ((value, targetClass, javaDeserializer) -> {
ExtensionValue ev = value.asExtensionValue();
byte extType = ev.getType();
if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) {
return javaDeserializer.deserialize(ev);
}
throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + ".");
}));
unpackers.put(
ValueType.EXTENSION,
((value, targetClass, javaDeserializer) -> {
ExtensionValue ev = value.asExtensionValue();
byte extType = ev.getType();
if (extType == LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID) {
return javaDeserializer.deserialize(ev);
}
throw new IllegalArgumentException("Unknown extension type id " + ev.getType() + ".");
}));
}
interface JavaSerializer {
@@ -173,14 +196,13 @@ public class MessagePackSerializer {
interface TypePacker {
void pack(Object object, MessagePacker packer,
JavaSerializer javaSerializer) throws IOException;
void pack(Object object, MessagePacker packer, JavaSerializer javaSerializer)
throws IOException;
}
interface TypeUnpacker {
Object unpack(Value value, Class<?> targetClass,
JavaDeserializer javaDeserializer);
Object unpack(Value value, Class<?> targetClass, JavaDeserializer javaDeserializer);
}
private static boolean checkTypeCompatible(List<Class<?>> expected, Class<?> actual) {
@@ -230,12 +252,15 @@ public class MessagePackSerializer {
packer.writePayload(new byte[MESSAGE_PACK_OFFSET]);
// Serialize input object by MessagePack.
MutableBoolean isCrossLanguage = new MutableBoolean(true);
pack(obj, packer, ((object, packer1) -> {
byte[] payload = FstSerializer.encode(object);
packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length);
packer1.addPayload(payload);
isCrossLanguage.setFalse();
}));
pack(
obj,
packer,
((object, packer1) -> {
byte[] payload = FstSerializer.encode(object);
packer1.packExtensionTypeHeader(LANGUAGE_SPECIFIC_TYPE_EXTENSION_ID, payload.length);
packer1.addPayload(payload);
isCrossLanguage.setFalse();
}));
byte[] msgpackBytes = packer.toByteArray();
// Serialize MessagePack bytes length.
MessageBufferPacker headerPacker = MessagePack.newDefaultBufferPacker();
@@ -252,7 +277,6 @@ public class MessagePackSerializer {
}
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, Class<?> type) {
try {
@@ -263,14 +287,13 @@ public class MessagePackSerializer {
// Check MessagePack bytes length is valid.
Preconditions.checkState(MESSAGE_PACK_OFFSET + msgpackBytesLength <= bs.length);
// Deserialize MessagePack bytes from MESSAGE_PACK_OFFSET.
MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET,
(int) msgpackBytesLength);
MessageUnpacker unpacker =
MessagePack.newDefaultUnpacker(bs, MESSAGE_PACK_OFFSET, (int) msgpackBytesLength);
Value v = unpacker.unpackValue();
if (type == null) {
type = Object.class;
}
return (T) unpack(v, type,
((ExtensionValue ev) -> FstSerializer.decode(ev.getData())));
return (T) unpack(v, type, ((ExtensionValue ev) -> FstSerializer.decode(ev.getData())));
} catch (Exception e) {
throw new RuntimeException(e);
}
@@ -16,9 +16,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Helper methods to convert arguments from/to objects.
*/
/** Helper methods to convert arguments from/to objects. */
public class ArgumentsBuilder {
/**
@@ -28,15 +26,11 @@ public class ArgumentsBuilder {
// TODO(kfstorm): Read from internal config `max_direct_call_object_size`.
public static final int LARGEST_SIZE_PASS_BY_VALUE = 100 * 1024;
/**
* This dummy type is also defined in signature.py. Please keep it synced.
*/
private static final NativeRayObject PYTHON_DUMMY_TYPE = ObjectSerializer
.serialize("__RAY_DUMMY__".getBytes());
/** This dummy type is also defined in signature.py. Please keep it synced. */
private static final NativeRayObject PYTHON_DUMMY_TYPE =
ObjectSerializer.serialize("__RAY_DUMMY__".getBytes());
/**
* Convert real function arguments to task spec arguments.
*/
/** Convert real function arguments to task spec arguments. */
public static List<FunctionArg> wrap(Object[] args, Language language) {
List<FunctionArg> ret = new ArrayList<>();
for (Object arg : args) {
@@ -51,20 +45,21 @@ public class ArgumentsBuilder {
value = ObjectSerializer.serialize(arg);
if (language != Language.JAVA) {
boolean isCrossData =
Bytes.indexOf(value.metadata,
ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) == 0 ||
Bytes.indexOf(value.metadata,
ObjectSerializer.OBJECT_METADATA_TYPE_RAW) == 0 ||
Bytes.indexOf(value.metadata,
ObjectSerializer.OBJECT_METADATA_TYPE_ACTOR_HANDLE) == 0;
Bytes.indexOf(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE)
== 0
|| Bytes.indexOf(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW) == 0
|| Bytes.indexOf(
value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_ACTOR_HANDLE)
== 0;
if (!isCrossData) {
throw new IllegalArgumentException(String.format("Can't transfer %s data to %s",
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
throw new IllegalArgumentException(
String.format(
"Can't transfer %s data to %s",
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
}
}
if (value.data.length > LARGEST_SIZE_PASS_BY_VALUE) {
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore()
.putRaw(value);
id = ((RayRuntimeInternal) Ray.internal()).getObjectStore().putRaw(value);
address = ((RayRuntimeInternal) Ray.internal()).getWorkerContext().getRpcAddress();
value = null;
}
@@ -81,9 +76,7 @@ public class ArgumentsBuilder {
return ret;
}
/**
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
*/
/** Convert list of NativeRayObject/ByteBuffer to real function arguments. */
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
@@ -6,25 +6,18 @@ import io.ray.runtime.generated.Common.Address;
import io.ray.runtime.object.NativeRayObject;
/**
* Represents a function argument in task spec.
* Either `id` or `data` should be null, when id is not null, this argument will be
* passed by reference, otherwise it will be passed by value.
* Represents a function argument in task spec. Either `id` or `data` should be null, when id is not
* null, this argument will be passed by reference, otherwise it will be passed by value.
*/
public class FunctionArg {
/**
* The id of this argument (passed by reference).
*/
/** The id of this argument (passed by reference). */
public final ObjectId id;
/**
* The owner address of this argument (passed by reference).
*/
/** The owner address of this argument (passed by reference). */
public final Address ownerAddress;
/**
* Serialized data of this argument (passed by value).
*/
/** Serialized data of this argument (passed by value). */
public final NativeRayObject value;
private FunctionArg(ObjectId id, Address ownerAddress) {
@@ -42,16 +35,12 @@ public class FunctionArg {
this.value = nativeRayObject;
}
/**
* Create a FunctionArg that will be passed by reference.
*/
/** Create a FunctionArg that will be passed by reference. */
public static FunctionArg passByReference(ObjectId id, Address ownerAddress) {
return new FunctionArg(id, ownerAddress);
}
/**
* Create a FunctionArg that will be passed by value.
*/
/** Create a FunctionArg that will be passed by value. */
public static FunctionArg passByValue(NativeRayObject value) {
return new FunctionArg(value);
}
@@ -3,16 +3,12 @@ package io.ray.runtime.task;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
/**
* Task executor for local mode.
*/
/** Task executor for local mode. */
public class LocalModeTaskExecutor extends TaskExecutor<LocalModeTaskExecutor.LocalActorContext> {
static class LocalActorContext extends TaskExecutor.ActorContext {
/**
* The worker ID of the actor.
*/
/** The worker ID of the actor. */
private final UniqueId workerId;
public LocalActorContext(UniqueId workerId) {
@@ -32,5 +28,4 @@ public class LocalModeTaskExecutor extends TaskExecutor<LocalModeTaskExecutor.Lo
protected LocalActorContext createActorContext() {
return new LocalActorContext(runtime.getWorkerContext().getCurrentWorkerId());
}
}
@@ -51,9 +51,7 @@ import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Task submitter for local mode.
*/
/** Task submitter for local mode. */
public class LocalModeTaskSubmitter implements TaskSubmitter {
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeTaskSubmitter.class);
@@ -78,8 +76,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private final Map<PlacementGroupId, PlacementGroup> placementGroups = new ConcurrentHashMap<>();
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
LocalModeObjectStore objectStore) {
public LocalModeTaskSubmitter(
RayRuntimeInternal runtime, TaskExecutor taskExecutor, LocalModeObjectStore objectStore) {
this.runtime = runtime;
this.taskExecutor = taskExecutor;
this.objectStore = objectStore;
@@ -117,8 +115,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
}
if (taskSpec.getType() == TaskType.ACTOR_TASK) {
ObjectId dummyObjectId = new ObjectId(
taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray());
ObjectId dummyObjectId =
new ObjectId(
taskSpec.getActorTaskSpec().getPreviousActorTaskDummyObjectId().toByteArray());
if (!objectStore.isObjectReady(dummyObjectId)) {
unreadyObjects.add(dummyObjectId);
}
@@ -126,9 +125,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
return unreadyObjects;
}
private TaskSpec.Builder getTaskSpecBuilder(TaskType taskType,
FunctionDescriptor functionDescriptor,
List<FunctionArg> args) {
private TaskSpec.Builder getTaskSpecBuilder(
TaskType taskType, FunctionDescriptor functionDescriptor, List<FunctionArg> args) {
byte[] taskIdBytes = new byte[TaskId.LENGTH];
new Random().nextBytes(taskIdBytes);
List<String> functionDescriptorList = functionDescriptor.toList();
@@ -136,64 +134,83 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
return TaskSpec.newBuilder()
.setType(taskType)
.setLanguage(Language.JAVA)
.setJobId(
ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
.setJobId(ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes()))
.setTaskId(ByteString.copyFrom(taskIdBytes))
.setFunctionDescriptor(Common.FunctionDescriptor.newBuilder()
.setJavaFunctionDescriptor(
Common.JavaFunctionDescriptor.newBuilder()
.setClassName(functionDescriptorList.get(0))
.setFunctionName(functionDescriptorList.get(1))
.setSignature(functionDescriptorList.get(2))))
.addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder()
.setObjectRef(ObjectReference.newBuilder().setObjectId(
ByteString.copyFrom(arg.id.getBytes()))).build()
: TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data))
.setMetadata(arg.value.metadata != null ? ByteString
.copyFrom(arg.value.metadata) : ByteString.EMPTY).build())
.collect(Collectors.toList()));
.setFunctionDescriptor(
Common.FunctionDescriptor.newBuilder()
.setJavaFunctionDescriptor(
Common.JavaFunctionDescriptor.newBuilder()
.setClassName(functionDescriptorList.get(0))
.setFunctionName(functionDescriptorList.get(1))
.setSignature(functionDescriptorList.get(2))))
.addAllArgs(
args.stream()
.map(
arg ->
arg.id != null
? TaskArg.newBuilder()
.setObjectRef(
ObjectReference.newBuilder()
.setObjectId(ByteString.copyFrom(arg.id.getBytes())))
.build()
: TaskArg.newBuilder()
.setData(ByteString.copyFrom(arg.value.data))
.setMetadata(
arg.value.metadata != null
? ByteString.copyFrom(arg.value.metadata)
: ByteString.EMPTY)
.build())
.collect(Collectors.toList()));
}
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
public List<ObjectId> submitTask(
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options) {
Preconditions.checkState(numReturns <= 1);
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args)
.setNumReturns(numReturns)
.build();
TaskSpec taskSpec =
getTaskSpecBuilder(TaskType.NORMAL_TASK, functionDescriptor, args)
.setNumReturns(numReturns)
.build();
submitTaskSpec(taskSpec);
return getReturnIds(taskSpec);
}
@Override
public BaseActorHandle createActor(
FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) throws IllegalArgumentException {
FunctionDescriptor functionDescriptor, List<FunctionArg> args, ActorCreationOptions options)
throws IllegalArgumentException {
if (options != null) {
if (options.group != null) {
PlacementGroupImpl group = (PlacementGroupImpl)options.group;
Preconditions.checkArgument(options.bundleIndex >= 0
&& options.bundleIndex < group.getBundles().size(),
PlacementGroupImpl group = (PlacementGroupImpl) options.group;
Preconditions.checkArgument(
options.bundleIndex >= 0 && options.bundleIndex < group.getBundles().size(),
String.format("Bundle index %s is invalid", options.bundleIndex));
}
}
ActorId actorId = ActorId.fromRandom();
TaskSpec taskSpec = getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args)
.setNumReturns(1)
.setActorCreationTaskSpec(ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(actorId.toByteBuffer()))
.build())
.build();
TaskSpec taskSpec =
getTaskSpecBuilder(TaskType.ACTOR_CREATION_TASK, functionDescriptor, args)
.setNumReturns(1)
.setActorCreationTaskSpec(
ActorCreationTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(actorId.toByteBuffer()))
.build())
.build();
submitTaskSpec(taskSpec);
final LocalModeActorHandle actorHandle
= new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0));
final LocalModeActorHandle actorHandle =
new LocalModeActorHandle(actorId, getReturnIds(taskSpec).get(0));
actorHandles.put(actorId, actorHandle.copy());
if (StringUtils.isNotBlank(options.name)) {
String fullName = options.global ? options.name :
String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), options.name);
Preconditions.checkArgument(!namedActors.containsKey(fullName),
String.format("Actor of name %s exists", fullName));
String fullName =
options.global
? options.name
: String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), options.name);
Preconditions.checkArgument(
!namedActors.containsKey(fullName), String.format("Actor of name %s exists", fullName));
namedActors.put(fullName, actorHandle);
}
return actorHandle;
@@ -201,22 +218,29 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitActorTask(
BaseActorHandle actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
BaseActorHandle actor,
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options) {
Preconditions.checkState(numReturns <= 1);
TaskSpec.Builder builder = getTaskSpecBuilder(TaskType.ACTOR_TASK, functionDescriptor, args);
List<ObjectId> returnIds = getReturnIds(
TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
TaskSpec taskSpec = builder
.setNumReturns(numReturns + 1)
.setActorTaskSpec(
ActorTaskSpec.newBuilder().setActorId(ByteString.copyFrom(actor.getId().getBytes()))
.setPreviousActorTaskDummyObjectId(ByteString.copyFrom(
((LocalModeActorHandle) actor)
.exchangePreviousActorTaskDummyObjectId(returnIds.get(returnIds.size() - 1))
.getBytes()))
.build())
.build();
List<ObjectId> returnIds =
getReturnIds(TaskId.fromBytes(builder.getTaskId().toByteArray()), numReturns + 1);
TaskSpec taskSpec =
builder
.setNumReturns(numReturns + 1)
.setActorTaskSpec(
ActorTaskSpec.newBuilder()
.setActorId(ByteString.copyFrom(actor.getId().getBytes()))
.setPreviousActorTaskDummyObjectId(
ByteString.copyFrom(
((LocalModeActorHandle) actor)
.exchangePreviousActorTaskDummyObjectId(
returnIds.get(returnIds.size() - 1))
.getBytes()))
.build())
.build();
submitTaskSpec(taskSpec);
if (numReturns == 0) {
return ImmutableList.of();
@@ -226,11 +250,15 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
@Override
public PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
PlacementGroupImpl placementGroup = new PlacementGroupImpl.Builder()
.setId(PlacementGroupId.fromRandom()).setName(name)
.setBundles(bundles).setStrategy(strategy).build();
public PlacementGroup createPlacementGroup(
String name, List<Map<String, Double>> bundles, PlacementStrategy strategy) {
PlacementGroupImpl placementGroup =
new PlacementGroupImpl.Builder()
.setId(PlacementGroupId.fromRandom())
.setName(name)
.setBundles(bundles)
.setStrategy(strategy)
.build();
placementGroups.put(placementGroup.getId(), placementGroup);
return placementGroup;
}
@@ -251,8 +279,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
public Optional<BaseActorHandle> getActor(String name, boolean global) {
String fullName = global ? name :
String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), name);
String fullName =
global ? name : String.format("%s-%s", Ray.getRuntimeContext().getCurrentJobId(), name);
ActorHandle actorHandle = namedActors.get(fullName);
if (null == actorHandle) {
return Optional.empty();
@@ -289,14 +317,15 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
synchronized (taskAndObjectLock) {
Set<ObjectId> unreadyObjects = getUnreadyObjects(taskSpec);
final Runnable runnable = () -> {
try {
executeTask(taskSpec);
} catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1);
}
};
final Runnable runnable =
() -> {
try {
executeTask(taskSpec);
} catch (Exception ex) {
LOGGER.error("Unexpected exception when executing a task.", ex);
System.exit(-1);
}
};
if (unreadyObjects.isEmpty()) {
// If all dependencies are ready, execute this task.
@@ -318,7 +347,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
executorService.submit(runnable);
} catch (RejectedExecutionException e) {
if (executorService.isShutdown()) {
LOGGER.warn("Ignore task submission due to the ExecutorService is shutdown. Task: {}",
LOGGER.warn(
"Ignore task submission due to the ExecutorService is shutdown. Task: {}",
taskSpec);
}
}
@@ -338,16 +368,20 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
Preconditions.checkNotNull(actorContext);
}
taskExecutor.setActorContext(actorContext);
List<NativeRayObject> args = getFunctionArgs(taskSpec).stream()
.map(arg -> arg.id != null ?
objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
List<NativeRayObject> args =
getFunctionArgs(taskSpec).stream()
.map(
arg ->
arg.id != null
? objectStore.getRaw(Collections.singletonList(arg.id), -1).get(0)
: arg.value)
.collect(Collectors.toList());
runtime.setIsContextSet(true);
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentTask(taskSpec);
UniqueId workerId = actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
UniqueId workerId =
actorContext != null
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
@@ -379,10 +413,9 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
}
private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) {
Common.FunctionDescriptor functionDescriptor =
taskSpec.getFunctionDescriptor();
if (functionDescriptor.getFunctionDescriptorCase() ==
Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
Common.FunctionDescriptor functionDescriptor = taskSpec.getFunctionDescriptor();
if (functionDescriptor.getFunctionDescriptorCase()
== Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) {
return new JavaFunctionDescriptor(
functionDescriptor.getJavaFunctionDescriptor().getClassName(),
functionDescriptor.getJavaFunctionDescriptor().getFunctionName(),
@@ -397,30 +430,35 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
for (int i = 0; i < taskSpec.getArgsCount(); i++) {
TaskArg arg = taskSpec.getArgs(i);
if (arg.getObjectRef().getObjectId() != ByteString.EMPTY) {
functionArgs.add(FunctionArg
.passByReference(new ObjectId(arg.getObjectRef().getObjectId().toByteArray()),
Address.getDefaultInstance()));
functionArgs.add(
FunctionArg.passByReference(
new ObjectId(arg.getObjectRef().getObjectId().toByteArray()),
Address.getDefaultInstance()));
} else {
functionArgs.add(FunctionArg.passByValue(
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
functionArgs.add(
FunctionArg.passByValue(
new NativeRayObject(arg.getData().toByteArray(), arg.getMetadata().toByteArray())));
}
}
return functionArgs;
}
private static List<ObjectId> getReturnIds(TaskSpec taskSpec) {
return getReturnIds(TaskId.fromBytes(taskSpec.getTaskId().toByteArray()),
taskSpec.getNumReturns());
return getReturnIds(
TaskId.fromBytes(taskSpec.getTaskId().toByteArray()), taskSpec.getNumReturns());
}
private static List<ObjectId> getReturnIds(TaskId taskId, long numReturns) {
List<ObjectId> returnIds = new ArrayList<>();
for (int i = 0; i < numReturns; i++) {
returnIds.add(ObjectId.fromByteBuffer(
(ByteBuffer) ByteBuffer.allocate(ObjectId.LENGTH).put(taskId.getBytes())
.putInt(TaskId.LENGTH, i + 1).position(0)));
returnIds.add(
ObjectId.fromByteBuffer(
(ByteBuffer)
ByteBuffer.allocate(ObjectId.LENGTH)
.put(taskId.getBytes())
.putInt(TaskId.LENGTH, i + 1)
.position(0)));
}
return returnIds;
}
}
@@ -3,13 +3,10 @@ package io.ray.runtime.task;
import io.ray.api.id.UniqueId;
import io.ray.runtime.RayRuntimeInternal;
/**
* Task executor for cluster mode.
*/
/** Task executor for cluster mode. */
public class NativeTaskExecutor extends TaskExecutor<NativeTaskExecutor.NativeActorContext> {
static class NativeActorContext extends TaskExecutor.ActorContext {
}
static class NativeActorContext extends TaskExecutor.ActorContext {}
public NativeTaskExecutor(RayRuntimeInternal runtime) {
super(runtime);
@@ -20,16 +20,18 @@ import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.commons.lang3.StringUtils;
/**
* Task submitter for cluster mode. This is a wrapper class for core worker task interface.
*/
/** Task submitter for cluster mode. This is a wrapper class for core worker task interface. */
public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, functionDescriptor.hashCode(),
args, numReturns, options);
public List<ObjectId> submitTask(
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options) {
List<byte[]> returnIds =
nativeSubmitTask(
functionDescriptor, functionDescriptor.hashCode(), args, numReturns, options);
if (returnIds == null) {
return ImmutableList.of();
}
@@ -37,25 +39,26 @@ public class NativeTaskSubmitter implements TaskSubmitter {
}
@Override
public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) throws IllegalArgumentException {
public BaseActorHandle createActor(
FunctionDescriptor functionDescriptor, List<FunctionArg> args, ActorCreationOptions options)
throws IllegalArgumentException {
if (options != null) {
if (options.group != null) {
PlacementGroupImpl group = (PlacementGroupImpl)options.group;
Preconditions.checkArgument(options.bundleIndex >= 0
&& options.bundleIndex < group.getBundles().size(),
PlacementGroupImpl group = (PlacementGroupImpl) options.group;
Preconditions.checkArgument(
options.bundleIndex >= 0 && options.bundleIndex < group.getBundles().size(),
String.format("Bundle index %s is invalid", options.bundleIndex));
}
if (StringUtils.isNotBlank(options.name)) {
Optional<BaseActorHandle> actor =
options.global ? Ray.getGlobalActor(options.name) : Ray.getActor(options.name);
Preconditions.checkArgument(!actor.isPresent(),
String.format("Actor of name %s exists", options.name));
Preconditions.checkArgument(
!actor.isPresent(), String.format("Actor of name %s exists", options.name));
}
}
byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args,
options);
byte[] actorId =
nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args, options);
return NativeActorHandle.create(actorId, functionDescriptor.getLanguage());
}
@@ -66,11 +69,20 @@ public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitActorTask(
BaseActorHandle actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options) {
BaseActorHandle actor,
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options) {
Preconditions.checkState(actor instanceof NativeActorHandle);
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
functionDescriptor, functionDescriptor.hashCode(), args, numReturns, options);
List<byte[]> returnIds =
nativeSubmitActorTask(
actor.getId().getBytes(),
functionDescriptor,
functionDescriptor.hashCode(),
args,
numReturns,
options);
if (returnIds == null) {
return ImmutableList.of();
}
@@ -78,12 +90,15 @@ public class NativeTaskSubmitter implements TaskSubmitter {
}
@Override
public PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
public PlacementGroup createPlacementGroup(
String name, List<Map<String, Double>> bundles, PlacementStrategy strategy) {
byte[] bytes = nativeCreatePlacementGroup(name, bundles, strategy.value());
return new PlacementGroupImpl.Builder()
.setId(PlacementGroupId.fromBytes(bytes))
.setName(name).setBundles(bundles).setStrategy(strategy).build();
.setId(PlacementGroupId.fromBytes(bytes))
.setName(name)
.setBundles(bundles)
.setStrategy(strategy)
.build();
}
@Override
@@ -96,22 +111,32 @@ public class NativeTaskSubmitter implements TaskSubmitter {
return nativeWaitPlacementGroupReady(id.getBytes(), timeoutMs);
}
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
int functionDescriptorHash, List<FunctionArg> args, int numReturns, CallOptions callOptions);
private static native List<byte[]> nativeSubmitTask(
FunctionDescriptor functionDescriptor,
int functionDescriptorHash,
List<FunctionArg> args,
int numReturns,
CallOptions callOptions);
private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
int functionDescriptorHash, List<FunctionArg> args,
private static native byte[] nativeCreateActor(
FunctionDescriptor functionDescriptor,
int functionDescriptorHash,
List<FunctionArg> args,
ActorCreationOptions actorCreationOptions);
private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
FunctionDescriptor functionDescriptor, int functionDescriptorHash, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
private static native List<byte[]> nativeSubmitActorTask(
byte[] actorId,
FunctionDescriptor functionDescriptor,
int functionDescriptorHash,
List<FunctionArg> args,
int numReturns,
CallOptions callOptions);
private static native byte[] nativeCreatePlacementGroup(String name,
List<Map<String, Double>> bundles, int strategy);
private static native byte[] nativeCreatePlacementGroup(
String name, List<Map<String, Double>> bundles, int strategy);
private static native void nativeRemovePlacementGroup(byte[] placementGroupId);
private static native boolean nativeWaitPlacementGroupReady(byte[] placementGroupId,
int timeoutMs);
private static native boolean nativeWaitPlacementGroupReady(
byte[] placementGroupId, int timeoutMs);
}
@@ -20,9 +20,7 @@ import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* The task executor, which executes tasks assigned by raylet continuously.
*/
/** The task executor, which executes tasks assigned by raylet continuously. */
public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private static final Logger LOGGER = LoggerFactory.getLogger(TaskExecutor.class);
@@ -35,14 +33,10 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
static class ActorContext {
/**
* The current actor object, if this worker is an actor, otherwise null.
*/
/** The current actor object, if this worker is an actor, otherwise null. */
Object currentActor = null;
/**
* The exception that failed the actor creation task, if any.
*/
/** The exception that failed the actor creation task, if any. */
Throwable actorCreationException = null;
}
@@ -74,9 +68,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
}
/**
* The return value indicates which parameters are ByteBuffer.
*/
/** The return value indicates which parameters are ByteBuffer. */
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
localRayFunction.set(null);
try {
@@ -93,8 +85,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
return results;
}
protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<Object> argsBytes) {
protected List<NativeRayObject> execute(List<String> rayFunctionInfo, List<Object> argsBytes) {
runtime.setIsContextSet(true);
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
@@ -130,8 +121,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
}
actor = actorContext.currentActor;
}
Object[] args = ArgumentsBuilder
.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
Object[] args =
ArgumentsBuilder.unwrap(argsBytes, rayFunction.executable.getParameterTypes());
// Execute the task.
Object result;
try {
@@ -168,8 +159,9 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
returnObjects.add(
ObjectSerializer.serialize(
new RayTaskException("Error executing task " + taskId, e)));
}
} else {
actorContext.actorCreationException = e;
@@ -184,8 +176,7 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {
private JavaFunctionDescriptor parseFunctionDescriptor(List<String> rayFunctionInfo) {
Preconditions.checkState(rayFunctionInfo != null && rayFunctionInfo.size() == 3);
return new JavaFunctionDescriptor(rayFunctionInfo.get(0), rayFunctionInfo.get(1),
rayFunctionInfo.get(2));
return new JavaFunctionDescriptor(
rayFunctionInfo.get(0), rayFunctionInfo.get(1), rayFunctionInfo.get(2));
}
}
@@ -12,70 +12,76 @@ import io.ray.runtime.functionmanager.FunctionDescriptor;
import java.util.List;
import java.util.Map;
/**
* A set of methods to submit tasks and create actors.
*/
/** A set of methods to submit tasks and create actors. */
public interface TaskSubmitter {
/**
* Submit a normal task.
*
* @param functionDescriptor The remote function to execute.
* @param args Arguments of this task.
* @param numReturns Return object count.
* @param options Options for this task.
* @return Ids of the return objects.
* @param options Options for this task. Returns Ids of the return objects.
*/
List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options);
List<ObjectId> submitTask(
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options);
/**
* Create an actor.
*
* @param functionDescriptor The remote function that generates the actor object.
* @param args Arguments of this task.
* @param options Options for this actor creation task.
* @return Handle to the actor.
* @param options Options for this actor creation task. Returns Handle to the actor.
* @throws IllegalArgumentException if actor of specified name exists
*/
BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) throws IllegalArgumentException;
BaseActorHandle createActor(
FunctionDescriptor functionDescriptor, List<FunctionArg> args, ActorCreationOptions options)
throws IllegalArgumentException;
/**
* Submit an actor task.
*
* @param actor Handle to the actor.
* @param functionDescriptor The remote function to execute.
* @param args Arguments of this task.
* @param numReturns Return object count.
* @param options Options for this task.
* @return Ids of the return objects.
* @param options Options for this task. Returns Ids of the return objects.
*/
List<ObjectId> submitActorTask(BaseActorHandle actor, FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions options);
List<ObjectId> submitActorTask(
BaseActorHandle actor,
FunctionDescriptor functionDescriptor,
List<FunctionArg> args,
int numReturns,
CallOptions options);
/**
* Create a placement group.
*
* @param name Name of the placement group.
* @param bundles Pre-allocated resource list.
* @param strategy Actor placement strategy.
* @return A handle to the created placement group.
* @param strategy Actor placement strategy. Returns A handle to the created placement group.
*/
PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy);
PlacementGroup createPlacementGroup(
String name, List<Map<String, Double>> bundles, PlacementStrategy strategy);
/**
* Remove a placement group by id.
*
* @param id Id of the placement group.
*/
void removePlacementGroup(PlacementGroupId id);
/**
* Wait for the placement group to be ready within the specified time.
*
* @param id Id of placement group.
* @param timeoutMs Timeout in milliseconds.
* @return True if the placement group is created. False otherwise.
* @param timeoutMs Timeout in milliseconds. Returns True if the placement group is created. False
* otherwise.
*/
boolean waitPlacementGroupReady(PlacementGroupId id, int timeoutMs);
BaseActorHandle getActor(ActorId actorId);
}
@@ -16,13 +16,12 @@ public class BinaryFileUtil {
public static final String CORE_WORKER_JAVA_LIBRARY = "core_worker_library_java";
/**
* Extract a platform-native resource file to <code>destDir</code>.
* Note that this a process-safe operation. If multi processes extract the file to same
* directory concurrently, this operation will be protected by a file lock.
* Extract a platform-native resource file to <code>destDir</code>. Note that this a process-safe
* operation. If multi processes extract the file to same directory concurrently, this operation
* will be protected by a file lock.
*
* @param destDir a directory to extract resource file to
* @param fileName resource file name
* @return extracted resource file
* @param destDir a directory to extract resource file to
* @param fileName resource file name Returns extracted resource file
*/
public static File getNativeFile(String destDir, String fileName) {
final File dir = new File(destDir);
@@ -34,8 +33,7 @@ public class BinaryFileUtil {
}
}
String lockFilePath = destDir + File.separator + "file_lock";
try (FileLock ignored = new RandomAccessFile(lockFilePath, "rw")
.getChannel().lock()) {
try (FileLock ignored = new RandomAccessFile(lockFilePath, "rw").getChannel().lock()) {
String resourceDir;
if (SystemUtils.IS_OS_MAC) {
resourceDir = "native/darwin/";
@@ -12,15 +12,16 @@ public class IdUtil {
/**
* Compute the actor ID of the task which created this object.
* @return The actor ID of the task which created this object.
*
* <p>Returns The actor ID of the task which created this object.
*/
public static ActorId getActorIdFromObjectId(ObjectId objectId) {
byte[] taskIdBytes = new byte[TaskId.LENGTH];
System.arraycopy(objectId.getBytes(), 0, taskIdBytes, 0, TaskId.LENGTH);
TaskId taskId = TaskId.fromBytes(taskIdBytes);
byte[] actorIdBytes = new byte[ActorId.LENGTH];
System.arraycopy(taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH,
actorIdBytes, 0, ActorId.LENGTH);
System.arraycopy(
taskId.getBytes(), TaskId.UNIQUE_BYTES_LENGTH, actorIdBytes, 0, ActorId.LENGTH);
return ActorId.fromBytes(actorIdBytes);
}
}
@@ -9,10 +9,14 @@ public final class JniExceptionUtil {
private static final Logger LOGGER = LoggerFactory.getLogger(JniExceptionUtil.class);
public static String getStackTrace(String fileName, int lineNumber, String function,
Throwable throwable) {
LOGGER.error("An unexpected exception occurred while executing Java code from JNI ({}:{} {}).",
fileName, lineNumber, function, throwable);
public static String getStackTrace(
String fileName, int lineNumber, String function, Throwable throwable) {
LOGGER.error(
"An unexpected exception occurred while executing Java code from JNI ({}:{} {}).",
fileName,
lineNumber,
function,
throwable);
// Return the exception in string form to JNI.
return ExceptionUtils.getStackTrace(throwable);
}
@@ -15,20 +15,20 @@ public class JniUtils {
private static String defaultDestDir;
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
* Loads the native library specified by the <code>libraryName</code> argument. The <code>
* libraryName</code> argument must not contain any platform specific prefix, file extension or
* path.
*
* @param libraryName the name of the library.
* @param libraryName the name of the library.
*/
public static synchronized void loadLibrary(String libraryName) {
loadLibrary(getDefaultDestDir(), libraryName);
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
* Loads the native library specified by the <code>libraryName</code> argument. The <code>
* libraryName</code> argument must not contain any platform specific prefix, file extension or
* path.
*
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
@@ -38,9 +38,9 @@ public class JniUtils {
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
* Loads the native library specified by the <code>libraryName</code> argument. The <code>
* libraryName</code> argument must not contain any platform specific prefix, file extension or
* path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
@@ -50,16 +50,16 @@ public class JniUtils {
}
/**
* Loads the native library specified by the <code>libraryName</code> argument.
* The <code>libraryName</code> argument must not contain any platform specific
* prefix, file extension or path.
* Loads the native library specified by the <code>libraryName</code> argument. The <code>
* libraryName</code> argument must not contain any platform specific prefix, file extension or
* path.
*
* @param destDir The destination dir the library to be extracted.
* @param libraryName the name of the library.
* @param libraryName the name of the library.
* @param exportSymbols export symbols of library so that it can be used by other libs.
*/
public static synchronized void loadLibrary(String destDir, String libraryName,
boolean exportSymbols) {
public static synchronized void loadLibrary(
String destDir, String libraryName, boolean exportSymbols) {
if (!loadedLibs.contains(libraryName)) {
LOGGER.debug("Loading native library {}.", libraryName);
// Load native library.
@@ -77,9 +77,7 @@ public class JniUtils {
}
}
/**
* Cache the result so that multiple calls return the same dest dir.
*/
/** Cache the result so that multiple calls return the same dest dir. */
private static synchronized String getDefaultDestDir() {
if (defaultDestDir == null) {
try {
@@ -4,14 +4,10 @@ import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
/**
* see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html.
*/
/** see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html. */
public final class LambdaUtils {
private LambdaUtils() {
}
private LambdaUtils() {}
public static SerializedLambda getSerializedLambda(Serializable lambda) {
// Note.
@@ -26,6 +22,4 @@ public final class LambdaUtils {
throw new RuntimeException("failed to getSerializedLambda:" + lambda.getClass().getName(), e);
}
}
}
}
@@ -44,5 +44,4 @@ public class LoggingUtil {
appender.activateOptions();
Logger.getLogger("io.ray").addAppender(appender);
}
}
@@ -27,8 +27,8 @@ public class NetworkUtil {
if (!current.isUp() || current.isLoopback() || current.isVirtual()) {
continue;
}
if (!Strings.isNullOrEmpty(interfaceName) && !interfaceName
.equals(current.getDisplayName())) {
if (!Strings.isNullOrEmpty(interfaceName)
&& !interfaceName.equals(current.getDisplayName())) {
continue;
}
Enumeration<InetAddress> addresses = current.getInetAddresses();
@@ -65,7 +65,8 @@ public class NetworkUtil {
throw new IllegalArgumentException("Invalid start port: " + port);
}
try (ServerSocket ss = new ServerSocket(port); DatagramSocket ds = new DatagramSocket(port)) {
try (ServerSocket ss = new ServerSocket(port);
DatagramSocket ds = new DatagramSocket(port)) {
ss.setReuseAddress(true);
ds.setReuseAddress(true);
return true;
@@ -8,11 +8,11 @@ public class ResourceUtil {
public static final String GPU_LITERAL = "GPU";
/**
* Convert resources map to a string that is used
* for the command line argument of starting raylet.
* Convert resources map to a string that is used for the command line argument of starting
* raylet.
*
* @param resources The resources map to be converted.
* @return The starting-raylet command line argument, like "CPU,4,GPU,0".
* @param resources The resources map to be converted. Returns The starting-raylet command line
* argument, like "CPU,4,GPU,0".
*/
public static String getResourcesStringFromMap(Map<String, Double> resources) {
StringBuilder builder = new StringBuilder();
@@ -32,11 +32,10 @@ public class ResourceUtil {
/**
* Parse the static resources configure field and convert to the resources map.
*
* @param resources The static resources string to be parsed.
* @return The map whose key represents the resource name
* and the value represents the resource quantity.
* @throws IllegalArgumentException If the resources string's format does match,
* it will throw an IllegalArgumentException.
* @param resources The static resources string to be parsed. Returns The map whose key represents
* the resource name and the value represents the resource quantity.
* @throws IllegalArgumentException If the resources string's format does match, it will throw an
* IllegalArgumentException.
*/
public static Map<String, Double> getResourcesMapFromString(String resources)
throws IllegalArgumentException {
@@ -5,9 +5,7 @@ import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.util.concurrent.locks.ReentrantLock;
/**
* some utilities for system process.
*/
/** some utilities for system process. */
public class SystemUtil {
static final ReentrantLock pidlock = new ReentrantLock();
@@ -33,13 +31,12 @@ public class SystemUtil {
}
}
return pid;
}
public static boolean isProcessAlive(int pid) {
Process process;
try {
process = Runtime.getRuntime().exec(new String[]{"ps", "-p", String.valueOf(pid)});
process = Runtime.getRuntime().exec(new String[] {"ps", "-p", String.valueOf(pid)});
process.waitFor();
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
@@ -20,5 +20,4 @@ public abstract class BaseGenerator {
sb.append(" ");
}
}
}
@@ -13,9 +13,7 @@ import org.apache.commons.io.FileUtils;
*/
public class RayCallGenerator extends BaseGenerator {
/**
* @return Whole file content of `RayCall.java`.
*/
/** Returns Whole file content of `RayCall.java`. */
private String generateRayCallDotJava() {
sb = new StringBuilder();
@@ -70,9 +68,7 @@ public class RayCallGenerator extends BaseGenerator {
return sb.toString();
}
/**
* @return Whole file content of `ActorCall.java`.
*/
/** Returns Whole file content of `ActorCall.java`. */
private String generateActorCallDotJava() {
sb = new StringBuilder();
@@ -102,9 +98,7 @@ public class RayCallGenerator extends BaseGenerator {
return sb.toString();
}
/**
* @return Whole file content of `PyActorCall.java`.
*/
/** Returns Whole file content of `PyActorCall.java`. */
private String generatePyActorCallDotJava() {
sb = new StringBuilder();
@@ -128,16 +122,16 @@ public class RayCallGenerator extends BaseGenerator {
}
/**
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with
* the given number of parameters.
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with the given number of
* parameters.
*
* @param numParameters the number of parameters
* @param forActor Build `actor.call` when true, otherwise build `Ray.call`.
* @param hasReturn if true, Build api for functions with return.
* @param forActorCreation Build `Ray.createActor` when true, otherwise build `Ray.call`.
*/
private void buildCalls(int numParameters, boolean forActor,
boolean forActorCreation, boolean hasReturn) {
private void buildCalls(
int numParameters, boolean forActor, boolean forActorCreation, boolean hasReturn) {
// Template of the generated function:
// [modifiers] [genericTypes] [returnType] [callFunc]([argsDeclaration]) {
// Objects[] args = new Object[]{[args]};
@@ -186,10 +180,12 @@ public class RayCallGenerator extends BaseGenerator {
rayFuncGenericTypes = rayFuncGenericTypes.replace("<", "<A, ");
}
}
String argsDeclarationPrefix = String.format("RayFunc%s%d%s f, ",
hasReturn ? "" : "Void",
!forActor ? numParameters : numParameters + 1,
rayFuncGenericTypes);
String argsDeclarationPrefix =
String.format(
"RayFunc%s%d%s f, ",
hasReturn ? "" : "Void",
!forActor ? numParameters : numParameters + 1,
rayFuncGenericTypes);
String callFunc = forActorCreation ? "actor" : "task";
String caller;
@@ -209,10 +205,15 @@ public class RayCallGenerator extends BaseGenerator {
// Trim trailing ", ";
argsDeclaration = argsDeclaration.substring(0, argsDeclaration.length() - 2);
// Print the first line (method signature).
newLine(1, String.format(
"%s%s %s %s(%s) {", modifiers,
genericTypes.isEmpty() ? "" : " " + genericTypes, returnType, callFunc, argsDeclaration
));
newLine(
1,
String.format(
"%s%s %s %s(%s) {",
modifiers,
genericTypes.isEmpty() ? "" : " " + genericTypes,
returnType,
callFunc,
argsDeclaration));
// 4) Construct the `args` part.
String args = "";
@@ -240,15 +241,14 @@ public class RayCallGenerator extends BaseGenerator {
}
/**
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with
* the given number of parameters.
* Build `Ray.call`, `Ray.createActor` and `actor.call` methods with the given number of
* parameters.
*
* @param numParameters the number of parameters
* @param forActor Build `actor.call` when true, otherwise build `Ray.call`.
* @param forActorCreation Build `Ray.createActor` when true, otherwise build `Ray.call`.
*/
private void buildPyCalls(int numParameters, boolean forActor,
boolean forActorCreation) {
private void buildPyCalls(int numParameters, boolean forActor, boolean forActorCreation) {
String modifiers = forActor ? "default" : "public static";
String argList = "";
@@ -281,23 +281,23 @@ public class RayCallGenerator extends BaseGenerator {
}
String genericType = forActorCreation ? "" : " <R>";
String returnType = forActorCreation ? "PyActorCreator" :
forActor ? "PyActorTaskCaller<R>" : "PyTaskCaller<R>";
String returnType =
forActorCreation ? "PyActorCreator" : forActor ? "PyActorTaskCaller<R>" : "PyTaskCaller<R>";
String funcName = forActorCreation ? "actor" : "task";
String caller = forActorCreation ? "PyActorCreator" :
forActor ? "PyActorTaskCaller<>" : "PyTaskCaller<>";
String caller =
forActorCreation ? "PyActorCreator" : forActor ? "PyActorTaskCaller<>" : "PyTaskCaller<>";
funcArgs += ", args";
// Method signature.
newLine(1, String.format(
"%s%s %s %s(%s) {", modifiers, genericType,
returnType, funcName, paramPrefix + paramList
));
newLine(
1,
String.format(
"%s%s %s %s(%s) {",
modifiers, genericType, returnType, funcName, paramPrefix + paramList));
// Method body.
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
if (forActor) {
newLine(2, String.format("return new %s((PyActorHandle)this, %s);",
caller, funcArgs));
newLine(2, String.format("return new %s((PyActorHandle)this, %s);", caller, funcArgs));
} else {
newLine(2, String.format("return new %s(%s);", caller, funcArgs));
}
@@ -323,17 +323,18 @@ public class RayCallGenerator extends BaseGenerator {
}
public static void main(String[] args) throws IOException {
String path = System.getProperty("user.dir")
+ "/api/src/main/java/io/ray/api/RayCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generateRayCallDotJava(),
String path = System.getProperty("user.dir") + "/api/src/main/java/io/ray/api/RayCall.java";
FileUtils.write(
new File(path), new RayCallGenerator().generateRayCallDotJava(), Charset.defaultCharset());
path = System.getProperty("user.dir") + "/api/src/main/java/io/ray/api/ActorCall.java";
FileUtils.write(
new File(path),
new RayCallGenerator().generateActorCallDotJava(),
Charset.defaultCharset());
path = System.getProperty("user.dir")
+ "/api/src/main/java/io/ray/api/ActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generateActorCallDotJava(),
Charset.defaultCharset());
path = System.getProperty("user.dir")
+ "/api/src/main/java/io/ray/api/PyActorCall.java";
FileUtils.write(new File(path), new RayCallGenerator().generatePyActorCallDotJava(),
path = System.getProperty("user.dir") + "/api/src/main/java/io/ray/api/PyActorCall.java";
FileUtils.write(
new File(path),
new RayCallGenerator().generatePyActorCallDotJava(),
Charset.defaultCharset());
}
}
@@ -5,9 +5,7 @@ import java.io.IOException;
import java.nio.charset.Charset;
import org.apache.commons.io.FileUtils;
/**
* A util class that generates all the RayFuncX classes under io.ray.api.function package.
*/
/** A util class that generates all the RayFuncX classes under io.ray.api.function package. */
public class RayFuncGenerator extends BaseGenerator {
private String generate(int numParameters, boolean hasReturn) {
@@ -36,15 +34,18 @@ public class RayFuncGenerator extends BaseGenerator {
newLine("package io.ray.api.function;");
newLine("");
newLine("/**");
String comment = String.format(
" * Functional interface for a remote function that has %d parameter%s.",
numParameters, numParameters > 1 ? "s" : "");
String comment =
String.format(
" * Functional interface for a remote function that has %d parameter%s.",
numParameters, numParameters > 1 ? "s" : "");
newLine(comment);
newLine(" */");
newLine("@FunctionalInterface");
String className = "RayFunc" + (hasReturn ? "" : "Void") + numParameters;
newLine(String.format("public interface %s%s extends %s {",
className, genericTypes, hasReturn ? "RayFuncR<R>" : "RayFuncVoid"));
newLine(
String.format(
"public interface %s%s extends %s {",
className, genericTypes, hasReturn ? "RayFuncR<R>" : "RayFuncVoid"));
newLine("");
indents(1);
newLine(String.format("%s apply(%s) throws Exception;", hasReturn ? "R" : "void", paramList));
@@ -54,19 +55,16 @@ public class RayFuncGenerator extends BaseGenerator {
}
public static void main(String[] args) throws IOException {
String root = System.getProperty("user.dir")
+ "/api/src/main/java/io/ray/api/function/";
String root = System.getProperty("user.dir") + "/api/src/main/java/io/ray/api/function/";
RayFuncGenerator generator = new RayFuncGenerator();
for (int i = 0; i <= MAX_PARAMETERS; i++) {
// Functions that have return.
String content = generator.generate(i, true);
FileUtils.write(new File(root + "RayFunc" + i + ".java"), content,
Charset.defaultCharset());
FileUtils.write(new File(root + "RayFunc" + i + ".java"), content, Charset.defaultCharset());
// Functions that don't have return.
content = generator.generate(i, false);
FileUtils.write(new File(root + "RayFuncVoid" + i + ".java"), content,
Charset.defaultCharset());
FileUtils.write(
new File(root + "RayFuncVoid" + i + ".java"), content, Charset.defaultCharset());
}
}
}
@@ -12,12 +12,14 @@ public class UniqueIdTest {
@Test
public void testConstructUniqueId() {
// Test `fromHexString()`
UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
UniqueId id1 =
UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString());
Assert.assertFalse(id1.isNil());
try {
UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
UniqueId id2 =
UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
// This shouldn't be happened.
Assert.assertTrue(false);
} catch (IllegalArgumentException e) {
@@ -33,16 +35,18 @@ public class UniqueIdTest {
}
// Test `fromByteBuffer()`
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF");
byte[] bytes =
DatatypeConverter.parseHexBinary(
"0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF");
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28);
UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer);
Assert.assertTrue(Arrays.equals(bytes, id4.getBytes()));
Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString());
// Test `genNil()`
UniqueId id6 = UniqueId.NIL;
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
Assert.assertEquals(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
Assert.assertTrue(id6.isNil());
}
}
@@ -30,8 +30,8 @@ public class RayConfigTest {
}
RayConfig rayConfig = RayConfig.create();
Assert.assertEquals(WorkerType.DRIVER, rayConfig.workerMode);
Assert.assertEquals(Collections.singletonList("path/to/ray/job/resource/path"),
rayConfig.codeSearchPath);
Assert.assertEquals(
Collections.singletonList("path/to/ray/job/resource/path"), rayConfig.codeSearchPath);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("one"), 1);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("zero"), 0);
Assert.assertEquals(rayConfig.rayletConfigParameters.get("positive-integer"), 123);
@@ -18,9 +18,7 @@ import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
/**
* Tests for {@link FunctionManager}
*/
/** Tests for {@link FunctionManager} */
public class FunctionManagerTest {
public static Object foo() {
@@ -36,7 +34,6 @@ public class FunctionManagerTest {
public Object bar() {
return null;
}
}
public interface ChildClassInterface {
@@ -44,13 +41,11 @@ public class FunctionManagerTest {
default String interfaceName() {
return getClass().getName();
}
}
public static class ChildClass extends ParentClass implements ChildClassInterface {
public ChildClass() {
}
public ChildClass() {}
@Override
public Object bar() {
@@ -80,17 +75,20 @@ public class FunctionManagerTest {
fooFunc = FunctionManagerTest::foo;
childClassConstructor = ChildClass::new;
childClassBarFunc = ChildClass::bar;
fooDescriptor = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), "foo",
"()Ljava/lang/Object;");
childClassBarDescriptor = new JavaFunctionDescriptor(ChildClass.class.getName(), "bar",
"()Ljava/lang/Object;");
childClassConstructorDescriptor = new JavaFunctionDescriptor(ChildClass.class.getName(),
FunctionManager.CONSTRUCTOR_NAME,
"()V");
overloadFunctionDescriptorInt = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", "(I)Ljava/lang/Object;");
overloadFunctionDescriptorDouble = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", "(D)Ljava/lang/Object;");
fooDescriptor =
new JavaFunctionDescriptor(
FunctionManagerTest.class.getName(), "foo", "()Ljava/lang/Object;");
childClassBarDescriptor =
new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;");
childClassConstructorDescriptor =
new JavaFunctionDescriptor(
ChildClass.class.getName(), FunctionManager.CONSTRUCTOR_NAME, "()V");
overloadFunctionDescriptorInt =
new JavaFunctionDescriptor(
FunctionManagerTest.class.getName(), "overloadFunction", "(I)Ljava/lang/Object;");
overloadFunctionDescriptorDouble =
new JavaFunctionDescriptor(
FunctionManagerTest.class.getName(), "overloadFunction", "(D)Ljava/lang/Object;");
}
@Test
@@ -131,70 +129,87 @@ public class FunctionManagerTest {
Assert.assertEquals(func.getFunctionDescriptor(), childClassConstructorDescriptor);
// Test raise overload exception
Assert.expectThrows(RuntimeException.class, () -> {
functionManager.getFunction(JobId.NIL,
new JavaFunctionDescriptor(FunctionManagerTest.class.getName(),
"overloadFunction", ""));
});
Assert.expectThrows(
RuntimeException.class,
() -> {
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(
FunctionManagerTest.class.getName(), "overloadFunction", ""));
});
}
@Test
public void testInheritance() {
final FunctionManager functionManager = new FunctionManager(null);
// Check inheritance can work and FunctionManager can find method in parent class.
fooDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "foo",
"()Ljava/lang/Object;");
Assert.assertEquals(functionManager.getFunction(JobId.NIL, fooDescriptor)
.executable.getDeclaringClass(), ParentClass.class);
RayFunction fooFunc = functionManager.getFunction(JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "foo",
"()Ljava/lang/Object;"));
fooDescriptor =
new JavaFunctionDescriptor(ParentClass.class.getName(), "foo", "()Ljava/lang/Object;");
Assert.assertEquals(
functionManager.getFunction(JobId.NIL, fooDescriptor).executable.getDeclaringClass(),
ParentClass.class);
RayFunction fooFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "foo", "()Ljava/lang/Object;"));
Assert.assertEquals(fooFunc.executable.getDeclaringClass(), ParentClass.class);
// Check FunctionManager can use method in child class if child class methods overrides methods
// in parent class.
childClassBarDescriptor = new JavaFunctionDescriptor(ParentClass.class.getName(), "bar",
"()Ljava/lang/Object;");
Assert.assertEquals(functionManager.getFunction(JobId.NIL, childClassBarDescriptor)
.executable.getDeclaringClass(), ParentClass.class);
RayFunction barFunc = functionManager.getFunction(JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "bar",
"()Ljava/lang/Object;"));
childClassBarDescriptor =
new JavaFunctionDescriptor(ParentClass.class.getName(), "bar", "()Ljava/lang/Object;");
Assert.assertEquals(
functionManager
.getFunction(JobId.NIL, childClassBarDescriptor)
.executable
.getDeclaringClass(),
ParentClass.class);
RayFunction barFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "bar", "()Ljava/lang/Object;"));
Assert.assertEquals(barFunc.executable.getDeclaringClass(), ChildClass.class);
// Check interface default methods.
RayFunction interfaceNameFunc = functionManager.getFunction(JobId.NIL,
new JavaFunctionDescriptor(ChildClass.class.getName(), "interfaceName",
"()Ljava/lang/String;"));
Assert.assertEquals(interfaceNameFunc.executable.getDeclaringClass(),
ChildClassInterface.class);
RayFunction interfaceNameFunc =
functionManager.getFunction(
JobId.NIL,
new JavaFunctionDescriptor(
ChildClass.class.getName(), "interfaceName", "()Ljava/lang/String;"));
Assert.assertEquals(
interfaceNameFunc.executable.getDeclaringClass(), ChildClassInterface.class);
}
@Test
public void testLoadFunctionTableForClass() {
JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader());
Map<Pair<String, String>, RayFunction> res = functionTable
.loadFunctionsForClass(ChildClass.class.getName());
Map<Pair<String, String>, RayFunction> res =
functionTable.loadFunctionsForClass(ChildClass.class.getName());
// The result should be 4 entries:
// 1, the constructor with signature
// 2, the constructor without signature
// 3, bar with signature
// 4, bar without signature
Assert.assertEquals(res.size(), 11);
Assert.assertTrue(res.containsKey(
ImmutablePair.of(childClassBarDescriptor.name, childClassBarDescriptor.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(childClassConstructorDescriptor.name, childClassConstructorDescriptor.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(childClassBarDescriptor.name, "")));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(childClassConstructorDescriptor.name, "")));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorDouble.name, overloadFunctionDescriptorDouble.signature)));
Assert.assertTrue(res.containsKey(
ImmutablePair.of(overloadFunctionDescriptorInt.name, "")));
Assert.assertTrue(
res.containsKey(
ImmutablePair.of(childClassBarDescriptor.name, childClassBarDescriptor.signature)));
Assert.assertTrue(
res.containsKey(
ImmutablePair.of(
childClassConstructorDescriptor.name, childClassConstructorDescriptor.signature)));
Assert.assertTrue(res.containsKey(ImmutablePair.of(childClassBarDescriptor.name, "")));
Assert.assertTrue(res.containsKey(ImmutablePair.of(childClassConstructorDescriptor.name, "")));
Assert.assertTrue(
res.containsKey(
ImmutablePair.of(
overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature)));
Assert.assertTrue(
res.containsKey(
ImmutablePair.of(
overloadFunctionDescriptorDouble.name,
overloadFunctionDescriptorDouble.signature)));
Assert.assertTrue(res.containsKey(ImmutablePair.of(overloadFunctionDescriptorInt.name, "")));
Pair<String, String> overloadKey = ImmutablePair.of(overloadFunctionDescriptorInt.name, "");
RayFunction func = res.get(overloadKey);
// The function is overloaded.
@@ -230,12 +245,11 @@ public class FunctionManagerTest {
}
// Test loading the function.
JavaFunctionDescriptor descriptor = new JavaFunctionDescriptor(
"DemoApp", "hello", "()Ljava/lang/String;");
final FunctionManager functionManager = new FunctionManager(
Collections.singletonList(codeSearchPath));
JavaFunctionDescriptor descriptor =
new JavaFunctionDescriptor("DemoApp", "hello", "()Ljava/lang/String;");
final FunctionManager functionManager =
new FunctionManager(Collections.singletonList(codeSearchPath));
RayFunction func = functionManager.getFunction(jobId, descriptor);
Assert.assertEquals(func.getFunctionDescriptor(), descriptor);
}
}
@@ -12,8 +12,8 @@ public class SerializerTest {
public void testBasicSerialization() {
// Test serialize / deserialize primitive types with type conversion.
{
Object[] foo = new Object[]{"hello", (byte) 1, 2.0, (short) 3, 4, 5L,
new String[]{"hello", "world"}};
Object[] foo =
new Object[] {"hello", (byte) 1, 2.0, (short) 3, 4, 5L, new String[] {"hello", "world"}};
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
Object[] bar = Serializer.decode(serialized.getLeft(), Object[].class);
Assert.assertTrue(serialized.getRight());
@@ -26,10 +26,12 @@ public class SerializerTest {
}
// Test multidimensional array.
{
Object[][] foo = new Object[][]{{1, 2}, {"3", 4}};
Assert.expectThrows(RuntimeException.class, () -> {
Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class);
});
Object[][] foo = new Object[][] {{1, 2}, {"3", 4}};
Assert.expectThrows(
RuntimeException.class,
() -> {
Object[][] bar = Serializer.decode(Serializer.encode(foo).getLeft(), Integer[][].class);
});
Pair<byte[], Boolean> serialized = Serializer.encode(foo);
Object[][] bar = Serializer.decode(serialized.getLeft(), Object[][].class);
Assert.assertTrue(serialized.getRight());