[java] refine and simplify java worker code structure (#2838)

This commit is contained in:
Hao Chen
2018-09-11 01:48:17 +08:00
committed by Robert Nishihara
parent 588c573d41
commit 8414e413a2
97 changed files with 749 additions and 1344 deletions
@@ -0,0 +1,406 @@
package org.ray.runtime;
import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.function.RayFunc;
import org.ray.api.id.UniqueId;
import org.ray.api.runtime.RayRuntime;
import org.ray.runtime.config.PathConfig;
import org.ray.runtime.config.RayParameters;
import org.ray.runtime.functionmanager.LocalFunctionManager;
import org.ray.runtime.functionmanager.RayMethod;
import org.ray.runtime.functionmanager.RemoteFunctionManager;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.objectstore.ObjectStoreProxy.GetStatus;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.MethodId;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.UniqueIdHelper;
import org.ray.runtime.util.config.ConfigReader;
import org.ray.runtime.util.exception.TaskExecutionException;
import org.ray.runtime.util.logger.RayLog;
/**
* Core functionality to implement Ray APIs.
*/
public abstract class AbstractRayRuntime implements RayRuntime {
public static ConfigReader configReader;
protected static AbstractRayRuntime ins = null;
protected static RayParameters params = null;
private static boolean fromRayInit = false;
protected Worker worker;
protected RayletClient rayletClient;
protected ObjectStoreProxy objectStoreProxy;
protected LocalFunctionManager functions;
protected RemoteFunctionManager remoteFunctionManager;
protected PathConfig pathConfig;
/**
* Actor ID -> local actor instance.
*/
Map<UniqueId, Object> localActors = new HashMap<>();
// app level Ray.init()
// make it private so there is no direct usage but only from Ray.init
private static AbstractRayRuntime init() {
if (ins == null) {
try {
fromRayInit = true;
AbstractRayRuntime.init(null, null);
fromRayInit = false;
} catch (Exception e) {
e.printStackTrace();
throw new RuntimeException("Ray.init failed", e);
}
}
return ins;
}
// engine level AbstractRayRuntime.init(xx, xx)
// updateConfigStr is sth like section1.k1=v1;section2.k2=v2
public static AbstractRayRuntime init(String configPath, String updateConfigStr)
throws Exception {
if (ins == null) {
if (configPath == null) {
configPath = System.getenv("RAY_CONFIG");
if (configPath == null) {
configPath = System.getProperty("ray.config");
}
if (configPath == null) {
throw new Exception(
"Please set config file path in env RAY_CONFIG or property ray.config");
}
}
configReader = new ConfigReader(configPath, updateConfigStr);
AbstractRayRuntime.params = new RayParameters(configReader);
RayLog.init(params.log_dir);
assert RayLog.core != null;
ins = instantiate(params);
assert (ins != null);
if (!fromRayInit) {
Ray.init(); // assign Ray._impl
}
}
return ins;
}
// init with command line args
// --config=ray.config.ini --overwrite=updateConfigStr
public static AbstractRayRuntime init(String[] args) throws Exception {
String config = null;
String updateConfig = null;
for (String arg : args) {
if (arg.startsWith("--config=")) {
config = arg.substring("--config=".length());
} else if (arg.startsWith("--overwrite=")) {
updateConfig = arg.substring("--overwrite=".length());
} else {
throw new RuntimeException("Input argument " + arg
+ " is not recognized, please use --overwrite to merge it into config file");
}
}
return init(config, updateConfig);
}
protected void init(
RayletClient slink,
ObjectStoreLink plink,
RemoteFunctionManager remoteLoader,
PathConfig pathManager
) {
remoteFunctionManager = remoteLoader;
pathConfig = pathManager;
functions = new LocalFunctionManager(remoteLoader);
rayletClient = slink;
objectStoreProxy = new ObjectStoreProxy(plink);
worker = new Worker(this);
}
private static AbstractRayRuntime instantiate(RayParameters params) {
AbstractRayRuntime runtime;
if (params.run_mode.isNativeRuntime()) {
runtime = new RayNativeRuntime();
} else {
runtime = new RayDevRuntime();
}
RayLog.core
.info("Start " + runtime.getClass().getName() + " with " + params.run_mode.toString());
try {
runtime.start(params);
} catch (Exception e) {
RayLog.core.error("Failed to init RayRuntime", e);
System.exit(-1);
}
return runtime;
}
/**
* start runtime.
*/
public abstract void start(RayParameters params) throws Exception;
public static AbstractRayRuntime getInstance() {
return ins;
}
public static RayParameters getParams() {
return params;
}
@Override
public abstract void shutdown();
@Override
public <T> RayObject<T> put(T obj) {
UniqueId objectId = UniqueIdHelper.computePutId(
WorkerContext.currentTask().taskId, WorkerContext.nextPutIndex());
put(objectId, obj);
return new RayObjectImpl<>(objectId);
}
public <T> void put(UniqueId objectId, T obj) {
UniqueId taskId = WorkerContext.currentTask().taskId;
RayLog.core.info("Putting object {}, for task {} ", objectId, taskId);
objectStoreProxy.put(objectId, obj, null);
}
@Override
public <T> T get(UniqueId objectId) throws TaskExecutionException {
List<T> ret = get(ImmutableList.of(objectId));
return ret.get(0);
}
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
UniqueId taskId = WorkerContext.currentTask().taskId;
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
List<List<UniqueId>> fetchBatches =
splitIntoBatches(objectIds, params.worker_fetch_request_size);
for (List<UniqueId> batch : fetchBatches) {
rayletClient.reconstructObjects(batch, true);
}
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, params.default_first_check_timeout_ms, false);
assert ret.size() == numObjectIds;
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map<UniqueId, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
unreadys.put(objectIds.get(i), i);
}
}
wasBlocked = (unreadys.size() > 0);
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
while (unreadys.size() > 0) {
List<UniqueId> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueId>> reconstructBatches =
splitIntoBatches(unreadyList, params.worker_fetch_request_size);
for (List<UniqueId> batch : reconstructBatches) {
rayletClient.reconstructObjects(batch, false);
}
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, params.default_get_check_interval_ms, false);
// Remove any entries for objects we received during this iteration so we
// don't retrieve the same object twice.
for (int i = 0; i < results.size(); i++) {
Pair<T, GetStatus> value = results.get(i);
if (value.getRight() == GetStatus.SUCCESS) {
UniqueId id = unreadyList.get(i);
ret.set(unreadys.get(id), value);
unreadys.remove(id);
}
}
}
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
finalRet.add(value.getLeft());
}
return finalRet;
} catch (TaskExecutionException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
+ " get with Exception", e);
throw e;
} finally {
// If there were objects that we weren't able to get locally, let the local
// scheduler know that we're now unblocked.
if (wasBlocked) {
rayletClient.notifyUnblocked();
}
}
}
@Override
public void free(List<UniqueId> objectIds, boolean localOnly) {
rayletClient.freePlasmaObjects(objectIds, localOnly);
}
private List<List<UniqueId>> splitIntoBatches(List<UniqueId> objectIds, int batchSize) {
List<List<UniqueId>> batches = new ArrayList<>();
int objectsSize = objectIds.size();
for (int i = 0; i < objectsSize; i += batchSize) {
int endIndex = i + batchSize;
List<UniqueId> batchIds = (endIndex < objectsSize)
? objectIds.subList(i, endIndex)
: objectIds.subList(i, objectsSize);
batches.add(batchIds);
}
return batches;
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
return rayletClient.wait(waitList, numReturns, timeoutMs);
}
@Override
public RayObject call(RayFunc func, Object[] args) {
TaskSpec spec = createTaskSpec(func, RayActorImpl.NIL, args, false);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
public RayObject call(RayFunc func, RayActor actor, Object[] args) {
if (!(actor instanceof RayActorImpl)) {
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
}
RayActorImpl actorImpl = (RayActorImpl)actor;
TaskSpec spec = createTaskSpec(func, actorImpl, args, false);
actorImpl.setTaskCursor(spec.returnIds[1]);
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
@Override
@SuppressWarnings("unchecked")
public <T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args) {
TaskSpec spec = createTaskSpec(actorFactoryFunc, RayActorImpl.NIL, args, true);
RayActorImpl<?> actor = new RayActorImpl(spec.returnIds[0]);
actor.increaseTaskCounter();
actor.setTaskCursor(spec.returnIds[0]);
rayletClient.submitTask(spec);
return (RayActor<T>) actor;
}
/**
* Generate the return ids of a task.
*/
private UniqueId[] genReturnIds(UniqueId taskId, int numReturns) {
UniqueId[] ret = new UniqueId[numReturns];
for (int i = 0; i < numReturns; i++) {
ret[i] = UniqueIdHelper.computeReturnId(taskId, i + 1);
}
return ret;
}
/**
* Create the task specification.
* @param func The target remote function.
* @param actor The actor handle. If the task is not an actor task, actor id must be NIL.
* @param args The arguments for the remote function.
* @param isActorCreationTask Whether this task is an actor creation task.
* @return A TaskSpec object.
*/
private TaskSpec createTaskSpec(RayFunc func, RayActorImpl actor, Object[] args,
boolean isActorCreationTask) {
final TaskSpec current = WorkerContext.currentTask();
UniqueId taskId = rayletClient.generateTaskId(current.driverId,
current.taskId,
WorkerContext.nextCallIndex());
int numReturns = actor.getId().isNil() ? 1 : 2;
UniqueId[] returnIds = genReturnIds(taskId, numReturns);
UniqueId actorCreationId = UniqueId.NIL;
if (isActorCreationTask) {
actorCreationId = returnIds[0];
}
MethodId methodId = MethodId.fromSerializedLambda(func);
// NOTE: we append the class name at the end of arguments,
// so that we can look up the method based on the class name.
// TODO(hchen): move class name to task spec.
args = Arrays.copyOf(args, args.length + 1);
args[args.length - 1] = methodId.className;
RayMethod rayMethod = functions.getMethod(
current.driverId, actor.getId(), new UniqueId(methodId.getSha1Hash()), methodId.className
).getRight();
UniqueId funcId = rayMethod.getFuncId();
return new TaskSpec(
current.driverId,
taskId,
current.taskId,
-1,
actor.getId(),
actor.increaseTaskCounter(),
funcId,
ArgumentsBuilder.wrap(args),
returnIds,
actor.getHandleId(),
actorCreationId,
ResourceUtil.getResourcesMapFromArray(rayMethod.remoteAnnotation),
actor.getTaskCursor()
);
}
public void loop() {
worker.loop();
}
public Worker getWorker() {
return worker;
}
public RayletClient getRayletClient() {
return rayletClient;
}
public LocalFunctionManager getLocalFunctionManager() {
return functions;
}
}
@@ -0,0 +1,89 @@
package org.ray.runtime;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import org.ray.api.RayActor;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.Sha1Digestor;
public final class RayActorImpl<T> implements RayActor<T>, Externalizable {
public static final RayActorImpl NIL = new RayActorImpl();
private UniqueId id;
private UniqueId handleId;
/**
* The number of tasks that have been invoked on this actor.
*/
private int taskCounter;
/**
* The unique id of the last return of the last task.
* It's used as a dependency for the next task.
*/
private UniqueId taskCursor;
/**
* The number of times that this actor handle has been forked.
* It's used to make sure ids of actor handles are unique.
*/
private int numForks;
public RayActorImpl() {
this(UniqueId.NIL, UniqueId.NIL);
}
public RayActorImpl(UniqueId id) {
this(id, UniqueId.NIL);
}
public RayActorImpl(UniqueId id, UniqueId handleId) {
this.id = id;
this.handleId = handleId;
this.taskCounter = 0;
this.taskCursor = null;
numForks = 0;
}
@Override
public UniqueId getId() {
return id;
}
@Override
public UniqueId getHandleId() {
return handleId;
}
public void setTaskCursor(UniqueId taskCursor) {
this.taskCursor = taskCursor;
}
public UniqueId getTaskCursor() {
return taskCursor;
}
public int increaseTaskCounter() {
return taskCounter++;
}
private UniqueId computeNextActorHandleId() {
byte[] bytes = Sha1Digestor.digest(handleId.getBytes(), ++numForks);
return new UniqueId(bytes);
}
@Override
public void writeExternal(ObjectOutput out) throws IOException {
out.writeObject(this.id);
out.writeObject(this.computeNextActorHandleId());
out.writeObject(this.taskCursor);
}
@Override
public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
this.id = (UniqueId) in.readObject();
this.handleId = (UniqueId) in.readObject();
this.taskCursor = (UniqueId) in.readObject();
}
}
@@ -0,0 +1,26 @@
package org.ray.runtime;
import org.ray.runtime.config.PathConfig;
import org.ray.runtime.config.RayParameters;
import org.ray.runtime.functionmanager.NopRemoteFunctionManager;
import org.ray.runtime.functionmanager.RemoteFunctionManager;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.raylet.MockRayletClient;
public class RayDevRuntime extends AbstractRayRuntime {
@Override
public void start(RayParameters params) {
PathConfig pathConfig = new PathConfig(configReader);
RemoteFunctionManager rfm = new NopRemoteFunctionManager(params.driver_id);
MockObjectStore store = new MockObjectStore();
MockRayletClient scheduler = new MockRayletClient(this, store);
init(scheduler, store, rfm, pathConfig);
scheduler.setLocalFunctionManager(this.functions);
}
@Override
public void shutdown() {
// nothing to do
}
}
@@ -0,0 +1,158 @@
package org.ray.runtime;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.PlasmaClient;
import org.ray.runtime.config.PathConfig;
import org.ray.runtime.config.RayParameters;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.functionmanager.NativeRemoteFunctionManager;
import org.ray.runtime.functionmanager.NopRemoteFunctionManager;
import org.ray.runtime.functionmanager.RemoteFunctionManager;
import org.ray.runtime.gcs.AddressInfo;
import org.ray.runtime.gcs.KeyValueStoreLink;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.gcs.StateStoreProxy;
import org.ray.runtime.gcs.StateStoreProxyImpl;
import org.ray.runtime.raylet.RayletClient;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
import org.ray.runtime.util.logger.RayLog;
/**
* native runtime for local box and cluster run.
*/
public final class RayNativeRuntime extends AbstractRayRuntime {
static {
System.err.println("Current working directory is " + System.getProperty("user.dir"));
System.loadLibrary("local_scheduler_library_java");
System.loadLibrary("plasma_java");
}
private StateStoreProxy stateStoreProxy;
private KeyValueStoreLink kvStore = null;
private RunManager manager = null;
public RayNativeRuntime() {
}
@Override
public void start(RayParameters params) throws Exception {
boolean isWorker = (params.worker_mode == WorkerMode.WORKER);
PathConfig pathConfig = new PathConfig(configReader);
// initialize params
if (params.redis_address.length() == 0) {
if (isWorker) {
throw new Error("Redis address must be configured under Worker mode.");
}
startOnebox(params, pathConfig);
initStateStore(params.redis_address);
} else {
initStateStore(params.redis_address);
if (!isWorker) {
List<AddressInfo> nodes = stateStoreProxy.getAddressInfo(
params.node_ip_address, params.redis_address, 5);
params.object_store_name = nodes.get(0).storeName;
params.raylet_socket_name = nodes.get(0).rayletSocketName;
}
}
// initialize remote function manager
RemoteFunctionManager funcMgr = params.run_mode.isDevPathManager()
? new NopRemoteFunctionManager(params.driver_id) : new NativeRemoteFunctionManager(kvStore);
// initialize worker context
if (params.worker_mode == WorkerMode.DRIVER) {
// TODO: The relationship between workerID, driver_id and dummy_task.driver_id should be
// recheck carefully
WorkerContext.workerID = params.driver_id;
}
WorkerContext.init(params);
if (params.onebox_delay_seconds_before_run_app_logic > 0) {
for (int i = 0; i < params.onebox_delay_seconds_before_run_app_logic; ++i) {
System.err.println("Pause for debugger, "
+ (params.onebox_delay_seconds_before_run_app_logic - i)
+ " seconds left ...");
Thread.sleep(1000);
}
}
if (params.worker_mode != WorkerMode.NONE) {
// initialize the links
int releaseDelay = AbstractRayRuntime.configReader
.getIntegerValue("ray", "plasma_default_release_delay", 0,
"how many release requests should be delayed in plasma client");
ObjectStoreLink plink = new PlasmaClient(params.object_store_name, "", releaseDelay);
RayletClient rayletClient = new RayletClientImpl(
params.raylet_socket_name,
WorkerContext.currentWorkerId(),
isWorker,
WorkerContext.currentTask().taskId
);
init(rayletClient, plink, funcMgr, pathConfig);
// register
registerWorker(isWorker, params.node_ip_address, params.object_store_name,
params.raylet_socket_name);
}
RayLog.core.info("RayNativeRuntime started with store {}, raylet {}",
params.object_store_name, params.raylet_socket_name);
}
@Override
public void shutdown() {
if (null != manager) {
manager.cleanup(true);
}
}
private void startOnebox(RayParameters params, PathConfig paths) throws Exception {
params.cleanup = true;
manager = new RunManager(params, paths, AbstractRayRuntime.configReader);
manager.startRayHead();
params.redis_address = manager.info().redisAddress;
params.object_store_name = manager.info().localStores.get(0).storeName;
params.raylet_socket_name = manager.info().localStores.get(0).rayletSocketName;
//params.node_ip_address = NetworkUtil.getIpAddress();
}
private void initStateStore(String redisAddress) throws Exception {
kvStore = new RedisClient();
kvStore.setAddr(redisAddress);
stateStoreProxy = new StateStoreProxyImpl(kvStore);
stateStoreProxy.initializeGlobalState();
}
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
String rayletSocketName) {
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
if (!isWorker) {
workerInfo.put("node_ip_address", nodeIpAddress);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
workerInfo.put("plasma_store_socket", storeName);
workerInfo.put("raylet_socket", rayletSocketName);
workerInfo.put("name", System.getProperty("user.dir"));
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
kvStore.hmset("Drivers:" + workerId, workerInfo);
} else {
workerInfo.put("node_ip_address", nodeIpAddress);
workerInfo.put("plasma_store_socket", storeName);
workerInfo.put("raylet_socket", rayletSocketName);
//TODO: b"Workers:" + worker.workerId,
kvStore.hmset("Workers:" + workerId, workerInfo);
}
}
}
@@ -0,0 +1,26 @@
package org.ray.runtime;
import java.io.Serializable;
import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
public final class RayObjectImpl<T> implements RayObject<T>, Serializable {
private final UniqueId id;
public RayObjectImpl(UniqueId id) {
this.id = id;
}
@Override
public T get() {
return Ray.get(id);
}
@Override
public UniqueId getId() {
return id;
}
}
@@ -0,0 +1,71 @@
package org.ray.runtime;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.functionmanager.RayMethod;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.logger.RayLog;
/**
* The worker, which pulls tasks from {@code org.ray.spi.LocalSchedulerProxy} and executes them
* continuously.
*/
public class Worker {
private final AbstractRayRuntime runtime;
public Worker(AbstractRayRuntime runtime) {
this.runtime = runtime;
}
public void loop() {
while (true) {
RayLog.core.info(Thread.currentThread().getName() + ":fetching new task...");
TaskSpec task = runtime.getRayletClient().getTask();
execute(task);
}
}
/**
* Execute a task.
*/
public void execute(TaskSpec spec) {
RayLog.core.info("Executing task {}", spec.taskId);
UniqueId returnId = spec.returnIds[0];
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
try {
// Get method
Pair<ClassLoader, RayMethod> pair = runtime.getLocalFunctionManager().getMethod(
spec.driverId, spec.actorId, spec.functionId, spec.args);
ClassLoader classLoader = pair.getLeft();
RayMethod method = pair.getRight();
// Set context
WorkerContext.prepare(spec, classLoader);
Thread.currentThread().setContextClassLoader(classLoader);
// Get local actor object and arguments.
Object actor = spec.isActorTask() ? runtime.localActors.get(spec.actorId) : null;
Object[] args = ArgumentsBuilder.unwrap(spec, classLoader);
// Execute the task.
Object result;
if (!method.isConstructor()) {
result = method.getMethod().invoke(actor, args);
} else {
result = method.getConstructor().newInstance(args);
}
// Set result
if (!spec.isActorCreationTask()) {
runtime.put(returnId, result);
} else {
runtime.localActors.put(returnId, result);
}
RayLog.core.info("Finished executing task {}", spec.taskId);
} catch (Exception e) {
RayLog.core.error("Error executing task " + spec, e);
runtime.put(returnId, new RayException("Error executing task " + spec, e));
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
}
}
}
@@ -0,0 +1,82 @@
package org.ray.runtime;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayParameters;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
public class WorkerContext {
private static final ThreadLocal<WorkerContext> currentWorkerCtx =
ThreadLocal.withInitial(() -> init(AbstractRayRuntime.getParams()));
/**
* id of worker.
*/
public static UniqueId workerID = UniqueId.randomId();
/**
* current doing task.
*/
private TaskSpec currentTask;
/**
* current app classloader.
*/
private ClassLoader currentClassLoader;
/**
* how many puts done by current task.
*/
private int currentTaskPutCount;
/**
* how many calls done by current task.
*/
private int currentTaskCallCount;
public static WorkerContext init(RayParameters params) {
WorkerContext ctx = new WorkerContext();
currentWorkerCtx.set(ctx);
TaskSpec dummy = new TaskSpec();
dummy.parentTaskId = UniqueId.NIL;
if (params.worker_mode == WorkerMode.DRIVER) {
dummy.taskId = UniqueId.randomId();
} else {
dummy.taskId = UniqueId.NIL;
}
dummy.actorId = UniqueId.NIL;
dummy.driverId = params.driver_id;
prepare(dummy, null);
return ctx;
}
public static void prepare(TaskSpec task, ClassLoader classLoader) {
WorkerContext wc = get();
wc.currentTask = task;
wc.currentTaskPutCount = 0;
wc.currentTaskCallCount = 0;
wc.currentClassLoader = classLoader;
}
public static WorkerContext get() {
return currentWorkerCtx.get();
}
public static TaskSpec currentTask() {
return get().currentTask;
}
public static int nextPutIndex() {
return ++get().currentTaskPutCount;
}
public static int nextCallIndex() {
return ++get().currentTaskCallCount;
}
public static UniqueId currentWorkerId() {
return WorkerContext.workerID;
}
public static ClassLoader currentClassLoader() {
return get().currentClassLoader;
}
}
@@ -0,0 +1,57 @@
package org.ray.runtime.config;
import org.ray.runtime.util.config.AConfig;
import org.ray.runtime.util.config.ConfigReader;
/**
* Path related configurations.
*/
public class PathConfig {
@AConfig(comment = "additional class path for JAVA",
defaultArrayIndirectSectionName = "ray.java.path.classes.source")
public String[] java_class_paths;
@AConfig(comment = "additional JNI library paths for JAVA",
defaultArrayIndirectSectionName = "ray.java.path.jni.build")
public String[] java_jnilib_paths;
@AConfig(comment = "path to ray_functions.txt for the default rewritten functions in ray runtime")
public String java_runtime_rewritten_jars_dir = "";
@AConfig(comment = "path to redis-server")
public String redis_server;
@AConfig(comment = "path to redis module")
public String redis_module;
@AConfig(comment = "path to plasma storage")
public String store;
@AConfig(comment = "path to raylet")
public String raylet;
@AConfig(comment = "path to python directory")
public String python_dir;
@AConfig(comment = "path to log server")
public String log_server;
@AConfig(comment = "path to log server config file")
public String log_server_config;
public PathConfig(ConfigReader config) {
if (config.getBooleanValue("ray.java.start", "deploy", false,
"whether the package is used as a cluster deployment")) {
config.readObject("ray.java.path.deploy", this, this);
} else {
boolean isJar = this.getClass().getResource(this.getClass().getSimpleName() + ".class")
.getFile().split("!")[0].endsWith(".jar");
if (isJar) {
config.readObject("ray.java.path.package", this, this);
} else {
config.readObject("ray.java.path.source", this, this);
}
}
}
}
@@ -0,0 +1,106 @@
package org.ray.runtime.config;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.config.AConfig;
import org.ray.runtime.util.config.ConfigReader;
/**
* Runtime parameters of Ray process.
*/
public class RayParameters {
@AConfig(comment = "worker mode for this process DRIVER | WORKER | NONE")
public WorkerMode worker_mode = WorkerMode.DRIVER;
@AConfig(comment = "run mode for this app SINGLE_PROCESS | SINGLE_BOX | CLUSTER")
public RunMode run_mode = RunMode.SINGLE_PROCESS;
@AConfig(comment = "local node ip")
public String node_ip_address = NetworkUtil.getIpAddress(null);
@AConfig(comment = "primary redis address (e.g., 127.0.0.1:34222")
public String redis_address = "";
@AConfig(comment = "object store name (e.g., /tmp/store1111")
public String object_store_name = "";
@AConfig(comment = "object store rpc listen port")
public int object_store_rpc_port = 32567;
@AConfig(comment = "driver ID when the worker is served as a driver")
public UniqueId driver_id = UniqueId.NIL;
@AConfig(comment = "logging directory")
public String log_dir = "/tmp/raylogs";
@AConfig(comment = "primary redis port")
public int redis_port = 34222;
@AConfig(comment = "number of workers started initially")
public int num_workers = 1;
@AConfig(comment = "redirect err and stdout to files for newly created processes")
public boolean redirect = true;
@AConfig(comment = "whether to start redis shard server in addition to the primary server")
public boolean start_redis_shards = false;
@AConfig(comment = "whether to clean up the processes when there is a process start failure")
public boolean cleanup = false;
@AConfig(comment = "number of redis shard servers to be started")
public int num_redis_shards = 0;
@AConfig(comment = "whether this is a deployment in cluster")
public boolean deploy = false;
@AConfig(comment = "whether this is for python deployment")
public boolean py = false;
@AConfig(comment = "the max bytes of the buffer for task submit")
public int max_submit_task_buffer_size_bytes = 2 * 1024 * 1024;
@AConfig(comment = "default first check timeout(ms)")
public int default_first_check_timeout_ms = 1000;
@AConfig(comment = "default get check rate(ms)")
public int default_get_check_interval_ms = 5000;
@AConfig(comment = "add the jvm parameters for java worker")
public String jvm_parameters = "";
@AConfig(comment = "set the occupied memory(MB) size of object store")
public int object_store_occupied_memory_MB = 1000;
@AConfig(comment = "whether to use supreme failover strategy")
public boolean supremeFO = false;
@AConfig(comment = "whether to disable process failover")
public boolean disable_process_failover = false;
@AConfig(comment = "delay seconds under onebox before app logic for debugging")
public int onebox_delay_seconds_before_run_app_logic = 0;
@AConfig(comment = "raylet socket name (e.g., /tmp/raylet1111")
public String raylet_socket_name = "";
@AConfig(comment = "raylet rpc listen port")
public int raylet_port = 35567;
@AConfig(comment = "worker fetch request size")
public int worker_fetch_request_size = 10000;
@AConfig(comment = "static resource list of this node")
public String static_resources = "";
public RayParameters(ConfigReader config) {
if (null != config) {
String networkInterface = config.getStringValue("ray.java", "network_interface", null,
"Network interface to be specified for host ip address(e.g., en0, eth0), may use "
+ "ifconfig to get options");
node_ip_address = NetworkUtil.getIpAddress(networkInterface);
config.readObject("ray.java.start", this, this);
}
}
}
@@ -0,0 +1,39 @@
package org.ray.runtime.config;
public enum RunMode {
SINGLE_PROCESS(true, false), // dev path, dev runtime
SINGLE_BOX(true, true), // dev path, native runtime
CLUSTER(false, true); // deploy path, naive runtime
RunMode(boolean devPathManager,
boolean nativeRuntime) {
this.devPathManager = devPathManager;
this.nativeRuntime = nativeRuntime;
}
/**
* the jar has add to java -cp, no need to load jar after started.
*/
private final boolean devPathManager;
private final boolean nativeRuntime;
/**
* Getter method for property <tt>devPathManager</tt>.
*
* @return property value of devPathManager
*/
public boolean isDevPathManager() {
return devPathManager;
}
/**
* Getter method for property <tt>nativeRuntime</tt>.
*
* @return property value of nativeRuntime
*/
public boolean isNativeRuntime() {
return nativeRuntime;
}
}
@@ -0,0 +1,7 @@
package org.ray.runtime.config;
public enum WorkerMode {
NONE, // not set
DRIVER, // driver
WORKER // worker
}
@@ -0,0 +1,119 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Preconditions;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.id.UniqueId;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.logger.RayLog;
/**
* local function manager which pulls remote functions on demand.
*/
public class LocalFunctionManager {
private final RemoteFunctionManager remoteLoader;
private final ConcurrentHashMap<UniqueId, FunctionTable> functionTables
= new ConcurrentHashMap<>();
/**
* initialize load function manager using remote function manager to pull remote functions on
* demand.
*/
public LocalFunctionManager(RemoteFunctionManager remoteLoader) {
this.remoteLoader = remoteLoader;
}
private FunctionTable loadDriverFunctions(UniqueId driverId) {
FunctionTable functionTable = functionTables.get(driverId);
if (functionTable == null) {
RayLog.core.info("DriverId " + driverId + " Try to load functions");
ClassLoader classLoader = remoteLoader.loadResource(driverId);
if (classLoader == null) {
throw new RuntimeException(
"Cannot find resource' classLoader for app " + driverId.toString());
}
functionTable = new FunctionTable(classLoader);
functionTables.put(driverId, functionTable);
}
return functionTable;
}
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
UniqueId methodId, String className) {
// assert the driver's resource is load.
FunctionTable functionTable = loadDriverFunctions(driverId);
Preconditions.checkNotNull(functionTable, "driver's resource is not loaded:%s", driverId);
RayMethod method = actorId.isNil() ? functionTable.getTaskMethod(methodId, className)
: functionTable.getActorMethod(methodId, className);
Preconditions
.checkNotNull(method, "method not found, class=%s, methodId=%s, driverId=%s", className,
methodId, driverId);
return Pair.of(functionTable.classLoader, method);
}
/**
* get local method for executing, which pulls information from remote repo on-demand, therefore
* it may block for a while if the related resources (e.g., jars) are not ready on local machine
*/
public Pair<ClassLoader, RayMethod> getMethod(UniqueId driverId, UniqueId actorId,
UniqueId methodId, FunctionArg[] args) {
Preconditions.checkArgument(args.length >= 1, "method's args len %s<=1", args.length);
String className = (String) Serializer.decode(args[args.length - 1].data);
return getMethod(driverId, actorId, methodId, className);
}
/**
* unload the functions when the driver is declared dead.
*/
public synchronized void removeApp(UniqueId driverId) {
FunctionTable funcs = functionTables.get(driverId);
if (funcs != null) {
functionTables.remove(driverId);
remoteLoader.unloadFunctions(driverId);
}
}
private static class FunctionTable {
final ClassLoader classLoader;
final ConcurrentHashMap<String, RayTaskMethods> taskMethods = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, RayActorMethods> actorMethods = new ConcurrentHashMap<>();
FunctionTable(ClassLoader classLoader) {
this.classLoader = classLoader;
}
RayMethod getTaskMethod(UniqueId methodId, String className) {
RayTaskMethods taskMethods = this.taskMethods.get(className);
if (taskMethods == null) {
taskMethods = RayTaskMethods.fromClass(className, classLoader);
RayLog.core.info("create RayTaskMethods: {}", taskMethods);
this.taskMethods.put(className, taskMethods);
}
RayMethod m = taskMethods.functions.get(methodId);
if (m != null) {
return m;
}
// it is a actor static func.
return getActorMethod(methodId, className, true);
}
RayMethod getActorMethod(UniqueId methodId, String className) {
return getActorMethod(methodId, className, false);
}
private RayMethod getActorMethod(UniqueId methodId, String className, boolean isStatic) {
RayActorMethods actorMethods = this.actorMethods.get(className);
if (actorMethods == null) {
actorMethods = RayActorMethods.fromClass(className, classLoader);
RayLog.core.info("create RayActorMethods: {}", actorMethods);
this.actorMethods.put(className, actorMethods);
}
return isStatic ? actorMethods.staticFunctions.get(methodId)
: actorMethods.functions.get(methodId);
}
}
}
@@ -0,0 +1,131 @@
package org.ray.runtime.functionmanager;
import java.io.File;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.concurrent.ConcurrentHashMap;
import net.lingala.zip4j.core.ZipFile;
import org.ray.api.id.UniqueId;
import org.ray.runtime.gcs.KeyValueStoreLink;
import org.ray.runtime.util.FileUtil;
import org.ray.runtime.util.JarLoader;
import org.ray.runtime.util.Sha1Digestor;
import org.ray.runtime.util.SystemUtil;
import org.ray.runtime.util.logger.RayLog;
/**
* native implementation of remote function manager.
*/
public class NativeRemoteFunctionManager implements RemoteFunctionManager {
private final ConcurrentHashMap<UniqueId, ClassLoader> loadedApps = new ConcurrentHashMap<>();
private MessageDigest md;
private final String appDir = System.getProperty("user.dir") + "/apps";
private final KeyValueStoreLink kvStore;
public NativeRemoteFunctionManager(KeyValueStoreLink kvStore) throws NoSuchAlgorithmException {
this.kvStore = kvStore;
md = MessageDigest.getInstance("SHA-1");
File appDir = new File(this.appDir);
if (!appDir.exists()) {
appDir.mkdirs();
}
}
@Override
public UniqueId registerResource(byte[] resourceZip) {
byte[] digest = Sha1Digestor.digest(resourceZip);
assert (digest.length == UniqueId.LENGTH);
UniqueId resourceId = new UniqueId(digest);
// TODO: resources must be saved in persistent store
kvStore.set(resourceId.getBytes(), resourceZip, null);
return resourceId;
}
@Override
public byte[] getResource(UniqueId resourceId) {
return kvStore.get(resourceId.getBytes(), null);
}
@Override
public void unregisterResource(UniqueId resourceId) {
kvStore.delete(resourceId.getBytes(), null);
}
@Override
public void registerApp(UniqueId driverId, UniqueId resourceId) {
kvStore.set("App2ResMap", resourceId.toString(), driverId.toString());
}
@Override
public UniqueId getAppResourceId(UniqueId driverId) {
return UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString()));
}
@Override
public void unregisterApp(UniqueId driverId) {
kvStore.delete("App2ResMap", driverId.toString());
}
@Override
public ClassLoader loadResource(UniqueId driverId) {
ClassLoader classLoader = loadedApps.get(driverId);
if (classLoader == null) {
synchronized (this) {
classLoader = loadedApps.get(driverId);
if (classLoader == null) {
classLoader = initLoadedApps(driverId);
}
}
}
return classLoader;
}
private ClassLoader initLoadedApps(UniqueId driverId) {
try {
RayLog.core.info("initLoadedApps" + driverId.toString());
ClassLoader cl = loadedApps.get(driverId);
if (cl == null) {
UniqueId resId = UniqueId.fromHexString(kvStore.get("App2ResMap", driverId.toString()));
byte[] res = getResource(resId);
if (res == null) {
throw new RuntimeException("get resource null, the resId " + resId.toString());
}
RayLog.core.info("get resource of " + resId.toString() + ", result len " + res.length);
String resPath =
appDir + "/" + driverId.toString() + "/" + String.valueOf(SystemUtil.pid());
File dir = new File(resPath);
if (!dir.exists()) {
dir.mkdirs();
}
String zipPath = resPath + ".zip";
RayLog.rapp.info("unzip app file: zipPath " + zipPath + " resPath " + resPath);
FileUtil.bytesToFile(res, zipPath);
ZipFile zipFile = new ZipFile(zipPath);
zipFile.extractAll(resPath);
cl = JarLoader.loadJars(resPath, false);
loadedApps.put(driverId, cl);
}
return cl;
} catch (Exception e) {
RayLog.rapp.error("load function for " + driverId + " failed, ex = " + e.getMessage(), e);
return null;
}
}
@Override
public synchronized void unloadFunctions(UniqueId driverId) {
ClassLoader cl = loadedApps.get(driverId);
try {
JarLoader.unloadJars(cl);
} catch (Exception e) {
RayLog.rapp.error("unload function for " + driverId + " failed, ex = " + e.getMessage(), e);
}
}
}
@@ -0,0 +1,58 @@
package org.ray.runtime.functionmanager;
import org.ray.api.id.UniqueId;
/**
* mock version of remote function manager using local loaded jars.
*/
public class NopRemoteFunctionManager implements RemoteFunctionManager {
public NopRemoteFunctionManager(UniqueId driverId) {
//onLoad(driverId, Agent.hookedMethods);
//Agent.consumers.add(m -> { this.onLoad(m); });
}
@Override
public UniqueId registerResource(byte[] resourceZip) {
return null;
// nothing to do
}
@Override
public byte[] getResource(UniqueId resourceId) {
return null;
}
@Override
public void unregisterResource(UniqueId resourceId) {
// nothing to do
}
@Override
public void registerApp(UniqueId driverId, UniqueId resourceId) {
// nothing to do
}
@Override
public UniqueId getAppResourceId(UniqueId driverId) {
return null;
// nothing to do
}
@Override
public void unregisterApp(UniqueId driverId) {
// nothing to do
}
@Override
public ClassLoader loadResource(UniqueId driverId) {
//assert (startupDriverId().equals(driverId));
return this.getClass().getClassLoader();
}
@Override
public void unloadFunctions(UniqueId driverId) {
// never
//assert (startupDriverId().equals(driverId));
}
}
@@ -0,0 +1,68 @@
package org.ray.runtime.functionmanager;
import com.google.common.base.Preconditions;
import java.lang.reflect.Executable;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
public final class RayActorMethods {
public final Class clazz;
public final RayRemote remoteAnnotation;
public final Map<UniqueId, RayMethod> functions;
/**
* the static function in Actor, call as task.
*/
public final Map<UniqueId, RayMethod> staticFunctions;
private RayActorMethods(Class clazz, RayRemote remoteAnnotation,
Map<UniqueId, RayMethod> functions, Map<UniqueId, RayMethod> staticFunctions) {
this.clazz = clazz;
this.remoteAnnotation = remoteAnnotation;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
this.staticFunctions = Collections.unmodifiableMap(new HashMap<>(staticFunctions));
}
public static RayActorMethods fromClass(String className, ClassLoader classLoader) {
try {
Class clazz = Class.forName(className, true, classLoader);
RayRemote remoteAnnotation = (RayRemote) clazz.getAnnotation(RayRemote.class);
Preconditions.checkNotNull(remoteAnnotation,
"%s must be annotated with @RayRemote", className);
List<Executable> executables = new ArrayList<>(Arrays.asList(clazz.getDeclaredMethods()));
Map<UniqueId, RayMethod> functions = new HashMap<>();
Map<UniqueId, RayMethod> staticFunctions = new HashMap<>();
for (Executable e : executables) {
RayMethod rayMethod = RayMethod.from(e, remoteAnnotation);
if (Modifier.isStatic(e.getModifiers())) {
staticFunctions.put(rayMethod.getFuncId(), rayMethod);
} else {
functions.put(rayMethod.getFuncId(), rayMethod);
}
}
return new RayActorMethods(clazz, remoteAnnotation, functions, staticFunctions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayActorMethods from " + className, e);
}
}
@Override
public String toString() {
return String
.format("RayActorMethods:%s, funcNum=%s:{%s}, sfuncNum=%s:{%s}", clazz, functions.size(),
functions.values(),
staticFunctions.size(), staticFunctions.values());
}
}
@@ -0,0 +1,57 @@
package org.ray.runtime.functionmanager;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.MethodId;
/**
* method info.
*/
public class RayMethod {
public final Executable invokable;
public final String fullName;
public final RayRemote remoteAnnotation;
private final UniqueId funcId;
private RayMethod(Executable e, RayRemote remoteAnnotation, UniqueId funcId) {
this.invokable = e;
this.remoteAnnotation = remoteAnnotation;
this.funcId = funcId;
fullName = e.getDeclaringClass().getName() + "." + e.getName();
}
public static RayMethod from(Executable e, RayRemote parentRemoteAnnotation) {
RayRemote remoteAnnotation = e.getAnnotation(RayRemote.class);
MethodId mid = MethodId.fromExecutable(e);
UniqueId funcId = new UniqueId(mid.getSha1Hash());
RayMethod method = new RayMethod(e,
remoteAnnotation != null ? remoteAnnotation : parentRemoteAnnotation,
funcId);
return method;
}
public boolean isConstructor() {
return invokable instanceof Constructor;
}
public Constructor<?> getConstructor() {
return (Constructor<?>) invokable;
}
public Method getMethod() {
return (Method) invokable;
}
@Override
public String toString() {
return fullName;
}
public UniqueId getFuncId() {
return funcId;
}
}
@@ -0,0 +1,56 @@
package org.ray.runtime.functionmanager;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.id.UniqueId;
public final class RayTaskMethods {
public final Class clazz;
public final Map<UniqueId, RayMethod> functions;
public RayTaskMethods(Class clazz,
Map<UniqueId, RayMethod> functions) {
this.clazz = clazz;
this.functions = Collections.unmodifiableMap(new HashMap<>(functions));
}
public static RayTaskMethods fromClass(String clazzName, ClassLoader classLoader) {
try {
Class clazz = Class.forName(clazzName, true, classLoader);
List<Executable> executables = new ArrayList<>();
executables.addAll(Arrays.asList(clazz.getDeclaredMethods()));
executables.addAll(Arrays.asList(clazz.getConstructors()));
Map<UniqueId, RayMethod> functions = new HashMap<>(executables.size());
for (Executable e : executables) {
// This executable must be either a constructor or a static method.
if (!(e instanceof Constructor)
&& !Modifier.isStatic(e.getModifiers())) {
continue;
}
e.setAccessible(true);
RayMethod rayMethod = RayMethod.from(e, null);
functions.put(rayMethod.getFuncId(), rayMethod);
}
return new RayTaskMethods(clazz, functions);
} catch (Exception e) {
throw new RuntimeException("failed to get RayTaskMethods from " + clazzName, e);
}
}
@Override
public String toString() {
return String
.format("RayTaskMethods:%s, funcNum=%s:{%s}", clazz, functions.size(), functions.values());
}
}
@@ -0,0 +1,64 @@
package org.ray.runtime.functionmanager;
import org.ray.api.id.UniqueId;
/**
* register and load functions from function table.
*/
public interface RemoteFunctionManager {
/*
* register <resourceId, resource> mapping, and upload resource.
* this function is invoked by app proxy or other stand-alone tools it should detect for
* duplication first though
*
* @param resourceZip a directory zip from @JarRewriter
* @return SHA-1 hash of the content
*/
UniqueId registerResource(byte[] resourceZip);
/**
* download resource content.
*
* @return resource content
*/
byte[] getResource(UniqueId resourceId);
/**
* remove resource by its hash id
* be careful of invoking this function to make sure it is no longer used.
*
* @param resourceId SHA-1 hash of the resource zip bytes
*/
void unregisterResource(UniqueId resourceId);
/*
* register the <driver, resource> mapping to repo,
* this function is invoked by whoever initiates the driver id
*/
void registerApp(UniqueId driverId, UniqueId resourceId);
/**
* get the resourceId of one app.
*
* @return resourceId of the app driver
*/
UniqueId getAppResourceId(UniqueId driverId);
/*
* unregister <dirver, resource> mapping
* this function is called when the driver exits or detected dead
*/
void unregisterApp(UniqueId driverId);
/**
* load resource.
*/
ClassLoader loadResource(UniqueId driverId);
/**
* unload functions for this driver
* this function is used by the workers on demand when a driver is dead.
*/
void unloadFunctions(UniqueId driverId);
}
@@ -0,0 +1,18 @@
package org.ray.runtime.gcs;
/**
* Represents information of different process roles.
*/
public class AddressInfo {
public String managerName;
public String storeName;
public String schedulerName;
public String rayletSocketName;
public int managerPort;
public int workerCount;
public String managerRpcAddr;
public String storeRpcAddr;
public String schedulerRpcAddr;
public String rayletRpcAddr;
}
@@ -0,0 +1,138 @@
package org.ray.runtime.gcs;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Ray K/V abstraction.
*/
public interface KeyValueStoreLink {
/**
* set address of kv store: format "ip:port".
*/
void setAddr(String addr);
/**
* check if the kvstore client connected.
*/
void checkConnected() throws Exception;
/**
* set Key-value into State Store, such as redis.
*
* @param key the key to set
* @param value the value to set
* @param field the field is being set when the item is a hash If it is not hash field should be
* filled with null
* @return If the key(or field) already exists, and the StateStoreSet just produced an update of
* the value, 0 is returned, otherwise if a new key(or field) is created 1 is returned.
*/
Long set(final String key, final String value, final String field);
Long set(final byte[] key, final byte[] value, final byte[] field);
/**
* multi hash value set.
*
* @param key the key in kvStore
* @param hash the multi hash value to be set
* @return Return OK or Exception if hash is empty
*/
String hmset(final String key, final Map<String, String> hash);
String hmset(final byte[] key, final Map<byte[], byte[]> hash);
/**
* multi hash value get.
*
* @param key the key in kvStore
* @param fields the fields to be get
* @return Multi Bulk Reply specifically a list of all the values associated with the specified
* fields, in the same order of the request.
*/
List<String> hmget(final String key, final String... fields);
List<byte[]> hmget(final byte[] key, final byte[]... fields);
/**
* get the value of the specified key from State Store.
*
* @param key the key to get
* @param field the field is being got when the item is a hash If it is not hash field should be
* filled with null
* @return Bulk reply If the key does not exist null is returned.
*/
String get(final String key, final String field);
byte[] get(final byte[] key, final byte[] field);
/**
* delete the key(or the specified field of the key) from State Store.
*
* @param key the key to delete
* @param field the field is to delete when the item is a hash If it is not hash field should be
* filled with null
* @return Integer reply, specifically: an integer greater than 0 if the key(or the field) was
* removed 0 if none of the specified key existed
*/
Long delete(final String key, final String field);
Long delete(final byte[] key, final byte[] field);
/**
* get all keys which fit the pattern.
*/
Set<byte[]> keys(final byte[] pattern);
/**
* get all keys which fit the pattern.
*/
Set<String> keys(String pattern);
/**
* get all hash of the key.
*/
Map<byte[], byte[]> hgetAll(final byte[] key);
/**
* Return the specified elements of the list stored at the specified key.
*
* @return Multi bulk reply, specifically a list of elements in the specified range.
*/
List<String> lrange(final String key, final long start, final long end);
/**
* Return the set of elements of the sorted set stored at the specified key.
* @param key The specified key you want to query.
* @param start The start index of the range.
* @param end The end index of the range.
* @return The set of elements you queried.
*/
Set<byte[]> zrange(byte[] key, long start, long end);
/**
* Rpush.
* @return Integer reply, specifically, the number of elements inside the list after the push
* operation.
*/
Long rpush(final String key, final String... strings);
Long rpush(final byte[] key, final byte[]... strings);
/**
* Publish.
* @param channel To which channel the message will be published
* @param message What to publish
* @return the number of clients that received the message
*/
Long publish(final String channel, final String message);
Long publish(byte[] channel, byte[] message);
Object getImpl();
byte[] sendCommand(String command, int commandType, byte[] objectId);
}
@@ -0,0 +1,229 @@
package org.ray.runtime.gcs;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
public class RedisClient implements KeyValueStoreLink {
private String redisAddress;
private JedisPool jedisPool;
private int handle = 0;
public RedisClient() {
}
public RedisClient(String addr) {
setAddr(addr);
}
@Override
public synchronized void setAddr(String addr) {
if (StringUtils.isEmpty(redisAddress)) {
redisAddress = addr;
String[] ipPort = addr.split(":");
JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
//TODO NUM maybe equels to the thread num
jedisPoolConfig.setMaxTotal(1);
jedisPool = new JedisPool(jedisPoolConfig, ipPort[0], Integer.parseInt(ipPort[1]), 30000);
}
}
@Override
public void checkConnected() throws Exception {
if (jedisPool == null) {
throw new Exception("the GlobalState API can't be used before ray init.");
}
}
@Override
public Long set(final String key, final String value, final String field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
jedis.set(key, value);
return (long) 1;
} else {
return jedis.hset(key, field, value);
}
}
}
@Override
public Long set(byte[] key, byte[] value, byte[] field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
jedis.set(key, value);
return (long) 1;
} else {
return jedis.hset(key, field, value);
}
}
}
@Override
public String hmset(String key, Map<String, String> hash) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hmset(key, hash);
}
}
@Override
public String hmset(byte[] key, Map<byte[], byte[]> hash) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hmset(key, hash);
}
}
@Override
public List<String> hmget(String key, String... fields) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hmget(key, fields);
}
}
@Override
public List<byte[]> hmget(byte[] key, byte[]... fields) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hmget(key, fields);
}
}
@Override
public String get(final String key, final String field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
return jedis.get(key);
} else {
return jedis.hget(key, field);
}
}
}
@Override
public byte[] get(byte[] key, byte[] field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
return jedis.get(key);
} else {
return jedis.hget(key, field);
}
}
}
@Override
public Long delete(final String key, final String field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
return jedis.del(key);
} else {
return jedis.hdel(key, field);
}
}
}
@Override
public Long delete(byte[] key, byte[] field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
return jedis.del(key);
} else {
return jedis.hdel(key, field);
}
}
}
@Override
public Set<byte[]> keys(byte[] pattern) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.keys(pattern);
}
}
@Override
public Set<String> keys(String pattern) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.keys(pattern);
}
}
@Override
public Map<byte[], byte[]> hgetAll(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.hgetAll(key);
}
}
@Override
public List<String> lrange(String key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
}
@Override
public Set<byte[]> zrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.zrange(key, start, end);
}
}
@Override
public Long rpush(String key, String... strings) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.rpush(key, strings);
}
}
@Override
public Long rpush(byte[] key, byte[]... strings) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.rpush(key, strings);
}
}
@Override
public Long publish(String channel, String message) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.publish(channel, message);
}
}
@Override
public Long publish(byte[] channel, byte[] message) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.publish(channel, message);
}
}
@Override
public Object getImpl() {
return jedisPool;
}
@Override
public byte[] sendCommand(String command, int commandType, byte[] objectId) {
if (handle == 0) {
String[] ipPort = redisAddress.split(":");
handle = connect(ipPort[0], Integer.parseInt(ipPort[1]));
}
return execute_command(handle, command, commandType, objectId);
}
private static native int connect(String redisAddress, int port);
private static native void disconnect(int handle);
private static native byte[] execute_command(int handle,
String command, int commandType, byte[] objectId);
}
@@ -0,0 +1,36 @@
package org.ray.runtime.gcs;
import java.util.List;
import java.util.Set;
/**
* Proxy client for state store, for instance redis.
*/
public interface StateStoreProxy {
/**
* setStore.
* @param rayKvStore the underlying kv store used to store states
*/
void setStore(KeyValueStoreLink rayKvStore);
/**
* initialize the store.
*/
void initializeGlobalState() throws Exception;
/**
* keys.
* @param pattern filter which keys you are interested in.
*/
Set<String> keys(final String pattern);
/**
* getAddressInfo.
* @return list of address information
*/
List<AddressInfo> getAddressInfo(final String nodeIpAddress,
final String redisAddress,
int numRetries);
}
@@ -0,0 +1,161 @@
package org.ray.runtime.gcs;
import java.io.UnsupportedEncodingException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.ray.api.id.UniqueId;
import org.ray.runtime.generated.ClientTableData;
import org.ray.runtime.util.NetworkUtil;
import org.ray.runtime.util.logger.RayLog;
/**
* A class used to interface with the Ray control state.
*/
public class StateStoreProxyImpl implements StateStoreProxy {
public KeyValueStoreLink rayKvStore;
public ArrayList<KeyValueStoreLink> shardStoreList = new ArrayList<>();
public StateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
@Override
public void setStore(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
@Override
public synchronized void initializeGlobalState() throws Exception {
String es;
checkConnected();
String s = rayKvStore.get("NumRedisShards", null);
if (s == null) {
throw new Exception("NumRedisShards not found in redis.");
}
int numRedisShards = Integer.parseInt(s);
if (numRedisShards < 1) {
es = String.format("Expected at least one Redis shard, found %d", numRedisShards);
throw new Exception(es);
}
List<String> ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1);
Set<String> distinctIpAddress = new HashSet<String>(ipAddressPorts);
if (distinctIpAddress.size() != numRedisShards) {
es = String.format("Expected %d Redis shard addresses, found2 %d.", numRedisShards,
distinctIpAddress.size());
throw new Exception(es);
}
shardStoreList.clear();
for (String ipPort : distinctIpAddress) {
shardStoreList.add(new RedisClient(ipPort));
}
}
public void checkConnected() throws Exception {
rayKvStore.checkConnected();
}
@Override
public synchronized Set<String> keys(final String pattern) {
Set<String> allKeys = new HashSet<>();
Set<String> tmpKey;
for (KeyValueStoreLink ashardStoreList : shardStoreList) {
tmpKey = ashardStoreList.keys(pattern);
allKeys.addAll(tmpKey);
}
return allKeys;
}
@Override
public List<AddressInfo> getAddressInfo(final String nodeIpAddress,
final String redisAddress,
int numRetries) {
int count = 0;
while (count < numRetries) {
try {
return doGetAddressInfo(nodeIpAddress, redisAddress);
} catch (Exception e) {
try {
RayLog.core.warn("Error occurred in StateStoreProxyImpl getAddressInfo, "
+ (numRetries - count) + " retries remaining", e);
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException ie) {
RayLog.core.error("error at StateStoreProxyImpl getAddressInfo", e);
throw new RuntimeException(e);
}
}
count++;
}
throw new RuntimeException("cannot get address info from state store");
}
/**
* Get address info of one node from primary redis.
* This method only tries to get address info once, without any retry.
*
* @param nodeIpAddress Usually local ip address.
* @param redisAddress The primary redis address.
* @return A list of SchedulerInfo which contains node manager or local scheduler address info.
* @throws Exception No redis client exception.
*/
public List<AddressInfo> doGetAddressInfo(final String nodeIpAddress,
final String redisAddress) throws Exception {
if (this.rayKvStore == null) {
throw new Exception("no redis client when use doGetAddressInfo");
}
List<AddressInfo> schedulerInfo = new ArrayList<>();
byte[] prefix = "CLIENT".getBytes();
byte[] postfix = UniqueId.genNil().getBytes();
byte[] clientKey = new byte[prefix.length + postfix.length];
System.arraycopy(prefix, 0, clientKey, 0, prefix.length);
System.arraycopy(postfix, 0, clientKey, prefix.length, postfix.length);
Set<byte[]> clients = rayKvStore.zrange(clientKey, 0, -1);
for (byte[] clientMessage : clients) {
ByteBuffer bb = ByteBuffer.wrap(clientMessage);
ClientTableData client = ClientTableData.getRootAsClientTableData(bb);
String clientNodeIpAddress = client.nodeManagerAddress();
String localIpAddress = NetworkUtil.getIpAddress(null);
String redisIpAddress = redisAddress.substring(0, redisAddress.indexOf(':'));
boolean headNodeAddress = "127.0.0.1".equals(clientNodeIpAddress)
&& Objects.equals(redisIpAddress, localIpAddress);
boolean notHeadNodeAddress = Objects.equals(clientNodeIpAddress, nodeIpAddress);
if (headNodeAddress || notHeadNodeAddress) {
AddressInfo si = new AddressInfo();
si.storeName = client.objectStoreSocketName();
si.rayletSocketName = client.rayletSocketName();
si.managerRpcAddr = client.nodeManagerAddress();
si.managerPort = client.nodeManagerPort();
schedulerInfo.add(si);
}
}
return schedulerInfo;
}
protected String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException {
return new String(bs, charset);
}
protected byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException {
if (str != null) {
return str.getBytes(charset);
}
return null;
}
}
@@ -0,0 +1,58 @@
// automatically generated by the FlatBuffers compiler, do not modify
package org.ray.runtime.generated;
import java.nio.*;
import java.lang.*;
import com.google.flatbuffers.*;
@SuppressWarnings("unused")
public final class Arg extends Table {
public static Arg getRootAsArg(ByteBuffer _bb) { return getRootAsArg(_bb, new Arg()); }
public static Arg getRootAsArg(ByteBuffer _bb, Arg obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
public Arg __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
public String objectIds(int j) { int o = __offset(4); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int objectIdsLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; }
public String data() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer dataAsByteBuffer() { return __vector_as_bytebuffer(6, 1); }
public ByteBuffer dataInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); }
public static int createArg(FlatBufferBuilder builder,
int object_idsOffset,
int dataOffset) {
builder.startObject(2);
Arg.addData(builder, dataOffset);
Arg.addObjectIds(builder, object_idsOffset);
return Arg.endArg(builder);
}
public static void startArg(FlatBufferBuilder builder) { builder.startObject(2); }
public static void addObjectIds(FlatBufferBuilder builder, int objectIdsOffset) { builder.addOffset(0, objectIdsOffset, 0); }
public static int createObjectIdsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startObjectIdsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addData(FlatBufferBuilder builder, int dataOffset) { builder.addOffset(1, dataOffset, 0); }
public static int endArg(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;
}
//this is manually added to avoid encoding/decoding cost as our object id is a byte array
// instead of a string
public ByteBuffer objectIdAsByteBuffer(int j) {
int o = __offset(4);
if (o == 0) {
return null;
}
int offset = __vector(o) + j * 4;
offset += bb.getInt(offset);
ByteBuffer src = bb.duplicate().order(ByteOrder.LITTLE_ENDIAN);
int length = src.getInt(offset);
src.position(offset + 4);
src.limit(offset + 4 + length);
return src;
}
}
@@ -0,0 +1,79 @@
package org.ray.runtime.generated;
// automatically generated by the FlatBuffers compiler, do not modify
import java.nio.*;
import java.lang.*;
import com.google.flatbuffers.*;
@SuppressWarnings("unused")
public final class ClientTableData extends Table {
public static ClientTableData getRootAsClientTableData(ByteBuffer _bb) { return getRootAsClientTableData(_bb, new ClientTableData()); }
public static ClientTableData getRootAsClientTableData(ByteBuffer _bb, ClientTableData obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
public ClientTableData __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
public String clientId() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer clientIdAsByteBuffer() { return __vector_as_bytebuffer(4, 1); }
public ByteBuffer clientIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); }
public String nodeManagerAddress() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer nodeManagerAddressAsByteBuffer() { return __vector_as_bytebuffer(6, 1); }
public ByteBuffer nodeManagerAddressInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); }
public String rayletSocketName() { int o = __offset(8); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer rayletSocketNameAsByteBuffer() { return __vector_as_bytebuffer(8, 1); }
public ByteBuffer rayletSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 8, 1); }
public String objectStoreSocketName() { int o = __offset(10); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer objectStoreSocketNameAsByteBuffer() { return __vector_as_bytebuffer(10, 1); }
public ByteBuffer objectStoreSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 10, 1); }
public int nodeManagerPort() { int o = __offset(12); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public int objectManagerPort() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public boolean isInsertion() { int o = __offset(16); return o != 0 ? 0!=bb.get(o + bb_pos) : false; }
public String resourcesTotalLabel(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int resourcesTotalLabelLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
public double resourcesTotalCapacity(int j) { int o = __offset(20); return o != 0 ? bb.getDouble(__vector(o) + j * 8) : 0; }
public int resourcesTotalCapacityLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
public ByteBuffer resourcesTotalCapacityAsByteBuffer() { return __vector_as_bytebuffer(20, 8); }
public ByteBuffer resourcesTotalCapacityInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 8); }
public static int createClientTableData(FlatBufferBuilder builder,
int client_idOffset,
int node_manager_addressOffset,
int raylet_socket_nameOffset,
int object_store_socket_nameOffset,
int node_manager_port,
int object_manager_port,
boolean is_insertion,
int resources_total_labelOffset,
int resources_total_capacityOffset) {
builder.startObject(9);
ClientTableData.addResourcesTotalCapacity(builder, resources_total_capacityOffset);
ClientTableData.addResourcesTotalLabel(builder, resources_total_labelOffset);
ClientTableData.addObjectManagerPort(builder, object_manager_port);
ClientTableData.addNodeManagerPort(builder, node_manager_port);
ClientTableData.addObjectStoreSocketName(builder, object_store_socket_nameOffset);
ClientTableData.addRayletSocketName(builder, raylet_socket_nameOffset);
ClientTableData.addNodeManagerAddress(builder, node_manager_addressOffset);
ClientTableData.addClientId(builder, client_idOffset);
ClientTableData.addIsInsertion(builder, is_insertion);
return ClientTableData.endClientTableData(builder);
}
public static void startClientTableData(FlatBufferBuilder builder) { builder.startObject(9); }
public static void addClientId(FlatBufferBuilder builder, int clientIdOffset) { builder.addOffset(0, clientIdOffset, 0); }
public static void addNodeManagerAddress(FlatBufferBuilder builder, int nodeManagerAddressOffset) { builder.addOffset(1, nodeManagerAddressOffset, 0); }
public static void addRayletSocketName(FlatBufferBuilder builder, int rayletSocketNameOffset) { builder.addOffset(2, rayletSocketNameOffset, 0); }
public static void addObjectStoreSocketName(FlatBufferBuilder builder, int objectStoreSocketNameOffset) { builder.addOffset(3, objectStoreSocketNameOffset, 0); }
public static void addNodeManagerPort(FlatBufferBuilder builder, int nodeManagerPort) { builder.addInt(4, nodeManagerPort, 0); }
public static void addObjectManagerPort(FlatBufferBuilder builder, int objectManagerPort) { builder.addInt(5, objectManagerPort, 0); }
public static void addIsInsertion(FlatBufferBuilder builder, boolean isInsertion) { builder.addBoolean(6, isInsertion, false); }
public static void addResourcesTotalLabel(FlatBufferBuilder builder, int resourcesTotalLabelOffset) { builder.addOffset(7, resourcesTotalLabelOffset, 0); }
public static int createResourcesTotalLabelVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startResourcesTotalLabelVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addResourcesTotalCapacity(FlatBufferBuilder builder, int resourcesTotalCapacityOffset) { builder.addOffset(8, resourcesTotalCapacityOffset, 0); }
public static int createResourcesTotalCapacityVector(FlatBufferBuilder builder, double[] data) { builder.startVector(8, data.length, 8); for (int i = data.length - 1; i >= 0; i--) builder.addDouble(data[i]); return builder.endVector(); }
public static void startResourcesTotalCapacityVector(FlatBufferBuilder builder, int numElems) { builder.startVector(8, numElems, 8); }
public static int endClientTableData(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;
}
}
@@ -0,0 +1,38 @@
// automatically generated by the FlatBuffers compiler, do not modify
package org.ray.runtime.generated;
import java.nio.*;
import java.lang.*;
import com.google.flatbuffers.*;
@SuppressWarnings("unused")
public final class ResourcePair extends Table {
public static ResourcePair getRootAsResourcePair(ByteBuffer _bb) { return getRootAsResourcePair(_bb, new ResourcePair()); }
public static ResourcePair getRootAsResourcePair(ByteBuffer _bb, ResourcePair obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
public ResourcePair __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
public String key() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer keyAsByteBuffer() { return __vector_as_bytebuffer(4, 1); }
public ByteBuffer keyInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); }
public double value() { int o = __offset(6); return o != 0 ? bb.getDouble(o + bb_pos) : 0.0; }
public static int createResourcePair(FlatBufferBuilder builder,
int keyOffset,
double value) {
builder.startObject(2);
ResourcePair.addValue(builder, value);
ResourcePair.addKey(builder, keyOffset);
return ResourcePair.endResourcePair(builder);
}
public static void startResourcePair(FlatBufferBuilder builder) { builder.startObject(2); }
public static void addKey(FlatBufferBuilder builder, int keyOffset) { builder.addOffset(0, keyOffset, 0); }
public static void addValue(FlatBufferBuilder builder, double value) { builder.addDouble(1, value, 0.0); }
public static int endResourcePair(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;
}
}
@@ -0,0 +1,132 @@
// automatically generated by the FlatBuffers compiler, do not modify
package org.ray.runtime.generated;
import java.nio.*;
import java.lang.*;
import com.google.flatbuffers.*;
@SuppressWarnings("unused")
public final class TaskInfo extends Table {
public static TaskInfo getRootAsTaskInfo(ByteBuffer _bb) { return getRootAsTaskInfo(_bb, new TaskInfo()); }
public static TaskInfo getRootAsTaskInfo(ByteBuffer _bb, TaskInfo obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
public TaskInfo __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
public String driverId() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer driverIdAsByteBuffer() { return __vector_as_bytebuffer(4, 1); }
public ByteBuffer driverIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); }
public String taskId() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer taskIdAsByteBuffer() { return __vector_as_bytebuffer(6, 1); }
public ByteBuffer taskIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); }
public String parentTaskId() { int o = __offset(8); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer parentTaskIdAsByteBuffer() { return __vector_as_bytebuffer(8, 1); }
public ByteBuffer parentTaskIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 8, 1); }
public int parentCounter() { int o = __offset(10); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public String actorCreationId() { int o = __offset(12); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer actorCreationIdAsByteBuffer() { return __vector_as_bytebuffer(12, 1); }
public ByteBuffer actorCreationIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 12, 1); }
public String actorCreationDummyObjectId() { int o = __offset(14); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer actorCreationDummyObjectIdAsByteBuffer() { return __vector_as_bytebuffer(14, 1); }
public ByteBuffer actorCreationDummyObjectIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); }
public String actorId() { int o = __offset(16); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer actorIdAsByteBuffer() { return __vector_as_bytebuffer(16, 1); }
public ByteBuffer actorIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 16, 1); }
public String actorHandleId() { int o = __offset(18); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer actorHandleIdAsByteBuffer() { return __vector_as_bytebuffer(18, 1); }
public ByteBuffer actorHandleIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 18, 1); }
public int actorCounter() { int o = __offset(20); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public boolean isActorCheckpointMethod() { int o = __offset(22); return o != 0 ? 0!=bb.get(o + bb_pos) : false; }
public String functionId() { int o = __offset(24); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer functionIdAsByteBuffer() { return __vector_as_bytebuffer(24, 1); }
public ByteBuffer functionIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 24, 1); }
public Arg args(int j) { return args(new Arg(), j); }
public Arg args(Arg obj, int j) { int o = __offset(26); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int argsLength() { int o = __offset(26); return o != 0 ? __vector_len(o) : 0; }
public String returns(int j) { int o = __offset(28); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int returnsLength() { int o = __offset(28); return o != 0 ? __vector_len(o) : 0; }
public ResourcePair requiredResources(int j) { return requiredResources(new ResourcePair(), j); }
public ResourcePair requiredResources(ResourcePair obj, int j) { int o = __offset(30); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; }
public int requiredResourcesLength() { int o = __offset(30); return o != 0 ? __vector_len(o) : 0; }
public int language() { int o = __offset(32); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public static int createTaskInfo(FlatBufferBuilder builder,
int driver_idOffset,
int task_idOffset,
int parent_task_idOffset,
int parent_counter,
int actor_creation_idOffset,
int actor_creation_dummy_object_idOffset,
int actor_idOffset,
int actor_handle_idOffset,
int actor_counter,
boolean is_actor_checkpoint_method,
int function_idOffset,
int argsOffset,
int returnsOffset,
int required_resourcesOffset,
int language) {
builder.startObject(15);
TaskInfo.addLanguage(builder, language);
TaskInfo.addRequiredResources(builder, required_resourcesOffset);
TaskInfo.addReturns(builder, returnsOffset);
TaskInfo.addArgs(builder, argsOffset);
TaskInfo.addFunctionId(builder, function_idOffset);
TaskInfo.addActorCounter(builder, actor_counter);
TaskInfo.addActorHandleId(builder, actor_handle_idOffset);
TaskInfo.addActorId(builder, actor_idOffset);
TaskInfo.addActorCreationDummyObjectId(builder, actor_creation_dummy_object_idOffset);
TaskInfo.addActorCreationId(builder, actor_creation_idOffset);
TaskInfo.addParentCounter(builder, parent_counter);
TaskInfo.addParentTaskId(builder, parent_task_idOffset);
TaskInfo.addTaskId(builder, task_idOffset);
TaskInfo.addDriverId(builder, driver_idOffset);
TaskInfo.addIsActorCheckpointMethod(builder, is_actor_checkpoint_method);
return TaskInfo.endTaskInfo(builder);
}
public static void startTaskInfo(FlatBufferBuilder builder) { builder.startObject(15); }
public static void addDriverId(FlatBufferBuilder builder, int driverIdOffset) { builder.addOffset(0, driverIdOffset, 0); }
public static void addTaskId(FlatBufferBuilder builder, int taskIdOffset) { builder.addOffset(1, taskIdOffset, 0); }
public static void addParentTaskId(FlatBufferBuilder builder, int parentTaskIdOffset) { builder.addOffset(2, parentTaskIdOffset, 0); }
public static void addParentCounter(FlatBufferBuilder builder, int parentCounter) { builder.addInt(3, parentCounter, 0); }
public static void addActorCreationId(FlatBufferBuilder builder, int actorCreationIdOffset) { builder.addOffset(4, actorCreationIdOffset, 0); }
public static void addActorCreationDummyObjectId(FlatBufferBuilder builder, int actorCreationDummyObjectIdOffset) { builder.addOffset(5, actorCreationDummyObjectIdOffset, 0); }
public static void addActorId(FlatBufferBuilder builder, int actorIdOffset) { builder.addOffset(6, actorIdOffset, 0); }
public static void addActorHandleId(FlatBufferBuilder builder, int actorHandleIdOffset) { builder.addOffset(7, actorHandleIdOffset, 0); }
public static void addActorCounter(FlatBufferBuilder builder, int actorCounter) { builder.addInt(8, actorCounter, 0); }
public static void addIsActorCheckpointMethod(FlatBufferBuilder builder, boolean isActorCheckpointMethod) { builder.addBoolean(9, isActorCheckpointMethod, false); }
public static void addFunctionId(FlatBufferBuilder builder, int functionIdOffset) { builder.addOffset(10, functionIdOffset, 0); }
public static void addArgs(FlatBufferBuilder builder, int argsOffset) { builder.addOffset(11, argsOffset, 0); }
public static int createArgsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startArgsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addReturns(FlatBufferBuilder builder, int returnsOffset) { builder.addOffset(12, returnsOffset, 0); }
public static int createReturnsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startReturnsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addRequiredResources(FlatBufferBuilder builder, int requiredResourcesOffset) { builder.addOffset(13, requiredResourcesOffset, 0); }
public static int createRequiredResourcesVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startRequiredResourcesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addLanguage(FlatBufferBuilder builder, int language) { builder.addInt(14, language, 0); }
public static int endTaskInfo(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;
}
//this is manually added to avoid encoding/decoding cost as our object
//id is a byte array instead of a string
public ByteBuffer returnsAsByteBuffer(int j) {
int o = __offset(28);
if (o == 0) {
return null;
}
int offset = __vector(o) + j * 4;
offset += bb.getInt(offset);
ByteBuffer src = bb.duplicate().order(ByteOrder.LITTLE_ENDIAN);
int length = src.getInt(offset);
src.position(offset + 4);
src.limit(offset + 4 + length);
return src;
}
}
@@ -0,0 +1,14 @@
// automatically generated by the FlatBuffers compiler, do not modify
package org.ray.runtime.generated;
public final class TaskLanguage {
private TaskLanguage() { }
public static final int PYTHON = 0;
public static final int JAVA = 1;
public static final String[] names = { "PYTHON", "JAVA", };
public static String name(int e) { return names[e]; }
}
@@ -0,0 +1,110 @@
package org.ray.runtime.objectstore;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.ray.api.id.UniqueId;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.raylet.MockRayletClient;
import org.ray.runtime.util.logger.RayLog;
/**
* A mock implementation of {@code org.ray.spi.ObjectStoreLink}, which use Map to store data.
*/
public class MockObjectStore implements ObjectStoreLink {
private final Map<UniqueId, byte[]> data = new ConcurrentHashMap<>();
private final Map<UniqueId, byte[]> metadata = new ConcurrentHashMap<>();
private MockRayletClient scheduler = null;
@Override
public void put(byte[] objectId, byte[] value, byte[] metadataValue) {
if (objectId == null || objectId.length == 0 || value == null) {
RayLog.core
.error(logPrefix() + "cannot put null: " + objectId + "," + Arrays.toString(value));
System.exit(-1);
}
UniqueId uniqueId = new UniqueId(objectId);
data.put(uniqueId, value);
metadata.put(uniqueId, metadataValue);
if (scheduler != null) {
scheduler.onObjectPut(uniqueId);
}
}
@Override
public List<byte[]> get(byte[][] objectIds, int timeoutMs, boolean isMetadata) {
final Map<UniqueId, byte[]> dataMap = isMetadata ? metadata : data;
ArrayList<byte[]> rets = new ArrayList<>(objectIds.length);
for (byte[] objId : objectIds) {
UniqueId uniqueId = new UniqueId(objId);
RayLog.core.info(logPrefix() + " is notified for objectid " + uniqueId);
rets.add(dataMap.get(uniqueId));
}
return rets;
}
@Override
public List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns) {
ArrayList<byte[]> rets = new ArrayList<>();
for (byte[] objId : objectIds) {
//tod test
if (data.containsKey(new UniqueId(objId))) {
rets.add(objId);
}
}
return rets;
}
@Override
public byte[] hash(byte[] objectId) {
return null;
}
@Override
public void fetch(byte[][] objectIds) {
}
@Override
public long evict(long numBytes) {
return 0;
}
@Override
public void release(byte[] objectId) {
return;
}
@Override
public boolean contains(byte[] objectId) {
return data.containsKey(new UniqueId(objectId));
}
private String logPrefix() {
return WorkerContext.currentTask().taskId + "-" + getUserTrace() + " -> ";
}
private String getUserTrace() {
StackTraceElement[] stes = Thread.currentThread().getStackTrace();
int k = 1;
while (stes[k].getClassName().startsWith("org.ray")
&& !stes[k].getClassName().contains("test")) {
k++;
}
return stes[k].getFileName() + ":" + stes[k].getLineNumber();
}
public boolean isObjectReady(UniqueId id) {
return data.containsKey(id);
}
public void registerScheduler(MockRayletClient s) {
scheduler = s;
}
}
@@ -0,0 +1,86 @@
package org.ray.runtime.objectstore;
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.id.UniqueId;
import org.ray.runtime.WorkerContext;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.exception.TaskExecutionException;
/**
* Object store proxy, which handles serialization and deserialization, and utilize a {@code
* org.ray.spi.ObjectStoreLink} to actually store data.
*/
public class ObjectStoreProxy {
private final ObjectStoreLink store;
private final int getTimeoutMs = 1000;
public ObjectStoreProxy(ObjectStoreLink store) {
this.store = store;
}
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
throws TaskExecutionException {
return get(objectId, getTimeoutMs, isMetadata);
}
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws TaskExecutionException {
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, WorkerContext.currentClassLoader());
store.release(id.getBytes());
if (t instanceof TaskExecutionException) {
throw (TaskExecutionException) t;
}
return Pair.of(t, GetStatus.SUCCESS);
} else {
return Pair.of(null, GetStatus.FAILED);
}
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMetadata)
throws TaskExecutionException {
return get(objectIds, getTimeoutMs, isMetadata);
}
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws TaskExecutionException {
List<byte[]> objs = store.get(getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, WorkerContext.currentClassLoader());
store.release(ids.get(i).getBytes());
if (t instanceof TaskExecutionException) {
throw (TaskExecutionException) t;
}
ret.add(Pair.of(t, GetStatus.SUCCESS));
} else {
ret.add(Pair.of(null, GetStatus.FAILED));
}
}
return ret;
}
private static byte[][] getIdBytes(List<UniqueId> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
public void put(UniqueId id, Object obj, Object metadata) {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}
public enum GetStatus {
SUCCESS, FAILED
}
}
@@ -0,0 +1,97 @@
package org.ray.runtime.raylet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.functionmanager.LocalFunctionManager;
import org.ray.runtime.objectstore.MockObjectStore;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
/**
* A mock implementation of RayletClient, used in single process mode.
*/
public class MockRayletClient implements RayletClient {
private final Map<UniqueId, Map<UniqueId, TaskSpec>> waitTasks = new ConcurrentHashMap<>();
private final MockObjectStore store;
private LocalFunctionManager functions = null;
private final RayDevRuntime runtime;
public MockRayletClient(RayDevRuntime runtime, MockObjectStore store) {
this.runtime = runtime;
this.store = store;
store.registerScheduler(this);
}
public void setLocalFunctionManager(LocalFunctionManager mgr) {
functions = mgr;
}
public void onObjectPut(UniqueId id) {
Map<UniqueId, TaskSpec> bucket = waitTasks.get(id);
if (bucket != null) {
waitTasks.remove(id);
for (TaskSpec ts : bucket.values()) {
submitTask(ts);
}
}
}
@Override
public void submitTask(TaskSpec task) {
UniqueId id = isTaskReady(task);
if (id == null) {
runtime.getWorker().execute(task);
} else {
Map<UniqueId, TaskSpec> bucket = waitTasks
.computeIfAbsent(id, id_ -> new ConcurrentHashMap<>());
bucket.put(id, task);
}
}
private UniqueId isTaskReady(TaskSpec spec) {
for (FunctionArg arg : spec.args) {
if (arg.id != null) {
if (!store.isObjectReady(arg.id)) {
return arg.id;
}
}
}
return null;
}
@Override
public TaskSpec getTask() {
throw new RuntimeException("invalid execution flow here");
}
@Override
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
}
@Override
public void notifyUnblocked() {
}
@Override
public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) {
throw new RuntimeException("Not implemented here.");
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
throw new RuntimeException("Not implemented here.");
}
@Override
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
return;
}
}
@@ -0,0 +1,27 @@
package org.ray.runtime.raylet;
import java.util.List;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.runtime.task.TaskSpec;
/**
* Client to the Raylet backend.
*/
public interface RayletClient {
void submitTask(TaskSpec task);
TaskSpec getTask();
void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly);
void notifyUnblocked();
UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex);
<T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs);
void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly);
}
@@ -0,0 +1,295 @@
package org.ray.runtime.raylet;
import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.generated.Arg;
import org.ray.runtime.generated.ResourcePair;
import org.ray.runtime.generated.TaskInfo;
import org.ray.runtime.generated.TaskLanguage;
import org.ray.runtime.task.FunctionArg;
import org.ray.runtime.task.TaskSpec;
import org.ray.runtime.util.UniqueIdHelper;
import org.ray.runtime.util.logger.RayLog;
public class RayletClientImpl implements RayletClient {
private static ThreadLocal<ByteBuffer> _taskBuffer = ThreadLocal.withInitial(() -> {
ByteBuffer bb = ByteBuffer
.allocateDirect(AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes);
bb.order(ByteOrder.LITTLE_ENDIAN);
return bb;
});
private long client = 0;
public RayletClientImpl(String schedulerSockName, UniqueId clientId,
boolean isWorker, UniqueId driverId) {
client = nativeInit(schedulerSockName, clientId.getBytes(),
isWorker, driverId.getBytes());
}
@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitFor, int numReturns, int timeoutMs) {
List<UniqueId> ids = new ArrayList<>();
for (RayObject<T> element : waitFor) {
ids.add(element.getId());
}
boolean[] ready = nativeWaitObject(client, getIdBytes(ids), numReturns, timeoutMs, false);
List<RayObject<T>> readyList = new ArrayList<>();
List<RayObject<T>> unreadyList = new ArrayList<>();
for (int i = 0; i < ready.length; i++) {
if (ready[i]) {
readyList.add(waitFor.get(i));
} else {
unreadyList.add(waitFor.get(i));
}
}
return new WaitResult<>(readyList, unreadyList);
}
@Override
public void submitTask(TaskSpec task) {
RayLog.core.debug("Submitting task: {}", task);
ByteBuffer info = taskSpec2Info(task);
byte[] cursorId = null;
if (!task.actorId.isNil()) {
cursorId = task.cursorId.getBytes();
}
nativeSubmitTask(client, cursorId, info, info.position(), info.remaining());
}
@Override
public TaskSpec getTask() {
byte[] bytes = nativeGetTask(client);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return taskInfo2Spec(bb);
}
@Override
public void reconstructObjects(List<UniqueId> objectIds, boolean fetchOnly) {
if (RayLog.core.isInfoEnabled()) {
RayLog.core.info("Reconstructing objects for task {}, object IDs are {}",
UniqueIdHelper.computeTaskId(objectIds.get(0)), objectIds);
}
nativeReconstructObjects(client, getIdBytes(objectIds), fetchOnly);
}
@Override
public UniqueId generateTaskId(UniqueId driverId, UniqueId parentTaskId, int taskIndex) {
byte[] bytes = nativeGenerateTaskId(driverId.getBytes(), parentTaskId.getBytes(), taskIndex);
return new UniqueId(bytes);
}
@Override
public void notifyUnblocked() {
nativeNotifyUnblocked(client);
}
@Override
public void freePlasmaObjects(List<UniqueId> objectIds, boolean localOnly) {
byte[][] objectIdsArray = getIdBytes(objectIds);
nativeFreePlasmaObjects(client, objectIdsArray, localOnly);
}
public static TaskSpec taskInfo2Spec(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
TaskInfo info = TaskInfo.getRootAsTaskInfo(bb);
TaskSpec spec = new TaskSpec();
spec.driverId = UniqueId.fromByteBuffer(info.driverIdAsByteBuffer());
spec.taskId = UniqueId.fromByteBuffer(info.taskIdAsByteBuffer());
spec.parentTaskId = UniqueId.fromByteBuffer(info.parentTaskIdAsByteBuffer());
spec.parentCounter = info.parentCounter();
spec.actorId = UniqueId.fromByteBuffer(info.actorIdAsByteBuffer());
spec.actorCounter = info.actorCounter();
spec.createActorId = UniqueId.fromByteBuffer(info.actorCreationIdAsByteBuffer());
spec.functionId = UniqueId.fromByteBuffer(info.functionIdAsByteBuffer());
List<FunctionArg> args = new ArrayList<>();
for (int i = 0; i < info.argsLength(); i++) {
UniqueId id = null;
byte[] data = null;
Arg sarg = info.args(i);
int idCount = sarg.objectIdsLength();
if (idCount > 0) {
ByteBuffer lbb = sarg.objectIdAsByteBuffer(0);
assert (lbb != null && lbb.remaining() > 0);
id = UniqueId.fromByteBuffer(lbb);
}
ByteBuffer lbb = sarg.dataAsByteBuffer();
if (lbb != null && lbb.remaining() > 0) {
// TODO: how to avoid memory copy
data = new byte[lbb.remaining()];
lbb.get(data);
}
args.add(new FunctionArg(id, data));
}
spec.args = args.toArray(new FunctionArg[0]);
List<UniqueId> rids = new ArrayList<>();
for (int i = 0; i < info.returnsLength(); i++) {
ByteBuffer lbb = info.returnsAsByteBuffer(i);
assert (lbb != null && lbb.remaining() > 0);
rids.add(UniqueId.fromByteBuffer(lbb));
}
spec.returnIds = rids.toArray(new UniqueId[0]);
return spec;
}
public static ByteBuffer taskSpec2Info(TaskSpec task) {
ByteBuffer bb = _taskBuffer.get();
bb.clear();
FlatBufferBuilder fbb = new FlatBufferBuilder(bb);
final int driverIdOffset = fbb.createString(task.driverId.toByteBuffer());
final int taskIdOffset = fbb.createString(task.taskId.toByteBuffer());
final int parentTaskIdOffset = fbb.createString(task.parentTaskId.toByteBuffer());
final int parentCounter = task.parentCounter;
final int actorCreateIdOffset = fbb.createString(task.createActorId.toByteBuffer());
final int actorCreateDummyIdOffset = fbb.createString(UniqueId.NIL.toByteBuffer());
final int actorIdOffset = fbb.createString(task.actorId.toByteBuffer());
final int actorHandleIdOffset = fbb.createString(task.actorHandleId.toByteBuffer());
final int actorCounter = task.actorCounter;
final int functionIdOffset = fbb.createString(task.functionId.toByteBuffer());
// serialize args
int[] argsOffsets = new int[task.args.length];
for (int i = 0; i < argsOffsets.length; i++) {
int objectIdOffset = 0;
int dataOffset = 0;
if (task.args[i].id != null) {
int[] idOffsets = new int[] {
fbb.createString(task.args[i].id.toByteBuffer())
};
objectIdOffset = fbb.createVectorOfTables(idOffsets);
} else {
objectIdOffset = fbb.createVectorOfTables(new int[0]);
}
if (task.args[i].data != null) {
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
}
argsOffsets[i] = Arg.createArg(fbb, objectIdOffset, dataOffset);
}
int argsOffset = fbb.createVectorOfTables(argsOffsets);
// serialize returns
int returnCount = task.returnIds.length;
int[] returnsOffsets = new int[returnCount];
for (int k = 0; k < returnCount; k++) {
returnsOffsets[k] = fbb.createString(task.returnIds[k].toByteBuffer());
}
int returnsOffset = fbb.createVectorOfTables(returnsOffsets);
// serialize required resources
// The required_resources vector indicates the quantities of the different
// resources required by this task. The index in this vector corresponds to
// the resource type defined in the ResourceIndex enum. For example,
int[] requiredResourcesOffsets = new int[task.resources.size()];
int i = 0;
for (Map.Entry<String, Double> entry : task.resources.entrySet()) {
int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes()));
requiredResourcesOffsets[i] =
ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue());
i++;
}
int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets);
int root = TaskInfo.createTaskInfo(
fbb, driverIdOffset, taskIdOffset,
parentTaskIdOffset, parentCounter,
actorCreateIdOffset, actorCreateDummyIdOffset,
actorIdOffset, actorHandleIdOffset, actorCounter,
false, functionIdOffset,
argsOffset, returnsOffset, requiredResourcesOffset, TaskLanguage.JAVA);
fbb.finish(root);
ByteBuffer buffer = fbb.dataBuffer();
if (buffer.remaining() > AbstractRayRuntime.getParams().max_submit_task_buffer_size_bytes) {
RayLog.core.error(
"Allocated buffer is not enough to transfer the task specification: " + AbstractRayRuntime
.getParams().max_submit_task_buffer_size_bytes + " vs " + buffer.remaining());
assert (false);
}
return buffer;
}
private static byte[][] getIdBytes(List<UniqueId> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
public void destroy() {
nativeDestroy(client);
}
/// Native method declarations.
///
/// If you change the signature of any native methods, please re-generate
/// the C++ header file and update the C++ implementation accordingly:
///
/// Suppose that $Dir is your ray root directory.
/// 1) pushd $Dir/java/runtime/target/classes
/// 2) javah -classpath .:$Dir/java/api/target/classes org.ray.runtime.raylet.RayletClientImpl
/// 3) clang-format -i org_ray_runtime_raylet_RayletClientImpl.h
/// 4) cp org_ray_runtime_raylet_RayletClientImpl.h $Dir/src/local_scheduler/lib/java/
/// 5) vim $Dir/src/local_scheduler/lib/java/org_ray_runtime_raylet_RayletClientImpl.cc
/// 6) popd
private static native long nativeInit(String localSchedulerSocket, byte[] workerId,
boolean isWorker, byte[] driverTaskId);
private static native void nativeSubmitTask(long client, byte[] cursorId, ByteBuffer taskBuff,
int pos, int taskSize);
// return TaskInfo (in FlatBuffer)
private static native byte[] nativeGetTask(long client);
private static native void nativeDestroy(long client);
private static native void nativeReconstructObjects(long client, byte[][] objectIds,
boolean fetchOnly);
private static native void nativeNotifyUnblocked(long client);
private static native void nativePutObject(long client, byte[] taskId, byte[] objectId);
private static native boolean[] nativeWaitObject(long conn, byte[][] objectIds,
int numReturns, int timeout, boolean waitLocal);
private static native byte[] nativeGenerateTaskId(byte[] driverId, byte[] parentTaskId,
int taskIndex);
private static native void nativeFreePlasmaObjects(long conn, byte[][] objectIds,
boolean localOnly);
}
@@ -0,0 +1,13 @@
package org.ray.runtime.runner;
public class ProcessInfo {
public Process process;
public String[] cmd;
public RunInfo.ProcessType type;
public String name;
public String redisAddress;
public String ip;
public boolean redirect;
public boolean cleanup;
}
@@ -0,0 +1,45 @@
package org.ray.runtime.runner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.ray.runtime.gcs.AddressInfo;
/**
* information of kinds of processes.
*/
public class RunInfo {
public String redisAddress;
public List<String> redisShards;
public List<AddressInfo> localStores = new ArrayList<>();
public ArrayList<List<ProcessInfo>> allProcesses = initProcessInfoArray();
public ArrayList<List<Process>> toBeCleanedProcesses = initProcessArray();
public ArrayList<ProcessInfo> deadProcess = new ArrayList<>();
private ArrayList<List<Process>> initProcessArray() {
ArrayList<List<Process>> processes = new ArrayList<>();
for (ProcessType ignored : ProcessType.values()) {
processes.add(Collections.synchronizedList(new ArrayList<>()));
}
return processes;
}
private ArrayList<List<ProcessInfo>> initProcessInfoArray() {
ArrayList<List<ProcessInfo>> processes = new ArrayList<>();
for (ProcessType ignored : ProcessType.values()) {
processes.add(Collections.synchronizedList(new ArrayList<>()));
}
return processes;
}
public enum ProcessType {
PT_WORKER,
PT_PLASMA_STORE,
PT_REDIS_SERVER,
PT_WEB_UI,
PT_RAYLET,
PT_DRIVER
}
}
@@ -0,0 +1,527 @@
package org.ray.runtime.runner;
import com.google.common.collect.ImmutableList;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.PathConfig;
import org.ray.runtime.config.RayParameters;
import org.ray.runtime.gcs.AddressInfo;
import org.ray.runtime.runner.RunInfo.ProcessType;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
import org.ray.runtime.util.config.ConfigReader;
import org.ray.runtime.util.logger.RayLog;
import redis.clients.jedis.Jedis;
/**
* Ray service management on one box.
*/
public class RunManager {
private static final DateTimeFormatter DATE_TIME_FORMATTER =
DateTimeFormatter.ofPattern("Y-m-d_H-M-S");
private RayParameters params;
private PathConfig paths;
private ConfigReader configReader;
private RunInfo runInfo = new RunInfo();
private Random random = new Random();
public RunManager(RayParameters params, PathConfig paths, ConfigReader configReader) {
this.params = params;
this.paths = paths;
this.configReader = configReader;
}
private static boolean killProcess(Process p) {
if (p.isAlive()) {
p.destroy();
return true;
} else {
return false;
}
}
public RunInfo info() {
return runInfo;
}
public void startRayHead() throws Exception {
if (params.redis_address.length() != 0) {
throw new Exception("Redis address must be empty in head node.");
}
if (params.num_redis_shards <= 0) {
params.num_redis_shards = 1;
}
params.start_redis_shards = true;
startRayProcesses();
}
public void startRayNode() throws Exception {
if (params.redis_address.length() == 0) {
throw new Exception("Redis address cannot be empty in non-head node.");
}
if (params.num_redis_shards != 0) {
throw new Exception("Number of redis shards should be zero in non-head node.");
}
params.start_redis_shards = false;
startRayProcesses();
}
public Process startDriver(String mainClass, String redisAddress, UniqueId driverId,
String logDir, String ip,
String driverClass, String driverArgs, String additonalClassPaths,
String additionalConfigs) {
String driverConfigs =
"ray.java.start.driver_id=" + driverId + ";ray.java.start.driver_class=" + driverClass;
if (driverArgs != null) {
driverConfigs += ";ray.java.start.driver_args=" + driverArgs;
}
if (null != additionalConfigs) {
additionalConfigs += ";" + driverConfigs;
} else {
additionalConfigs = driverConfigs;
}
return startJavaProcess(
RunInfo.ProcessType.PT_DRIVER,
mainClass,
additonalClassPaths,
additionalConfigs,
"",
ip,
redisAddress,
false,
false,
null
);
}
private Process startJavaProcess(RunInfo.ProcessType pt, String mainClass,
String additonalClassPaths, String additionalConfigs,
String additionalJvmArgs, String ip, String
redisAddr, boolean redirect,
boolean cleanup, String agentlibAddr) {
String cmd = buildJavaProcessCommand(pt, mainClass, additonalClassPaths, additionalConfigs,
additionalJvmArgs, ip, redisAddr, agentlibAddr);
return startProcess(cmd.split(" "), null, pt, "", redisAddr, ip, redirect, cleanup);
}
private String buildJavaProcessCommand(
RunInfo.ProcessType pt, String mainClass, String additionalClassPaths,
String additionalConfigs,
String additionalJvmArgs, String ip, String redisAddr, String agentlibAddr) {
String cmd = "java -ea -noverify " + params.jvm_parameters + " ";
if (agentlibAddr != null && !agentlibAddr.equals("")) {
cmd += " -agentlib:jdwp=transport=dt_socket,address=" + agentlibAddr + ",server=y,suspend=n";
}
cmd += " -Djava.library.path=" + StringUtil.mergeArray(paths.java_jnilib_paths, ":");
cmd += " -classpath " + StringUtil.mergeArray(paths.java_class_paths, ":");
if (additionalClassPaths.length() > 0) {
cmd += ":" + additionalClassPaths;
}
if (additionalJvmArgs.length() > 0) {
cmd += " " + additionalJvmArgs;
}
cmd += " " + mainClass;
String section = "ray.java.start.";
cmd += " --config=" + configReader.filePath();
cmd += " --overwrite="
+ section + "node_ip_address=" + ip + ";"
+ section + "redis_address=" + redisAddr + ";"
+ section + "log_dir=" + params.log_dir + ";"
+ section + "run_mode=" + params.run_mode;
if (additionalConfigs.length() > 0) {
cmd += ";" + additionalConfigs;
}
return cmd;
}
private Process startProcess(String[] cmd, Map<String, String> env, RunInfo.ProcessType type,
String name,
String redisAddress, String ip, boolean redirect,
boolean cleanup) {
ProcessBuilder builder;
List<String> newCommand = Arrays.asList(cmd);
builder = new ProcessBuilder(newCommand);
if (redirect) {
int logId = random.nextInt(10000);
String date = DATE_TIME_FORMATTER.format(LocalDateTime.now());
String stdout = String.format("%s/%s-%s-%05d.out", params.log_dir, name, date, logId);
String stderr = String.format("%s/%s-%s-%05d.err", params.log_dir, name, date, logId);
builder.redirectOutput(new File(stdout));
builder.redirectError(new File(stderr));
recordLogFilesInRedis(redisAddress, ip, ImmutableList.of(stdout, stderr));
}
if (env != null && !env.isEmpty()) {
builder.environment().putAll(env);
}
Process p = null;
try {
p = builder.start();
} catch (IOException e) {
RayLog.core.error("Failed to start process {}", name, e);
return null;
}
RayLog.core.info("Process {} started", name);
if (cleanup) {
runInfo.toBeCleanedProcesses.get(type.ordinal()).add(p);
}
ProcessInfo processInfo = new ProcessInfo();
processInfo.cmd = cmd;
processInfo.type = type;
processInfo.name = name;
processInfo.redisAddress = redisAddress;
processInfo.ip = ip;
processInfo.redirect = redirect;
processInfo.cleanup = cleanup;
processInfo.process = p;
runInfo.allProcesses.get(type.ordinal()).add(processInfo);
return p;
}
private void recordLogFilesInRedis(String redisAddress, String nodeIpAddress,
List<String> logFiles) {
if (redisAddress != null && !redisAddress.isEmpty() && nodeIpAddress != null
&& !nodeIpAddress.isEmpty() && logFiles.size() > 0) {
String[] ipPort = redisAddress.split(":");
Jedis jedisClient = new Jedis(ipPort[0], Integer.parseInt(ipPort[1]));
String logFileListKey = String.format("LOG_FILENAMES:{%s}", nodeIpAddress);
for (String logfile : logFiles) {
jedisClient.rpush(logFileListKey, logfile);
}
jedisClient.close();
}
}
private void startRayProcesses() {
Jedis redisClient = null;
RayLog.core.info("start ray processes @ " + params.node_ip_address + " ...");
// start primary redis
if (params.redis_address.length() == 0) {
List<String> primaryShards = startRedis(
params.node_ip_address, params.redis_port, 1, params.redirect, params.cleanup);
params.redis_address = primaryShards.get(0);
String[] args = params.redis_address.split(":");
redisClient = new Jedis(args[0], Integer.parseInt(args[1]));
// Register the number of Redis shards in the primary shard, so that clients
// know how many redis shards to expect under RedisShards.
redisClient.set("NumRedisShards", Integer.toString(params.num_redis_shards));
} else {
String[] args = params.redis_address.split(":");
redisClient = new Jedis(args[0], Integer.parseInt(args[1]));
}
runInfo.redisAddress = params.redis_address;
// start redis shards
if (params.start_redis_shards) {
runInfo.redisShards = startRedis(
params.node_ip_address, params.redis_port + 1, params.num_redis_shards,
params.redirect,
params.cleanup);
// Store redis shard information in the primary redis shard.
for (int i = 0; i < runInfo.redisShards.size(); i++) {
String addr = runInfo.redisShards.get(i);
redisClient.rpush("RedisShards", addr);
}
}
redisClient.close();
AddressInfo info = new AddressInfo();
// Start object store
int rpcPort = params.object_store_rpc_port;
String storeName = "/tmp/plasma_store" + rpcPort;
startObjectStore(0, info,
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
Map<String, Double> staticResources =
ResourceUtil.getResourcesMapFromString(params.static_resources);
//Start raylet
startRaylet(storeName, info, params.num_workers,
params.redis_address,
params.node_ip_address, params.redirect, staticResources, params.cleanup);
runInfo.localStores.add(info);
if (!checkAlive()) {
cleanup(true);
throw new RuntimeException("Start Ray processes failed");
}
}
private boolean checkAlive() {
RunInfo.ProcessType[] types = RunInfo.ProcessType.values();
for (int i = 0; i < types.length; i++) {
ProcessInfo p;
for (int j = 0; j < runInfo.allProcesses.get(i).size(); ) {
p = runInfo.allProcesses.get(i).get(j);
if (!p.process.isAlive()) {
RayLog.core.error("Process " + p.process.hashCode() + " is not alive!" + " Process Type "
+ types[i].name());
runInfo.deadProcess.add(p);
runInfo.allProcesses.get(i).remove(j);
} else {
j++;
}
}
}
return runInfo.deadProcess.isEmpty();
}
// kill all processes started by startRayHead
public void cleanup(boolean killAll) {
// clean up the process in reverse order
for (int i = ProcessType.values().length - 1; i >= 0; i--) {
if (killAll) {
runInfo.allProcesses.get(i).forEach(p -> {
if (killProcess(p.process)) {
RayLog.core.info("Kill process " + p.process.hashCode() + " forcely");
}
});
} else {
runInfo.toBeCleanedProcesses.get(i).forEach(p -> {
if (killProcess(p)) {
RayLog.core.info("Kill process " + p.hashCode() + " forcely");
}
});
}
runInfo.toBeCleanedProcesses.get(i).clear();
runInfo.allProcesses.get(i).clear();
runInfo.deadProcess.clear();
}
}
//
// start a redis server
//
// @param ip the IP address of the local node
// @param port port to be opended for redis traffic
// @param numOfShards the number of redis shards to start
// @param redirect whether to redirect the output/err to the log files
// @param cleanup true if using ray in local mode. If cleanup is true, when
// all Redis processes started by this method will be killed by @cleanup
// when the worker exits
// @return primary redis shard address
//
private List<String> startRedis(String ip, int port, int numOfShards,
boolean redirect, boolean cleanup) {
ArrayList<String> shards = new ArrayList<>();
String addr;
for (int i = 0; i < numOfShards; i++) {
addr = startRedisInstance(ip, port + i, redirect, cleanup);
if (addr.length() == 0) {
cleanup(cleanup);
shards.clear();
return shards;
} else {
shards.add(addr);
}
}
for (String shard : shards) {
// TODO: wait for redis server to start
}
return shards;
}
//
// @param ip local node ip, only used for logging purpose
// @param port given port for this redis instance, 0 for auto-selected port
// @return redis server address
//
private String startRedisInstance(String ip, int port,
boolean redirect, boolean cleanup) {
String redisFilePath = paths.redis_server;
String redisModule = paths.redis_module;
assert (new File(redisFilePath).exists()) : "file don't exsits : " + redisFilePath;
assert (new File(redisModule).exists()) : "file don't exsits : " + redisModule;
String cmd = redisFilePath + " --protected-mode no --port " + port + " --loglevel warning"
+ " --loadmodule " + redisModule;
Map<String, String> env = null;
Process p = startProcess(cmd.split(" "), env, RunInfo.ProcessType.PT_REDIS_SERVER,
"redis", "", ip, redirect, cleanup);
if (p == null || !p.isAlive()) {
return "";
}
try {
TimeUnit.MILLISECONDS.sleep(300);
} catch (InterruptedException e) {
e.printStackTrace();
}
Jedis client = new Jedis(params.node_ip_address, port);
// Configure Redis to only generate notifications for the export keys.
client.configSet("notify-keyspace-events", "Kl");
// Put a time stamp in Redis to indicate when it was started.
client.set("redis_start_time", LocalDateTime.now().toString());
client.close();
return ip + ":" + port;
}
private void startRaylet(String storeName, AddressInfo info, int numWorkers,
String redisAddress, String ip, boolean redirect,
Map<String, Double> staticResources, boolean cleanup) {
int rpcPort = params.raylet_port;
String rayletSocketName = "/tmp/raylet" + rpcPort;
String filePath = paths.raylet;
//Create the worker command that the raylet will use to start workers.
String workerCommand = buildWorkerCommandRaylet(info.storeName, rayletSocketName,
UniqueId.NIL, "", ip, redisAddress);
int sep = redisAddress.indexOf(':');
assert (sep != -1);
String gcsIp = redisAddress.substring(0, sep);
String gcsPort = redisAddress.substring(sep + 1);
String resourceArgument = ResourceUtil.getResourcesStringFromMap(staticResources);
int hardwareConcurrency = Runtime.getRuntime().availableProcessors();
int maximumStartupConcurrency = Math.max(1, Math.min(staticResources.get("CPU").intValue(),
hardwareConcurrency));
// The second-last arugment is the worker command for Python, not needed for Java.
String[] cmds = new String[]{filePath, rayletSocketName, storeName, ip, gcsIp,
gcsPort, String.valueOf(numWorkers), String.valueOf(maximumStartupConcurrency),
resourceArgument, "", workerCommand};
Process p = startProcess(cmds, null, RunInfo.ProcessType.PT_RAYLET,
"raylet", redisAddress, ip, redirect, cleanup);
if (p != null && p.isAlive()) {
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
if (p == null || !p.isAlive()) {
info.rayletSocketName = "";
info.rayletRpcAddr = "";
throw new RuntimeException("Failed to start raylet process.");
} else {
info.rayletSocketName = rayletSocketName;
info.rayletRpcAddr = ip + ":" + rpcPort;
}
}
private String buildWorkerCommandRaylet(String storeName, String rayletSocketName,
UniqueId actorId, String actorClass,
String ip, String redisAddress) {
String workerConfigs = "ray.java.start.object_store_name=" + storeName
+ ";ray.java.start.raylet_socket_name=" + rayletSocketName
+ ";ray.java.start.worker_mode=WORKER";
workerConfigs += ";ray.java.start.deploy=" + params.deploy;
if (!actorId.equals(UniqueId.NIL)) {
workerConfigs += ";ray.java.start.actor_id=" + actorId;
}
if (!actorClass.equals("")) {
workerConfigs += ";ray.java.start.driver_class=" + actorClass;
}
String jvmArgs = "";
jvmArgs += " -Dlogging.path=" + params.log_dir;
jvmArgs += " -Dlogging.file.name=core-*pid_suffix*";
return buildJavaProcessCommand(
RunInfo.ProcessType.PT_WORKER,
"org.ray.runtime.runner.worker.DefaultWorker",
"",
workerConfigs,
jvmArgs,
ip,
redisAddress,
null
);
}
private void startObjectStore(int index, AddressInfo info, String redisAddress,
String ip, boolean redirect, boolean cleanup) {
int occupiedMemoryMb = params.object_store_occupied_memory_MB;
long memoryBytes = occupiedMemoryMb * 1000000;
String filePath = paths.store;
int rpcPort = params.object_store_rpc_port + index;
String name = "/tmp/plasma_store" + rpcPort;
String rpcAddr = "";
String cmd = filePath + " -s " + name + " -m " + memoryBytes;
Map<String, String> env = null;
Process p = startProcess(cmd.split(" "), env, RunInfo.ProcessType.PT_PLASMA_STORE,
"plasma_store", redisAddress, ip, redirect, cleanup);
if (p != null && p.isAlive()) {
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
if (p == null || !p.isAlive()) {
info.storeName = "";
info.storeRpcAddr = "";
throw new RuntimeException("Start object store failed ...");
} else {
info.storeName = name;
info.storeRpcAddr = rpcAddr;
}
}
}
@@ -0,0 +1,35 @@
package org.ray.runtime.runner.worker;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.WorkerMode;
/**
* The main function of DefaultDriver.
*/
public class DefaultDriver {
//
// " --node-ip-address=" + ip
// + " --redis-address=" + redisAddress
// + " --driver-class" + className
//
public static void main(String[] args) {
try {
AbstractRayRuntime.init(args);
assert AbstractRayRuntime.getParams().worker_mode == WorkerMode.DRIVER;
String driverClass = AbstractRayRuntime.configReader
.getStringValue("ray.java.start", "driver_class", "",
"java class which main is served as the driver in a java worker");
String driverArgs = AbstractRayRuntime.configReader
.getStringValue("ray.java.start", "driver_args", "",
"arguments for the java class main function which is served at the driver");
Class<?> cls = Class.forName(driverClass);
String[] argsArray = (driverArgs != null) ? driverArgs.split(",") : (new String[] {});
cls.getMethod("main", String[].class).invoke(null, (Object) argsArray);
} catch (Throwable e) {
e.printStackTrace();
System.exit(-1);
}
}
}
@@ -0,0 +1,31 @@
package org.ray.runtime.runner.worker;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.config.WorkerMode;
/**
* default worker implementation.
*/
public class DefaultWorker {
//
// String workerCmd = "java" + " -jarls " + workerPath + " --node-ip-address=" + ip
// + " --object-store-name=" + storeName
// + " --object-store-manager-name=" + storeManagerName
// + " --local-scheduler-name=" + name + " --redis-address=" + redisAddress
//
public static void main(String[] args) {
try {
AbstractRayRuntime.init(args);
assert AbstractRayRuntime.getParams().worker_mode == WorkerMode.WORKER;
AbstractRayRuntime.getInstance().loop();
throw new RuntimeException("Control flow should never reach here");
} catch (Throwable e) {
e.printStackTrace();
System.err
.println("--config=ray.config.ini --overwrite=ray.java.start.worker_mode=WORKER;...");
System.exit(-1);
}
}
}
@@ -0,0 +1,61 @@
package org.ray.runtime.task;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.Serializer;
public class ArgumentsBuilder {
private static boolean checkSimpleValue(Object o) {
// TODO(raulchen): implement this.
return true;
}
/**
* Convert real function arguments to task spec arguments.
*/
public static FunctionArg[] wrap(Object[] args) {
FunctionArg[] ret = new FunctionArg[args.length];
for (int i = 0; i < ret.length; i++) {
Object arg = args[i];
UniqueId id = null;
byte[] data = null;
if (arg == null) {
data = Serializer.encode(null);
} else if (arg instanceof RayActor) {
data = Serializer.encode(arg);
} else if (arg instanceof RayObject) {
id = ((RayObject) arg).getId();
} else if (checkSimpleValue(arg)) {
data = Serializer.encode(arg);
} else {
RayObject obj = Ray.put(arg);
id = obj.getId();
}
ret[i] = new FunctionArg(id, data);
}
return ret;
}
/**
* Convert task spec arguments to real function arguments.
*/
public static Object[] unwrap(TaskSpec task, ClassLoader classLoader) {
// Ignore the last arg, which is the class name
Object[] realArgs = new Object[task.args.length - 1];
for (int i = 0; i < task.args.length - 1; i++) {
FunctionArg arg = task.args[i];
if (arg.id == null) {
// pass by value
Object obj = Serializer.decode(arg.data, classLoader);
realArgs[i] = obj;
} else if (arg.data == null) {
// pass by reference
realArgs[i] = Ray.get(arg.id);
}
}
return realArgs;
}
}
@@ -0,0 +1,21 @@
package org.ray.runtime.task;
import org.ray.api.id.UniqueId;
/**
* Represents arguments for ray function calls.
*/
public class FunctionArg {
public final UniqueId id;
public final byte[] data;
public FunctionArg(UniqueId id, byte[] data) {
this.id = id;
this.data = data;
}
public void toString(StringBuilder builder) {
builder.append("ids: ").append(id).append(", ").append("<data>:").append(data);
}
}
@@ -0,0 +1,112 @@
package org.ray.runtime.task;
import java.util.Arrays;
import java.util.Map;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.ResourceUtil;
/**
* Represents necessary information of a task for scheduling and executing.
*/
public class TaskSpec {
// ID of the driver that created this task.
public UniqueId driverId;
// Task ID of the task.
public UniqueId taskId;
// Task ID of the parent task.
public UniqueId parentTaskId;
// A count of the number of tasks submitted by the parent task before this one.
public int parentCounter;
// Actor ID of the task. This is the actor that this task is executed on
// or NIL_ACTOR_ID if the task is just a normal task.
public UniqueId actorId;
// Number of tasks that have been submitted to this actor so far.
public int actorCounter;
// Function ID of the task.
public UniqueId functionId;
// Task arguments.
public FunctionArg[] args;
// return ids
public UniqueId[] returnIds;
// ID per actor client for session consistency
public UniqueId actorHandleId;
// Id for createActor a target actor
public UniqueId createActorId;
// The task's resource demands.
public Map<String, Double> resources;
public UniqueId cursorId;
public TaskSpec() {}
public TaskSpec(UniqueId driverId, UniqueId taskId, UniqueId parentTaskId, int parentCounter,
UniqueId actorId, int actorCounter, UniqueId functionId, FunctionArg[] args,
UniqueId[] returnIds, UniqueId actorHandleId, UniqueId createActorId,
Map<String, Double> resources, UniqueId cursorId) {
this.driverId = driverId;
this.taskId = taskId;
this.parentTaskId = parentTaskId;
this.parentCounter = parentCounter;
this.actorId = actorId;
this.actorCounter = actorCounter;
this.functionId = functionId;
this.args = args;
this.returnIds = returnIds;
this.actorHandleId = actorHandleId;
this.createActorId = createActorId;
this.resources = resources;
this.cursorId = cursorId;
if (!this.resources.containsKey(ResourceUtil.CPU_LITERAL)) {
this.resources.put(ResourceUtil.CPU_LITERAL, 0.0);
}
if (!this.resources.containsKey(ResourceUtil.GPU_LITERAL)) {
this.resources.put(ResourceUtil.GPU_LITERAL, 0.0);
}
}
@Override
public String toString() {
StringBuilder builder = new StringBuilder();
builder.append("\ttaskId: ").append(taskId).append("\n");
builder.append("\tdriverId: ").append(driverId).append("\n");
builder.append("\tparentCounter: ").append(parentCounter).append("\n");
builder.append("\tactorId: ").append(actorId).append("\n");
builder.append("\tactorCounter: ").append(actorCounter).append("\n");
builder.append("\tfunctionId: ").append(functionId).append("\n");
builder.append("\treturnIds: ").append(Arrays.toString(returnIds)).append("\n");
builder.append("\tactorHandleId: ").append(actorHandleId).append("\n");
builder.append("\tcreateActorId: ").append(createActorId).append("\n");
builder.append("\tresources: ")
.append(ResourceUtil.getResourcesFromatStringFromMap(resources)).append("\n");
builder.append("\tcursorId: ").append(cursorId).append("\n");
builder.append("\targs:\n");
for (FunctionArg arg : args) {
builder.append("\t\t");
arg.toString(builder);
builder.append("\n");
}
return builder.toString();
}
public boolean isActorTask() {
return !actorId.isNil();
}
public boolean isActorCreationTask() {
return !createActorId.isNil();
}
}
@@ -0,0 +1,134 @@
package org.ray.runtime.util;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Scanner;
public class FileUtil {
public static String getFilename(String logPath) {
if (logPath != null && !logPath.isEmpty()) {
int lastPos = logPath.lastIndexOf('/');
if (lastPos != -1) {
return logPath.substring(lastPos + 1);
} else {
return logPath;
}
}
return null;
}
public static boolean deleteFile(String filePath) {
File file = new File(filePath);
if (!file.exists()) {
return true;
} else {
if (file.isFile()) {
return file.delete();
} else {
for (File f : file.listFiles()) {
deleteFile(f.getAbsolutePath());
}
return file.delete();
}
}
}
public static void mkDir(File dir) {
if (dir.exists()) {
return;
}
if (dir.getParentFile().exists()) {
dir.mkdir();
} else {
mkDir(dir.getParentFile());
dir.mkdir();
}
}
public static void mkDirAndFile(File file) throws IOException {
if (file.exists()) {
return;
}
if (!file.getParentFile().exists()) {
mkDir(file.getParentFile());
}
file.createNewFile();
}
public static String readResourceFile(String fileName) throws FileNotFoundException {
ClassLoader classLoader = FileUtil.class.getClassLoader();
File file = new File(classLoader.getResource(fileName).getFile());
StringBuilder result = new StringBuilder();
try (Scanner scanner = new Scanner(file)) {
//Get file from resources folder
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
result.append(line).append("\n");
}
return result.toString();
}
}
public static void overrideFile(String file, String str) throws IOException {
try (FileWriter fw = new FileWriter(file)) {
fw.write(str);
}
}
public static boolean createDir(String dirName, boolean failIfExist) {
File wdir = new File(dirName);
if (wdir.isFile()) {
return false;
}
if (!wdir.exists()) {
wdir.mkdirs();
} else if (failIfExist) {
return false;
}
return true;
}
public static void bytesToFile(byte[] bytes, String name) throws IOException {
Path path = Paths.get(name);
Files.write(path, bytes);
}
public static byte[] fileToBytes(String name) throws IOException {
Path path = Paths.get(name);
return Files.readAllBytes(path);
}
/**
* If the given string is the empty string, then the result is the current directory.
*
* @param rawDir a path in any legal form, such as a relative path
* @return the absolute and unique path in String
*/
public static String getCanonicalDirectory(final String rawDir) throws IOException {
String dir = rawDir.length() == 0 ? "." : rawDir;
// create working dir if necessary
File dd = new File(dir);
if (!dd.exists()) {
dd.mkdirs();
}
if (!dir.startsWith("/")) {
dir = dd.getCanonicalPath();
}
return dir;
}
}
@@ -0,0 +1,89 @@
package org.ray.runtime.util;
import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.List;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.RegexFileFilter;
import org.ray.runtime.util.logger.RayLog;
/**
* load and unload jars from a dir.
*/
public class JarLoader {
public static URLClassLoader loadJars(String dir, boolean explicitLoad) {
// get all jars
Collection<File> jars = FileUtils.listFiles(
new File(dir),
new RegexFileFilter(".*\\.jar"),
DirectoryFileFilter.DIRECTORY
);
return loadJar(jars, explicitLoad);
}
public static void unloadJars(ClassLoader loader) {
// now do nothing, if no ref to the loader and loader's class.
// they would be gc.
}
private static URLClassLoader loadJar(Collection<File> appJars, boolean explicitLoad) {
List<JarFile> jars = new ArrayList<>();
List<URL> urls = new ArrayList<>();
for (File appJar : appJars) {
try {
RayLog.core.info("load jar " + appJar.getAbsolutePath());
JarFile jar = new JarFile(appJar.getAbsolutePath());
jars.add(jar);
urls.add(appJar.toURI().toURL());
} catch (IOException e) {
throw new RuntimeException(
"invalid app jar path: " + appJar.getAbsolutePath() + ", load failed with exception",
e);
}
}
URLClassLoader cl = URLClassLoader.newInstance(urls.toArray(new URL[urls.size()]));
if (!explicitLoad) {
return cl;
}
for (JarFile jar : jars) {
try {
Enumeration<JarEntry> e = jar.entries();
while (e.hasMoreElements()) {
JarEntry je = e.nextElement();
if (je.isDirectory() || !je.getName().endsWith(".class")) {
continue;
}
String className = classNameOfJarEntry(je);
className = className.replace('/', '.');
try {
Class.forName(className, true, cl);
} catch (ClassNotFoundException e1) {
e1.printStackTrace();
}
}
} finally {
IOUtils.closeQuietly(jar);
}
}
return cl;
}
private static String classNameOfJarEntry(JarEntry je) {
return je.getName().substring(0, je.getName().length() - ".class".length());
}
}
@@ -0,0 +1,31 @@
package org.ray.runtime.util;
import java.io.Serializable;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Method;
/**
* see http://cr.openjdk.java.net/~briangoetz/lambda/lambda-translation.html.
*/
public final class LambdaUtils {
private LambdaUtils() {
}
public static SerializedLambda getSerializedLambda(Serializable lambda) {
// Note.
// the class of lambda which isAssignableFrom Serializable
// has an privte method:writeReplace
// This mechanism may be changed in the future
try {
Method m = lambda.getClass().getDeclaredMethod("writeReplace");
m.setAccessible(true);
return (SerializedLambda) m.invoke(lambda);
} catch (Exception e) {
throw new RuntimeException("failed to getSerializedLambda:" + lambda.getClass().getName(), e);
}
}
}
@@ -0,0 +1,215 @@
package org.ray.runtime.util;
import com.google.common.base.Preconditions;
import java.io.Serializable;
import java.lang.invoke.MethodHandleInfo;
import java.lang.invoke.SerializedLambda;
import java.lang.reflect.Constructor;
import java.lang.reflect.Executable;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.WeakHashMap;
import org.objectweb.asm.Type;
import org.ray.runtime.util.logger.RayLog;
/**
* An instance of RayFunc is a lambda.
* MethodId describe the information of the called function in lambda.<br/>
* e.g. Ray.call(Foo::foo), the MethodId of the lambda Foo::foo is:<br/>
* MethodId.className = Foo <br/>
* MethodId.methodName = foo <br/>
* MethodId.methodDesc = describe the types of args and return.
* see org.objectweb.asm.Type.getDescriptor.
*/
public final class MethodId {
/**
* use ThreadLocal to avoid lock.
* A cache from the lambda instances to MethodId.
* Note: the lambda instances are dynamically created per call site,
* we use WeakHashMap to avoid OOM.
*/
private static final ThreadLocal<WeakHashMap<Class<Serializable>, MethodId>>
CACHE = ThreadLocal.withInitial(() -> new WeakHashMap<>());
public final String className;
public final String methodName;
public final String methodDesc;
public final boolean isStatic;
/**
* encode the className,methodName,methodDesc,isStatic as an uniquel id.
*/
private final String encoding;
/**
* sha1 from the encoding, used as functionId.
*/
private final byte[] digest;
public MethodId(String className, String methodName, String methodDesc, boolean isStatic) {
this.className = className;
this.methodName = methodName;
this.methodDesc = methodDesc;
this.isStatic = isStatic;
this.encoding = encode(className, methodName, methodDesc, isStatic);
this.digest = getSha1Hash0();
}
private static String encode(String className, String methodName, String methodDesc,
boolean isStatic) {
StringBuilder sb = new StringBuilder(512);
sb.append(className).append('/').append(methodName).append("::").append(methodDesc).append("&&")
.append(isStatic);
return sb.toString();
}
public static MethodId fromExecutable(Executable method) {
final boolean isStatic = Modifier.isStatic(method.getModifiers());
final String className = method.getDeclaringClass().getName();
final String methodName = method instanceof Method
? method.getName() : "<init>";
final Type type = method instanceof Method
? Type.getType((Method) method) : Type.getType((Constructor) method);
final String methodDesc = type.getDescriptor();
return new MethodId(className, methodName, methodDesc, isStatic);
}
public static MethodId fromSerializedLambda(Serializable serial) {
return fromSerializedLambda(serial, false);
}
public static MethodId fromSerializedLambda(Serializable serial, boolean forceNew) {
Preconditions.checkArgument(!(serial instanceof SerializedLambda), "arg could not be "
+ "SerializedLambda");
Class<Serializable> clazz = (Class<Serializable>) serial.getClass();
WeakHashMap<Class<Serializable>, MethodId> map = CACHE.get();
MethodId id = map.get(clazz);
if (id == null || forceNew) {
final SerializedLambda lambda = LambdaUtils.getSerializedLambda(serial);
Preconditions.checkArgument(lambda.getCapturedArgCount() == 0, "could not transfer a lambda "
+ "which is closure");
final boolean isStatic = lambda.getImplMethodKind() == MethodHandleInfo.REF_invokeStatic;
final String className = lambda.getImplClass().replace('/', '.');
id = new MethodId(className, lambda.getImplMethodName(),
lambda.getImplMethodSignature(), isStatic);
if (!forceNew) {
map.put(clazz, id);
}
}
return id;
}
public Method load() {
return load(null);
}
public Method load(ClassLoader loader) {
Class<?> cls = null;
try {
RayLog.core.debug(
"load class " + className + " from class loader " + (loader == null ? this.getClass()
.getClassLoader() : loader)
+ " for method " + toString() + " with ID = " + toHexHashString()
);
cls = Class
.forName(className, true, loader == null ? this.getClass().getClassLoader() : loader);
} catch (Throwable e) {
RayLog.core.error("Cannot load class {}", className, e);
return null;
}
Method[] ms = cls.getDeclaredMethods();
ArrayList<Method> methods = new ArrayList<>();
Type t = Type.getMethodType(this.methodDesc);
Type[] params = t.getArgumentTypes();
String rt = t.getReturnType().getDescriptor();
for (Method m : ms) {
if (m.getName().equals(methodName)) {
if (!Arrays.equals(params, Type.getArgumentTypes(m))) {
continue;
}
String mrt = Type.getDescriptor(m.getReturnType());
if (!rt.equals(mrt)) {
continue;
}
if (isStatic != Modifier.isStatic(m.getModifiers())) {
continue;
}
methods.add(m);
}
}
if (methods.size() != 1) {
RayLog.core.error(
"Load method {} failed as there are {} definitions.", toString(), methods.size());
return null;
}
return methods.get(0);
}
private byte[] getSha1Hash0() {
byte[] digests = Sha1Digestor.digest(encoding);
ByteBuffer bb = ByteBuffer.wrap(digests);
bb.order(ByteOrder.LITTLE_ENDIAN);
if (methodName.contains("createActorStage1")) {
bb.putLong(Long.BYTES, 1);
} else {
bb.putLong(Long.BYTES, 0);
}
return digests;
}
public byte[] getSha1Hash() {
return digest;
}
private String toHexHashString() {
byte[] id = this.getSha1Hash();
return StringUtil.toHexHashString(id);
}
public String toEncodingString() {
return encoding;
}
@Override
public int hashCode() {
return encoding.hashCode();
}
@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}
if (obj == null) {
return false;
}
if (getClass() != obj.getClass()) {
return false;
}
MethodId other = (MethodId) obj;
return className.equals(other.className)
&& methodName.equals(other.methodName)
&& methodDesc.equals(other.methodDesc)
&& isStatic == other.isStatic;
}
@Override
public String toString() {
return encoding;
}
}
@@ -0,0 +1,60 @@
package org.ray.runtime.util;
import java.io.IOException;
import java.net.DatagramSocket;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.util.Enumeration;
import org.ray.runtime.util.logger.RayLog;
public class NetworkUtil {
public static String getIpAddress(String interfaceName) {
try {
Enumeration<NetworkInterface> interfaces = NetworkInterface.getNetworkInterfaces();
while (interfaces.hasMoreElements()) {
NetworkInterface current = interfaces.nextElement();
if (!current.isUp() || current.isLoopback() || current.isVirtual()) {
continue;
}
if (!StringUtil.isNullOrEmpty(interfaceName) && !interfaceName
.equals(current.getDisplayName())) {
continue;
}
Enumeration<InetAddress> addresses = current.getInetAddresses();
while (addresses.hasMoreElements()) {
InetAddress addr = addresses.nextElement();
if (addr.isLoopbackAddress()) {
continue;
}
if (addr instanceof Inet6Address) {
continue;
}
return addr.getHostAddress();
}
}
RayLog.core.warn("You need to correctly specify [ray.java] net_interface in config.");
} catch (Exception e) {
RayLog.core.error("Can't get ip address, use 127.0.0.1 as default.", e);
}
return "127.0.0.1";
}
public static boolean isPortAvailable(int port) {
if (port < 1 || port > 65535) {
throw new IllegalArgumentException("Invalid start port: " + port);
}
try (ServerSocket ss = new ServerSocket(port); DatagramSocket ds = new DatagramSocket(port)) {
ss.setReuseAddress(true);
ds.setReuseAddress(true);
return true;
} catch (IOException ignored) {
/* should not be thrown */
return false;
}
}
}
@@ -0,0 +1,25 @@
package org.ray.runtime.util;
import java.lang.reflect.InvocationTargetException;
public class ObjectUtil {
public static <T> T newObject(Class<T> cls) {
try {
return cls.getConstructor().newInstance();
} catch (InstantiationException | IllegalAccessException | NoSuchMethodException
| InvocationTargetException e) {
e.printStackTrace();
return null;
}
}
public static boolean[] toBooleanArray(Object[] vs) {
boolean[] nvs = new boolean[vs.length];
for (int i = 0; i < vs.length; i++) {
nvs[i] = (boolean) vs[i];
}
return nvs;
}
}
@@ -0,0 +1,105 @@
package org.ray.runtime.util;
import java.util.HashMap;
import java.util.Map;
import org.ray.api.annotation.RayRemote;
import org.ray.api.annotation.ResourceItem;
public class ResourceUtil {
public static final String CPU_LITERAL = "CPU";
public static final String GPU_LITERAL = "GPU";
/**
* Convert the array that contains resource items to a map.
*
* @param remoteAnnotation The RayRemote annotation that contains the resource items.
* @return The map whose key represents the resource name
* and the value represents the resource quantity.
*/
public static Map<String, Double> getResourcesMapFromArray(RayRemote remoteAnnotation) {
Map<String, Double> resourceMap = new HashMap<>();
if (remoteAnnotation == null) {
return resourceMap;
}
for (ResourceItem item : remoteAnnotation.resources()) {
if (!item.name().isEmpty()) {
resourceMap.put(item.name(), item.value());
}
}
return resourceMap;
}
/**
* Convert the resources map to a format string.
*
* @param resources The resource map to be Converted.
* @return The format resources string, like "{CPU:4, GPU:0}".
*/
public static String getResourcesFromatStringFromMap(Map<String, Double> resources) {
StringBuilder builder = new StringBuilder();
builder.append("{");
int count = 1;
for (Map.Entry<String, Double> entry : resources.entrySet()) {
builder.append(entry.getKey()).append(":").append(entry.getValue());
count++;
if (count != resources.size()) {
builder.append(", ");
}
}
builder.append("}");
return builder.toString();
}
/**
* Convert resources map to a string that is used
* for the command line argument of starting raylet.
*
* @param resources The resources map to be converted.
* @return The starting-raylet command line argument, like "CPU,4,GPU,0".
*/
public static String getResourcesStringFromMap(Map<String, Double> resources) {
StringBuilder builder = new StringBuilder();
if (resources != null) {
int count = 1;
for (Map.Entry<String, Double> entry : resources.entrySet()) {
builder.append(entry.getKey()).append(",").append(entry.getValue());
if (count != resources.size()) {
builder.append(",");
}
count++;
}
}
return builder.toString();
}
/**
* Parse the static resources configure field and convert to the resources map.
*
* @param resources The static resources string to be parsed.
* @return The map whose key represents the resource name
* and the value represents the resource quantity.
* @throws IllegalArgumentException If the resources string's format does match,
* it will throw an IllegalArgumentException.
*/
public static Map<String, Double> getResourcesMapFromString(String resources)
throws IllegalArgumentException {
Map<String, Double> ret = new HashMap<>();
if (resources != null) {
String[] items = resources.split(",");
for (String item : items) {
String trimItem = item.trim();
String[] resourcePair = trimItem.split(":");
if (resourcePair.length != 2) {
throw new IllegalArgumentException("Format of static resurces configure is invalid.");
}
final String resourceName = resourcePair[0].trim();
final Double resourceValue = Double.valueOf(resourcePair[1].trim());
ret.put(resourceName, resourceValue);
}
}
return ret;
}
}
@@ -0,0 +1,55 @@
package org.ray.runtime.util;
import org.nustaq.serialization.FSTConfiguration;
/**
* Java object serialization TODO: use others (e.g. Arrow) for higher performance
*/
public class Serializer {
static final ThreadLocal<FSTConfiguration> conf = ThreadLocal.withInitial(
FSTConfiguration::createDefaultConfiguration);
public static byte[] encode(Object obj) {
return conf.get().asByteArray(obj);
}
public static byte[] encode(Object obj, ClassLoader classLoader) {
byte[] result;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
result = current.asByteArray(obj);
current.setClassLoader(old);
} else {
result = current.asByteArray(obj);
}
return result;
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs) {
return (T) conf.get().asObject(bs);
}
@SuppressWarnings("unchecked")
public static <T> T decode(byte[] bs, ClassLoader classLoader) {
Object object;
FSTConfiguration current = conf.get();
if (classLoader != null && classLoader != current.getClassLoader()) {
ClassLoader old = current.getClassLoader();
current.setClassLoader(classLoader);
object = current.asObject(bs);
current.setClassLoader(old);
} else {
object = current.asObject(bs);
}
return (T) object;
}
public static void setClassloader(ClassLoader classLoader) {
conf.get().setClassLoader(classLoader);
}
}
@@ -0,0 +1,41 @@
package org.ray.runtime.util;
import java.nio.ByteBuffer;
import java.security.MessageDigest;
import org.ray.runtime.util.logger.RayLog;
public class Sha1Digestor {
private static final ThreadLocal<MessageDigest> md = ThreadLocal.withInitial(() -> {
try {
return MessageDigest.getInstance("SHA1");
} catch (Exception e) {
RayLog.core.error("Cannot get SHA1 MessageDigest", e);
throw new RuntimeException("Cannot get SHA1 digest", e);
}
});
private static final ThreadLocal<ByteBuffer> longBuffer = ThreadLocal
.withInitial(() -> ByteBuffer.allocate(Long.SIZE / Byte.SIZE));
public static byte[] digest(byte[] src, long addIndex) {
MessageDigest dg = md.get();
longBuffer.get().clear();
dg.reset();
dg.update(src);
dg.update(longBuffer.get().putLong(addIndex).array());
return dg.digest();
}
public static byte[] digest(String str) {
return digest(str.getBytes(StringUtil.UTF8));
}
public static byte[] digest(byte[] src) {
MessageDigest dg = md.get();
dg.reset();
return dg.digest(src);
}
}
@@ -0,0 +1,150 @@
package org.ray.runtime.util;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Vector;
public class StringUtil {
public static final Charset UTF8 = Charset.forName("UTF-8");
private static final char[] HEX_CHARS = "0123456789abcdef".toCharArray();
/**
* split.
* @param s input string
* @param splitters common splitters
* @param open open braces
* @param close close braces
* @return output array list
*/
public static Vector<String> split(String s, String splitters, String open, String close) {
// The splits.
Vector<String> split = new Vector<>();
// The stack.
ArrayList<Start> stack = new ArrayList<>();
int lastPos = 0;
// Walk the string.
for (int i = 0; i < s.length(); i++) {
// Get the char there.
char ch = s.charAt(i);
// Is it an open brace?
int o = open.indexOf(ch);
// Is it a close brace?
int c = close.indexOf(ch);
// Is it a splitter?
int sp = splitters.indexOf(ch);
if (stack.size() == 0 && sp >= 0) {
if (i == lastPos) {
++lastPos;
continue;
}
split.add(s.substring(lastPos, i));
lastPos = i + 1;
} else if (o >= 0 && (c < 0 || stack.size() == 0)) {
// Its an open! Push it.
stack.add(new Start(o, i));
} else if (c >= 0 && stack.size() > 0) {
// Pop (if matches).
int tosPos = stack.size() - 1;
Start tos = stack.get(tosPos);
// Does the brace match?
if (tos.brace == c) {
// Done with that one.
stack.remove(tosPos);
}
}
}
if (lastPos < s.length()) {
split.add(s.substring(lastPos, s.length()));
}
// build removal filter set
HashSet<Character> removals = new HashSet<>();
for (int i = 0; i < splitters.length(); i++) {
removals.add(splitters.charAt(i));
}
for (int i = 0; i < open.length(); i++) {
removals.add(open.charAt(i));
}
for (int i = 0; i < close.length(); i++) {
removals.add(close.charAt(i));
}
// apply removal filter set
for (int i = 0; i < split.size(); i++) {
String cs = split.get(i);
// remove heading chars
int j;
for (j = 0; j < cs.length(); j++) {
if (!removals.contains(cs.charAt(j))) {
break;
}
}
cs = cs.substring(j);
// remove tail chars
for (j = cs.length() - 1; j >= 0; j--) {
if (!removals.contains(cs.charAt(j))) {
break;
}
}
cs = cs.substring(0, j + 1);
// reset cs
split.set(i, cs);
}
return split;
}
public static boolean isNullOrEmpty(String s) {
return s == null || s.length() == 0;
}
public static <T> String mergeArray(T[] objs, String concatenator) {
StringBuilder sb = new StringBuilder();
for (T obj : objs) {
sb.append(obj).append(concatenator);
}
return objs.length == 0 ? "" : sb.substring(0, sb.length() - concatenator.length());
}
public static String toHexHashString(byte[] id) {
StringBuilder sb = new StringBuilder(20);
assert (id.length == 20);
for (int i = 0; i < 20; i++) {
int val = id[i] & 0xff;
sb.append(HEX_CHARS[val >> 4]);
sb.append(HEX_CHARS[val & 0xf]);
}
return sb.toString();
}
// Holds the start of an element and which brace started it.
private static class Start {
// The brace number from the braces string in use.
final int brace;
// The position in the string it was seen.
final int pos;
// Constructor.
public Start(int brace, int pos) {
this.brace = brace;
this.pos = pos;
}
}
}
@@ -0,0 +1,64 @@
package org.ray.runtime.util;
import java.lang.management.ManagementFactory;
import java.lang.management.RuntimeMXBean;
import java.util.concurrent.locks.ReentrantLock;
import org.ray.runtime.util.logger.RayLog;
/**
* some utilities for system process.
*/
public class SystemUtil {
static final ReentrantLock pidlock = new ReentrantLock();
static Integer pid;
public static String userHome() {
return System.getProperty("user.home");
}
public static String userDir() {
return System.getProperty("user.dir");
}
public static boolean startWithJar(Class<?> cls) {
return cls.getResource(cls.getSimpleName() + ".class").getFile().split("!")[0].endsWith(".jar");
}
public static boolean startWithJar(String clsName) {
Class<?> cls;
try {
cls = Class.forName(clsName);
return cls.getResource(cls.getSimpleName() + ".class").getFile().split("!")[0]
.endsWith(".jar");
} catch (ClassNotFoundException e) {
// TODO Auto-generated catch block
e.printStackTrace();
RayLog.core.error("error at SystemUtil startWithJar", e);
return false;
}
}
public static int pid() {
if (pid == null) {
pidlock.lock();
try {
if (pid == null) {
RuntimeMXBean runtime = ManagementFactory.getRuntimeMXBean();
String name = runtime.getName();
int index = name.indexOf("@");
if (index != -1) {
pid = Integer.parseInt(name.substring(0, index));
} else {
throw new RuntimeException("parse pid error:" + name);
}
}
} finally {
pidlock.unlock();
}
}
return pid;
}
}
@@ -0,0 +1,71 @@
package org.ray.runtime.util;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import org.ray.api.id.UniqueId;
/**
* Helper method for UniqueId.
* Note: any changes to these methods must be synced with C++ helper functions
* in src/ray/id.h
*/
public class UniqueIdHelper {
public static final int OBJECT_INDEX_POS = 0;
public static final int OBJECT_INDEX_LENGTH = 4;
/**
* Compute the object ID of an object returned by the task.
*
* @param taskId The task ID of the task that created the object.
* @param returnIndex What number return value this object is in the task.
* @return The computed object ID.
*/
public static UniqueId computeReturnId(UniqueId taskId, int returnIndex) {
return computeObjectId(taskId, returnIndex);
}
/**
* Compute the object ID from the task ID and the index.
* @param taskId The task ID of the task that created the object.
* @param index The index which can distinguish different objects in one task.
* @return The computed object ID.
*/
private static UniqueId computeObjectId(UniqueId taskId, int index) {
byte[] objId = new byte[UniqueId.LENGTH];
System.arraycopy(taskId.getBytes(),0, objId, 0, UniqueId.LENGTH);
ByteBuffer wbb = ByteBuffer.wrap(objId);
wbb.order(ByteOrder.LITTLE_ENDIAN);
wbb.putInt(UniqueIdHelper.OBJECT_INDEX_POS, index);
return new UniqueId(objId);
}
/**
* Compute the object ID of an object put by the task.
*
* @param taskId The task ID of the task that created the object.
* @param putIndex What number put this object was created by in the task.
* @return The computed object ID.
*/
public static UniqueId computePutId(UniqueId taskId, int putIndex) {
// We multiply putIndex by -1 to distinguish from returnIndex.
return computeObjectId(taskId, -1 * putIndex);
}
/**
* Compute the task ID of the task that created the object.
*
* @param objectId The object ID.
* @return The task ID of the task that created this object.
*/
public static UniqueId computeTaskId(UniqueId objectId) {
byte[] taskId = new byte[UniqueId.LENGTH];
System.arraycopy(objectId.getBytes(), 0, taskId, 0, UniqueId.LENGTH);
Arrays.fill(taskId, UniqueIdHelper.OBJECT_INDEX_POS,
UniqueIdHelper.OBJECT_INDEX_POS + UniqueIdHelper.OBJECT_INDEX_LENGTH, (byte) 0);
return new UniqueId(taskId);
}
}
@@ -0,0 +1,36 @@
package org.ray.runtime.util.config;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Annotate a field as a ray configuration item.
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
public @interface AConfig {
/**
* comments for this configuration field.
*/
String comment();
/**
* when the config is an array list, a splitter set is specified, e.g., " \t" to use ' ' and '\t'
* as possible splits.
*/
String splitters() default ", \t";
/**
* indirect with value as the new section name, the field name remains the same.
*/
String defaultIndirectSectionName() default "";
/**
* see ConfigReader.getIndirectStringArray this config tells which is the default
* indirectSectionName in that function.
*/
String defaultArrayIndirectSectionName() default "";
}
@@ -0,0 +1,15 @@
package org.ray.runtime.util.config;
/**
* A ray configuration item of type {@code T}.
*/
public class ConfigItem<T> {
public String key;
public String oriValue;
public T defaultValue;
public String desc;
}
@@ -0,0 +1,382 @@
package org.ray.runtime.util.config;
import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Vector;
import org.ini4j.Config;
import org.ini4j.Ini;
import org.ini4j.Profile;
import org.ray.api.id.UniqueId;
import org.ray.runtime.util.ObjectUtil;
import org.ray.runtime.util.StringUtil;
/**
* Loads configurations from a file.
*/
public class ConfigReader {
private final CurrentUseConfig currentUseConfig = new CurrentUseConfig();
private final Ini ini = new Ini();
private String file = "";
public ConfigReader(String filePath) throws Exception {
this(filePath, null);
}
public ConfigReader(String filePath, String updateConfigStr) throws Exception {
System.out.println("Build ConfigReader, the file path " + filePath + " ,the update config str "
+ updateConfigStr);
try {
loadConfigFile(filePath);
updateConfigFile(updateConfigStr);
} catch (Exception e) {
e.printStackTrace();
throw e;
}
}
private void loadConfigFile(String filePath) throws Exception {
this.currentUseConfig.filePath = filePath;
String configFileDir = (new File(filePath)).getAbsoluteFile().getParent();
byte[] encoded = Files.readAllBytes(Paths.get(filePath));
String content = new String(encoded, StandardCharsets.UTF_8);
content = content.replaceAll("%CONFIG_FILE_DIR%", configFileDir);
InputStream fis = new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8));
Config config = new Config();
ini.setConfig(config);
ini.load(fis);
file = currentUseConfig.filePath;
}
private void updateConfigFile(String updateConfigStr) {
if (updateConfigStr == null) {
return;
}
String[] updateConfigArray = updateConfigStr.split(";");
for (String currentUpdateConfig : updateConfigArray) {
if (StringUtil.isNullOrEmpty(currentUpdateConfig)) {
continue;
}
String[] currentUpdateConfigArray = currentUpdateConfig.split("=");
String sectionAndItemKey;
String value = "";
if (currentUpdateConfigArray.length == 2) {
sectionAndItemKey = currentUpdateConfigArray[0];
value = currentUpdateConfigArray[1];
} else if (currentUpdateConfigArray.length == 1) {
sectionAndItemKey = currentUpdateConfigArray[0];
} else {
String errorMsg = "invalid config (must be of k=v or k or k=): " + currentUpdateConfig;
System.err.println(errorMsg);
throw new RuntimeException(errorMsg);
}
int splitOffset = sectionAndItemKey.lastIndexOf(".");
int len = sectionAndItemKey.length();
if (splitOffset < 1 || splitOffset == len - 1) {
String errorMsg =
"invalid config (no '.' found for section name and key):" + currentUpdateConfig;
System.err.println(errorMsg);
throw new RuntimeException(errorMsg);
}
String sectionKey = sectionAndItemKey.substring(0, splitOffset);
String itemKey = sectionAndItemKey.substring(splitOffset + 1);
if (ini.containsKey(sectionKey)) {
ini.get(sectionKey).put(itemKey, value);
} else {
ini.add(sectionKey, itemKey, value);
}
}
}
public String filePath() {
return file;
}
public CurrentUseConfig getCurrentUseConfig() {
return currentUseConfig;
}
public String getStringValue(String sectionKey, String configKey, String defaultValue,
String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
if (value != null) {
return value;
} else {
return defaultValue;
}
}
public boolean getBooleanValue(String sectionKey, String configKey, boolean defaultValue,
String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
if (value != null) {
if (value.length() == 0) {
return defaultValue;
} else {
return Boolean.valueOf(value);
}
} else {
return defaultValue;
}
}
public int getIntegerValue(String sectionKey, String configKey, int defaultValue, String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
if (value != null) {
if (value.length() == 0) {
return defaultValue;
} else {
return Integer.valueOf(value);
}
} else {
return defaultValue;
}
}
private synchronized <T> String getOriValue(String sectionKey, String configKey, T defaultValue,
String deptr) {
if (null == deptr) {
throw new RuntimeException("desc must not be empty of the key:" + configKey);
}
Profile.Section section = ini.get(sectionKey);
String oriValue = null;
if (section != null && section.containsKey(configKey)) {
oriValue = section.get(configKey);
}
if (!currentUseConfig.sectionMap.containsKey(sectionKey)) {
ConfigSection configSection = new ConfigSection();
configSection.sectionKey = sectionKey;
updateConfigSection(configSection, configKey, defaultValue, deptr, oriValue);
currentUseConfig.sectionMap.put(sectionKey, configSection);
} else if (!currentUseConfig.sectionMap.get(sectionKey).itemMap.containsKey(configKey)) {
ConfigSection configSection = currentUseConfig.sectionMap.get(sectionKey);
updateConfigSection(configSection, configKey, defaultValue, deptr, oriValue);
}
return oriValue;
}
private <T> void updateConfigSection(ConfigSection configSection, String configKey,
T defaultValue, String deptr, String oriValue) {
ConfigItem<T> configItem = new ConfigItem<>();
configItem.defaultValue = defaultValue;
configItem.key = configKey;
configItem.oriValue = oriValue;
configItem.desc = deptr;
configSection.itemMap.put(configKey, configItem);
}
public long getLongValue(String sectionKey, String configKey, long defaultValue, String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
if (value != null) {
if (value.length() == 0) {
return defaultValue;
} else {
return Long.valueOf(value);
}
} else {
return defaultValue;
}
}
public double getDoubleValue(String sectionKey, String configKey, double defaultValue,
String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
if (value != null) {
if (value.length() == 0) {
return defaultValue;
} else {
return Double.valueOf(value);
}
} else {
return defaultValue;
}
}
public int[] getIntegerArray(String sectionKey, String configKey, int[] defaultValue,
String dsptr) {
String value = getOriValue(sectionKey, configKey, defaultValue, dsptr);
int[] array = defaultValue;
if (value != null) {
String[] list = value.split(",");
array = new int[list.length];
for (int i = 0; i < list.length; i++) {
array[i] = Integer.valueOf(list[i]);
}
}
return array;
}
/**
* get a string list from a whole section as keys e.g., [core] data_dirs = local.dirs # or
* cluster.dirs
* [local.dirs] /home/xxx/1 /home/yyy/2
* [cluster.dirs] ...
*
* @param sectionKey e.g., core
* @param configKey e.g., data_dirs
* @param indirectSectionName e.g., cluster.dirs
* @return string list
*/
public String[] getIndirectStringArray(String sectionKey, String configKey,
String indirectSectionName, String dsptr) {
String s = getStringValue(sectionKey, configKey, indirectSectionName, dsptr);
Profile.Section section = ini.get(s);
if (section == null) {
return new String[] {};
} else {
return section.keySet().toArray(new String[] {});
}
}
public <T> void readObject(String sectionKey, T obj, T defaultValues) {
for (Field fld : obj.getClass().getFields()) {
Object defaultFldValue;
try {
defaultFldValue = defaultValues != null ? fld.get(defaultValues) : null;
} catch (IllegalArgumentException | IllegalAccessException e) {
defaultFldValue = null;
}
String section = sectionKey;
String comment;
String splitters = ", \t";
String defaultArrayIndirectSectionName;
AConfig[] anns = fld.getAnnotationsByType(AConfig.class);
if (anns.length > 0) {
comment = anns[0].comment();
if (!StringUtil.isNullOrEmpty(anns[0].splitters())) {
splitters = anns[0].splitters();
}
defaultArrayIndirectSectionName = anns[0].defaultArrayIndirectSectionName();
// redirect the section if necessary
if (!StringUtil.isNullOrEmpty(anns[0].defaultIndirectSectionName())) {
section = this
.getStringValue(sectionKey, fld.getName(), anns[0].defaultIndirectSectionName(),
comment);
}
} else {
throw new RuntimeException("unspecified comment, please use @AConfig(comment = xxxx) for "
+ obj.getClass().getName() + "." + fld.getName() + "'s configuration descriptions ");
}
try {
if (fld.getType().isPrimitive()) {
if (fld.getType().equals(boolean.class)) {
boolean v = getBooleanValue(section, fld.getName(), (boolean) defaultFldValue, comment);
fld.set(obj, v);
} else if (fld.getType().equals(float.class)) {
float v = (float) getDoubleValue(section, fld.getName(),
(double) (float) defaultFldValue, comment);
fld.set(obj, v);
} else if (fld.getType().equals(double.class)) {
double v = getDoubleValue(section, fld.getName(), (double) defaultFldValue, comment);
fld.set(obj, v);
} else if (fld.getType().equals(byte.class)) {
byte v = (byte) getLongValue(section, fld.getName(), (long) (byte) defaultFldValue,
comment);
fld.set(obj, v);
} else if (fld.getType().equals(char.class)) {
char v = (char) getLongValue(section, fld.getName(), (long) (char) defaultFldValue,
comment);
fld.set(obj, v);
} else if (fld.getType().equals(short.class)) {
short v = (short) getLongValue(section, fld.getName(), (long) (short) defaultFldValue,
comment);
fld.set(obj, v);
} else if (fld.getType().equals(int.class)) {
int v = (int) getLongValue(section, fld.getName(), (long) (int) defaultFldValue,
comment);
fld.set(obj, v);
} else if (fld.getType().equals(long.class)) {
long v = getLongValue(section, fld.getName(), (long) defaultFldValue, comment);
fld.set(obj, v);
} else {
throw new RuntimeException("unhandled type " + fld.getType().getName());
}
} else if (fld.getType().equals(String.class)) {
String v = getStringValue(section, fld.getName(), (String) defaultFldValue, comment);
fld.set(obj, v);
} else if (fld.getType().isEnum()) {
String sv = getStringValue(section, fld.getName(), defaultFldValue.toString(), comment);
@SuppressWarnings({"unchecked", "rawtypes"})
Object v = Enum.valueOf((Class<Enum>) fld.getType(), sv);
fld.set(obj, v);
// TODO: this is a hack and needs to be resolved later
} else if (fld.getType().equals(UniqueId.class)) {
String sv = getStringValue(section, fld.getName(), defaultFldValue.toString(), comment);
Object v;
try {
v = UniqueId.fromHexString(sv);
} catch (IllegalArgumentException e) {
System.err.println(
section + "." + fld.getName() + "'s format (" + sv + ") is invalid, default to "
+ defaultFldValue.toString());
v = defaultFldValue;
}
fld.set(obj, v);
} else if (fld.getType().isArray()) {
Class<?> ccls = fld.getType().getComponentType();
String ss = getStringValue(section, fld.getName(), null, comment);
if (null == ss) {
fld.set(obj, defaultFldValue);
} else {
Vector<String> ls = StringUtil.split(ss, splitters, "", "");
if (ccls.equals(boolean.class)) {
boolean[] v = ObjectUtil
.toBooleanArray(ls.stream().map(Boolean::parseBoolean).toArray());
fld.set(obj, v);
} else if (ccls.equals(double.class)) {
double[] v = ls.stream().mapToDouble(Double::parseDouble).toArray();
fld.set(obj, v);
} else if (ccls.equals(int.class)) {
int[] v = ls.stream().mapToInt(Integer::parseInt).toArray();
fld.set(obj, v);
} else if (ccls.equals(long.class)) {
long[] v = ls.stream().mapToLong(Long::parseLong).toArray();
fld.set(obj, v);
} else if (ccls.equals(String.class)) {
String[] v;
if (StringUtil.isNullOrEmpty(defaultArrayIndirectSectionName)) {
v = ls.toArray(new String[] {});
} else {
v = this
.getIndirectStringArray(section, fld.getName(),
defaultArrayIndirectSectionName,
comment);
}
fld.set(obj, v);
} else {
throw new RuntimeException(
"Array with component type " + ccls.getName() + " is not supported yet");
}
}
} else {
Object fldObj = ObjectUtil.newObject(fld.getType());
fld.set(obj, fldObj);
readObject(section + "." + fld.getName(), fldObj, defaultFldValue);
}
} catch (IllegalArgumentException | IllegalAccessException e) {
throw new RuntimeException("set fld " + fld.getName() + " failed, err = " + e.getMessage(),
e);
}
}
}
}
@@ -0,0 +1,13 @@
package org.ray.runtime.util.config;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* A configuration section of related items.
*/
public class ConfigSection {
public final Map<String, ConfigItem<?>> itemMap = new ConcurrentHashMap<>();
public String sectionKey;
}
@@ -0,0 +1,15 @@
package org.ray.runtime.util.config;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
/**
* The configuration which is currently in use.
*/
public class CurrentUseConfig {
public final Map<String, ConfigSection> sectionMap = new ConcurrentHashMap<>();
public String filePath;
}
@@ -0,0 +1,15 @@
package org.ray.runtime.util.exception;
/**
* An exception which is thrown when a ray task encounters an error when executing.
*/
public class TaskExecutionException extends RuntimeException {
public TaskExecutionException(Throwable cause) {
super(cause);
}
public TaskExecutionException(String message, Throwable cause) {
super(message, cause);
}
}
@@ -0,0 +1,24 @@
package org.ray.runtime.util.generator;
public abstract class BaseGenerator {
protected static final int MAX_PARAMETERS = 6;
protected StringBuilder sb;
protected void newLine(String line) {
sb.append(line).append("\n");
}
protected void newLine(int numIndents, String line) {
indents(numIndents);
newLine(line);
}
protected void indents(int numIndents) {
for (int i = 0; i < numIndents; i++) {
sb.append(" ");
}
}
}
@@ -0,0 +1,126 @@
package org.ray.runtime.util.generator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.ray.runtime.util.FileUtil;
/**
* A util class that generates `RayCall.java`,
* which provides type-safe interfaces for `Ray.call` and `Ray.createActor`.
*/
public class RayCallGenerator extends BaseGenerator {
/**
* @return Whole file content of `RayCall.java`.
*/
private String build() {
sb = new StringBuilder();
newLine("// generated automatically, do not modify.");
newLine("");
newLine("package org.ray.api;");
newLine("");
newLine("import org.ray.api.function.*;");
newLine("");
newLine("/**");
newLine(" * This class provides type-safe interfaces for `Ray.call` and `Ray.createActor`.");
newLine(" **/");
newLine("@SuppressWarnings({\"rawtypes\", \"unchecked\"})");
newLine("class RayCall {");
newLine(1, "// =======================================");
newLine(1, "// Methods for remote function invocation.");
newLine(1, "// =======================================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, false);
}
newLine(1, "// ===========================================");
newLine(1, "// Methods for remote actor method invocation.");
newLine(1, "// ===========================================");
for (int i = 0; i <= MAX_PARAMETERS - 1; i++) {
buildCalls(i, true, false);
}
newLine(1, "// ===========================");
newLine(1, "// Methods for actor creation.");
newLine(1, "// ===========================");
for (int i = 0; i <= MAX_PARAMETERS; i++) {
buildCalls(i, false, true);
}
newLine("}");
return sb.toString();
}
/**
* Build the `Ray.call` or `Ray.createActor` methods with the given number of parameters.
* @param numParameters the number of parameters
* @param forActor build actor api when true, otherwise build task api.
* @param forActorCreation build `Ray.createActor` when true, otherwise build `Ray.call`.
*/
private void buildCalls(int numParameters, boolean forActor, boolean forActorCreation) {
String genericTypes = "";
String argList = "";
for (int i = 0; i < numParameters; i++) {
genericTypes += "T" + i + ", ";
argList += "t" + i + ", ";
}
if (forActor) {
genericTypes = "A, " + genericTypes;
}
genericTypes += forActorCreation ? "A" : "R";
if (argList.endsWith(", ")) {
argList = argList.substring(0, argList.length() - 2);
}
String paramPrefix = String.format("RayFunc%d<%s> f",
!forActor ? numParameters : numParameters + 1,
genericTypes);
if (forActor) {
paramPrefix += ", RayActor<A> actor";
}
if (numParameters > 0) {
paramPrefix += ", ";
}
String returnType = !forActorCreation ? "RayObject<R>" : "RayActor<A>";
String funcName = !forActorCreation ? "call" : "createActor";
String funcArgs = !forActor ? "f, args" : "f, actor, args";
for (String param : generateParameters(0, numParameters)) {
// method signature
newLine(1, String.format(
"public static <%s> %s %s(%s) {",
genericTypes, returnType, funcName, paramPrefix + param
));
// method body
newLine(2, String.format("Object[] args = new Object[]{%s};", argList));
newLine(2, String.format("return Ray.internal().%s(%s);", funcName, funcArgs));
newLine(1, "}");
}
}
private List<String> generateParameters(int from, int to) {
List<String> res = new ArrayList<>();
dfs(from, from, to, "", res);
return res;
}
private void dfs(int pos, int from, int to, String cur, List<String> res) {
if (pos >= to) {
res.add(cur);
return;
}
if (pos > from) {
cur += ", ";
}
String nextParameter = String.format("T%d t%d", pos, pos);
dfs(pos + 1, from, to, cur + nextParameter, res);
nextParameter = String.format("RayObject<T%d> t%d", pos, pos);
dfs(pos + 1, from, to, cur + nextParameter, res);
}
public static void main(String[] args) throws IOException {
String path = System.getProperty("user.dir")
+ "/api/src/main/java/org/ray/api/RayCall.java";
FileUtil.overrideFile(path, new RayCallGenerator().build());
}
}
@@ -0,0 +1,54 @@
package org.ray.runtime.util.generator;
import java.io.IOException;
import org.ray.runtime.util.FileUtil;
/**
* A util class that generates all the RayFuncX classes under org.ray.api.function package.
*/
public class RayFuncGenerator extends BaseGenerator {
private String generate(int numParameters) {
sb = new StringBuilder();
String genericTypes = "";
String paramList = "";
for (int i = 0; i < numParameters; i++) {
genericTypes += "T" + i + ", ";
if (i > 0) {
paramList += ", ";
}
paramList += String.format("T%d t%d", i, i);
}
newLine("// generated automatically, do not modify.");
newLine("");
newLine("package org.ray.api.function;");
newLine("");
newLine("/**");
String comment = String.format(
" * Functional interface for a remote function that has %d parameter%s.",
numParameters, numParameters > 1 ? "s" : "");
newLine(comment);
newLine(" */");
newLine("@FunctionalInterface");
newLine(String.format("public interface RayFunc%d<%sR> extends RayFunc {",
numParameters, genericTypes));
indents(1);
newLine(String.format("R apply(%s);", paramList));
newLine("}");
return sb.toString();
}
public static void main(String[] args) throws IOException {
String root = System.getProperty("user.dir")
+ "/api/src/main/java/org/ray/api/function/";
RayFuncGenerator generator = new RayFuncGenerator();
for (int i = 0; i <= MAX_PARAMETERS; i++) {
String content = generator.generate(i);
FileUtil.overrideFile(root + "RayFunc" + i + ".java", content);
}
}
}
@@ -0,0 +1,44 @@
package org.ray.runtime.util.logger;
import org.ray.runtime.util.SystemUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* loggers in Ray.
* 1. core logger is used for internal Ray status logging.
* 2. rapp for ray applications logging.
*/
public class RayLog {
/**
* for ray itself.
*/
public static Logger core;
/**
* for ray app.
*/
public static Logger rapp;
/**
* Initialize loggers
* @param logDir directory of the log files.
*/
public static void init(String logDir) {
String loggingPath = System.getProperty("logging.path");
if (loggingPath == null) {
System.setProperty("logging.path", logDir);
}
String loggingFileName = System.getProperty("logging.file.name");
if (loggingFileName != null && loggingFileName.contains("*pid_suffix*")) {
loggingFileName = loggingFileName.replaceAll("\\*pid_suffix\\*",
String.valueOf(SystemUtil.pid()));
System.setProperty("logging.file.name", loggingFileName);
}
core = LoggerFactory.getLogger("core");
rapp = core;
}
}
@@ -0,0 +1,20 @@
# define default properties here
logging.level=WARN
logging.path=./run/logs
logging.file.name=core
logging.max.log.file.num=10
logging.max.log.file.size=500MB
log4j.rootLogger=${logging.level}, stdout, core
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %p %c{1} [%t]: %m%n
log4j.appender.core=org.apache.log4j.RollingFileAppender
log4j.appender.core.File=${logging.path}/${logging.file.name}.log
log4j.appender.core.Append=true
log4j.appender.core.MaxFileSize=${logging.max.log.file.size}
log4j.appender.core.MaxBackupIndex=${logging.max.log.file.num}
log4j.appender.core.layout=org.apache.log4j.PatternLayout
log4j.appender.core.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %p %c{1} [%t]: %m%n