[Java Worker] Support raylet on Java (#2479)

This commit is contained in:
Wang Qing
2018-08-02 08:52:49 +08:00
committed by Robert Nishihara
parent 9a479b3a63
commit e4f68ff8cf
41 changed files with 916 additions and 303 deletions
@@ -1,5 +1,6 @@
package org.ray.core;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
@@ -122,7 +123,13 @@ public abstract class RayRuntime implements RayApi {
functions = new LocalFunctionManager(remoteLoader);
localSchedulerProxy = new LocalSchedulerProxy(slink);
objectStoreProxy = new ObjectStoreProxy(plink);
if (!params.use_raylet) {
objectStoreProxy = new ObjectStoreProxy(plink);
} else {
objectStoreProxy = new ObjectStoreProxy(plink, slink);
}
worker = new Worker(localSchedulerProxy, functions);
}
@@ -188,7 +195,9 @@ public abstract class RayRuntime implements RayApi {
public <T, TMT> void putRaw(UniqueID taskId, UniqueID objectId, T obj, TMT metadata) {
RayLog.core.info("Task " + taskId.toString() + " Object " + objectId.toString() + " put");
localSchedulerProxy.markTaskPutDependency(taskId, objectId);
if (!params.use_raylet) {
localSchedulerProxy.markTaskPutDependency(taskId, objectId);
}
objectStoreProxy.put(objectId, obj, metadata);
}
@@ -274,22 +283,32 @@ public abstract class RayRuntime implements RayApi {
return worker.rpcWithReturnIndices(taskId, funcCls, lambda, returnCount, args);
}
private <T> List<T> doGet(List<UniqueID> objectIds, boolean isMetadata)
throws TaskExecutionException {
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskId();
try {
int numObjectIds = objectIds.size();
// Do an initial fetch for remote objects.
dividedFetch(objectIds);
List<List<UniqueID>> fetchBatches =
splitIntoBatches(objectIds, params.worker_fetch_request_size);
for (List<UniqueID> batch : fetchBatches) {
if (!params.use_raylet) {
objectStoreProxy.fetch(batch);
} else {
localSchedulerProxy.reconstructObjects(batch, true);
}
}
// Get the objects. We initially try to get the objects immediately.
List<Pair<T, GetStatus>> ret = objectStoreProxy
.get(objectIds, params.default_first_check_timeout_ms, isMetadata);
assert ret.size() == numObjectIds;
// mapping the object IDs that we haven't gotten yet to their original index in objectIds
// Mapping the object IDs that we haven't gotten yet to their original index in objectIds.
Map<UniqueID, Integer> unreadys = new HashMap<>();
for (int i = 0; i < numObjectIds; i++) {
if (ret.get(i).getRight() != GetStatus.SUCCESS) {
@@ -301,15 +320,22 @@ public abstract class RayRuntime implements RayApi {
// Try reconstructing any objects we haven't gotten yet. Try to get them
// until at least PlasmaLink.GET_TIMEOUT_MS milliseconds passes, then repeat.
while (unreadys.size() > 0) {
for (UniqueID id : unreadys.keySet()) {
localSchedulerProxy.reconstructObject(id);
}
// Do another fetch for objects that aren't available locally yet, in case
// they were evicted since the last fetch.
List<UniqueID> unreadyList = new ArrayList<>(unreadys.keySet());
List<List<UniqueID>> reconstructBatches =
splitIntoBatches(unreadyList, params.worker_fetch_request_size);
dividedFetch(unreadyList);
for (List<UniqueID> batch : reconstructBatches) {
if (!params.use_raylet) {
for (UniqueID objectId : batch) {
localSchedulerProxy.reconstructObject(objectId, false);
}
// Do another fetch for objects that aren't available locally yet, in case
// they were evicted since the last fetch.
objectStoreProxy.fetch(batch);
} else {
localSchedulerProxy.reconstructObjects(batch, false);
}
}
List<Pair<T, GetStatus>> results = objectStoreProxy
.get(unreadyList, params.default_get_check_interval_ms, isMetadata);
@@ -329,9 +355,11 @@ public abstract class RayRuntime implements RayApi {
RayLog.core
.debug("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray()) + " get");
List<T> finalRet = new ArrayList<>();
for (Pair<T, GetStatus> value : ret) {
finalRet.add(value.getLeft());
}
return finalRet;
} catch (TaskExecutionException e) {
RayLog.core.error("Task " + taskId + " Objects " + Arrays.toString(objectIds.toArray())
@@ -344,68 +372,30 @@ public abstract class RayRuntime implements RayApi {
localSchedulerProxy.notifyUnblocked();
}
}
}
private <T> T doGet(UniqueID objectId, boolean isMetadata) throws TaskExecutionException {
ImmutableList<UniqueID> objectIds = ImmutableList.of(objectId);
List<T> results = doGet(objectIds, isMetadata);
boolean wasBlocked = false;
UniqueID taskId = getCurrentTaskId();
try {
// Do an initial fetch.
objectStoreProxy.fetch(objectId);
// Get the object. We initially try to get the object immediately.
Pair<T, GetStatus> ret = objectStoreProxy
.get(objectId, params.default_first_check_timeout_ms, isMetadata);
wasBlocked = (ret.getRight() != GetStatus.SUCCESS);
// Try reconstructing the object. Try to get it until at least PlasmaLink.GET_TIMEOUT_MS
// milliseconds passes, then repeat.
while (ret.getRight() != GetStatus.SUCCESS) {
RayLog.core.warn(
"Task " + taskId + " Object " + objectId.toString() + " get failed, reconstruct ...");
localSchedulerProxy.reconstructObject(objectId);
// Do another fetch
objectStoreProxy.fetch(objectId);
ret = objectStoreProxy.get(objectId, params.default_get_check_interval_ms,
isMetadata);//check the result every 5s, but it will return once available
}
RayLog.core.debug(
"Task " + taskId + " Object " + objectId.toString() + " get" + ", the result " + ret
.getLeft());
return ret.getLeft();
} catch (TaskExecutionException e) {
RayLog.core
.error("Task " + taskId + " Object " + objectId.toString() + " get with Exception", e);
throw e;
} finally {
// If the object was not able to get locally, let the local scheduler
// know that we're now unblocked.
if (wasBlocked) {
localSchedulerProxy.notifyUnblocked();
}
}
assert results.size() == 1;
return results.get(0);
}
// We divide the fetch into smaller fetches so as to not block the manager
// for a prolonged period of time in a single call.
private void dividedFetch(List<UniqueID> objectIds) {
int fetchSize = objectStoreProxy.getFetchSize();
private List<List<UniqueID>> splitIntoBatches(List<UniqueID> objectIds, int batchSize) {
List<List<UniqueID>> batches = new ArrayList<>();
int objectsSize = objectIds.size();
int numObjectIds = objectIds.size();
for (int i = 0; i < numObjectIds; i += fetchSize) {
int endIndex = i + fetchSize;
if (endIndex < numObjectIds) {
objectStoreProxy.fetch(objectIds.subList(i, endIndex));
} else {
objectStoreProxy.fetch(objectIds.subList(i, numObjectIds));
}
for (int i = 0; i < objectsSize; i += batchSize) {
int endIndex = i + batchSize;
List<UniqueID> batchIds = (endIndex < objectsSize)
? objectIds.subList(i, endIndex)
: objectIds.subList(i, objectsSize);
batches.add(batchIds);
}
return batches;
}
/**
@@ -112,6 +112,18 @@ public class RayParameters {
@AConfig(comment = "delay seconds under onebox before app logic for debugging")
public int onebox_delay_seconds_before_run_app_logic = 0;
@AConfig(comment = "whether to use raylet")
public boolean use_raylet = false;
@AConfig(comment = "raylet socket name (e.g., /tmp/raylet1111")
public String raylet_socket_name = "";
@AConfig(comment = "raylet rpc listen port")
public int raylet_port = 35567;
@AConfig(comment = "worker fetch request size")
public int worker_fetch_request_size = 10000;
public RayParameters(ConfigReader config) {
if (null != config) {
String networkInterface = config.getStringValue("ray.java", "network_interface", null,
@@ -1,5 +1,6 @@
package org.ray.spi;
import java.util.List;
import org.ray.api.UniqueID;
import org.ray.spi.model.TaskSpec;
@@ -14,7 +15,11 @@ public interface LocalSchedulerLink {
void markTaskPutDependency(UniqueID taskId, UniqueID objectId);
void reconstructObject(UniqueID objectId);
void reconstructObject(UniqueID objectId, boolean fetchOnly);
void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly);
void notifyUnblocked();
List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns);
}
@@ -1,13 +1,17 @@
package org.ray.spi;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.ray.api.RayList;
import org.ray.api.RayMap;
import org.ray.api.RayObject;
import org.ray.api.RayObjects;
import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.ArgumentsBuilder;
import org.ray.core.UniqueIdHelper;
import org.ray.core.WorkerContext;
@@ -124,11 +128,44 @@ public class LocalSchedulerProxy {
scheduler.markTaskPutDependency(taskId, objectId);
}
public void reconstructObject(UniqueID objectId) {
scheduler.reconstructObject(objectId);
public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
scheduler.reconstructObject(objectId, fetchOnly);
}
public void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly) {
scheduler.reconstructObjects(objectIds, fetchOnly);
}
public void notifyUnblocked() {
scheduler.notifyUnblocked();
}
private static byte[][] getIdBytes(List<UniqueID> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
public <T> WaitResult<T> wait(RayList<T> waitfor, int numReturns, int timeout) {
List<UniqueID> ids = new ArrayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
ids.add(obj.getId());
}
List<byte[]> readys = scheduler.wait(getIdBytes(ids), timeout, numReturns);
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
for (RayObject<T> obj : waitfor.Objects()) {
if (readys.contains(obj.getId().getBytes())) {
readyObjs.add(obj);
} else {
remainObjs.add(obj);
}
}
return new WaitResult<>(readyObjs, remainObjs);
}
}
@@ -10,6 +10,7 @@ import org.ray.api.UniqueID;
import org.ray.api.WaitResult;
import org.ray.core.Serializer;
import org.ray.core.WorkerContext;
import org.ray.spi.LocalSchedulerLink;
import org.ray.util.exception.TaskExecutionException;
/**
@@ -19,12 +20,19 @@ import org.ray.util.exception.TaskExecutionException;
public class ObjectStoreProxy {
private final ObjectStoreLink store;
private final LocalSchedulerLink localSchedulerLink;
private final int getTimeoutMs = 1000;
public ObjectStoreProxy(ObjectStoreLink store) {
this.store = store;
this.localSchedulerLink = null;
}
public ObjectStoreProxy(ObjectStoreLink store, LocalSchedulerLink localSchedulerLink) {
this.store = store;
this.localSchedulerLink = localSchedulerLink;
}
public <T> Pair<T, GetStatus> get(UniqueID objectId, boolean isMetadata)
throws TaskExecutionException {
return get(objectId, getTimeoutMs, isMetadata);
@@ -88,7 +96,12 @@ public class ObjectStoreProxy {
for (RayObject<T> obj : waitfor.Objects()) {
ids.add(obj.getId());
}
List<byte[]> readys = store.wait(getIdBytes(ids), timeout, numReturns);
List<byte[]> readys;
if (localSchedulerLink == null) {
readys = store.wait(getIdBytes(ids), timeout, numReturns);
} else {
readys = localSchedulerLink.wait(getIdBytes(ids), timeout, numReturns);
}
RayList<T> readyObjs = new RayList<>();
RayList<T> remainObjs = new RayList<>();
@@ -103,19 +116,14 @@ public class ObjectStoreProxy {
return new WaitResult<>(readyObjs, remainObjs);
}
public void fetch(UniqueID objectId) {
store.fetch(objectId.getBytes());
}
public void fetch(List<UniqueID> objectIds) {
store.fetch(getIdBytes(objectIds));
if (localSchedulerLink == null) {
store.fetch(getIdBytes(objectIds));
} else {
localSchedulerLink.reconstructObjects(objectIds, true);
}
}
public int getFetchSize() {
return 10000;
}
public enum GetStatus {
SUCCESS, FAILED
}
@@ -37,6 +37,9 @@ public class PathConfig {
@AConfig(comment = "path to global scheduler")
public String global_scheduler;
@AConfig(comment = "path to raylet")
public String raylet;
@AConfig(comment = "path to python directory")
public String python_dir;
@@ -8,9 +8,11 @@ public class AddressInfo {
public String managerName;
public String storeName;
public String schedulerName;
public String rayletSocketName;
public int managerPort;
public int workerCount;
public String managerRpcAddr;
public String storeRpcAddr;
public String schedulerRpcAddr;
public String rayletRpcAddr;
}