[Java Worker] Support raylet on Java (#2479)

This commit is contained in:
Wang Qing
2018-08-02 08:52:49 +08:00
committed by Robert Nishihara
parent 9a479b3a63
commit e4f68ff8cf
41 changed files with 916 additions and 303 deletions
@@ -25,8 +25,9 @@ import org.ray.spi.RemoteFunctionManager;
import org.ray.spi.StateStoreProxy;
import org.ray.spi.impl.DefaultLocalSchedulerClient;
import org.ray.spi.impl.NativeRemoteFunctionManager;
import org.ray.spi.impl.NonRayletStateStoreProxyImpl;
import org.ray.spi.impl.RayletStateStoreProxyImpl;
import org.ray.spi.impl.RedisClient;
import org.ray.spi.impl.StateStoreProxyImpl;
import org.ray.spi.model.AddressInfo;
import org.ray.util.exception.TaskExecutionException;
import org.ray.util.logger.RayLog;
@@ -62,14 +63,19 @@ public class RayNativeRuntime extends RayRuntime {
throw new Error("Redis address must be configured under Worker mode.");
}
startOnebox(params, pathConfig);
initStateStore(params.redis_address);
initStateStore(params.redis_address, params.use_raylet);
} else {
initStateStore(params.redis_address);
initStateStore(params.redis_address, params.use_raylet);
if (!isWorker) {
List<AddressInfo> nodes = stateStoreProxy.getAddressInfo(params.node_ip_address, 5);
List<AddressInfo> nodes = stateStoreProxy.getAddressInfo(
params.node_ip_address, params.redis_address, 5);
params.object_store_name = nodes.get(0).storeName;
params.object_store_manager_name = nodes.get(0).managerName;
params.local_scheduler_name = nodes.get(0).schedulerName;
if (!params.use_raylet) {
params.object_store_manager_name = nodes.get(0).managerName;
params.local_scheduler_name = nodes.get(0).schedulerName;
} else {
params.raylet_socket_name = nodes.get(0).rayletSocketName;
}
}
}
@@ -101,23 +107,45 @@ public class RayNativeRuntime extends RayRuntime {
.getIntegerValue("ray", "plasma_default_release_delay", 0,
"how many release requests should be delayed in plasma client");
ObjectStoreLink plink = new PlasmaClient(params.object_store_name, params
.object_store_manager_name, releaseDelay);
if (!params.use_raylet) {
ObjectStoreLink plink = new PlasmaClient(params.object_store_name,
params.object_store_manager_name, releaseDelay);
LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
params.local_scheduler_name,
WorkerContext.currentWorkerId(),
UniqueID.nil,
isWorker,
WorkerContext.currentTask().taskId,
0
);
LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
params.local_scheduler_name,
WorkerContext.currentWorkerId(),
UniqueID.nil,
isWorker,
WorkerContext.currentTask().taskId,
0,
false
);
init(slink, plink, funcMgr, pathConfig);
init(slink, plink, funcMgr, pathConfig);
// register
registerWorker(isWorker, params.node_ip_address, params.object_store_name,
params.object_store_manager_name, params.local_scheduler_name);
// register
registerWorker(isWorker, params.node_ip_address, params.object_store_name,
params.object_store_manager_name, params.local_scheduler_name);
} else {
ObjectStoreLink plink = new PlasmaClient(params.object_store_name, "", releaseDelay);
LocalSchedulerLink slink = new DefaultLocalSchedulerClient(
params.raylet_socket_name,
WorkerContext.currentWorkerId(),
UniqueID.nil,
isWorker,
WorkerContext.currentTask().taskId,
0,
true
);
init(slink, plink, funcMgr, pathConfig);
// register
registerWorker(isWorker, params.node_ip_address, params.object_store_name,
params.raylet_socket_name);
}
}
RayLog.core.info("RayNativeRuntime start with "
@@ -152,19 +180,44 @@ public class RayNativeRuntime extends RayRuntime {
params.object_store_name = manager.info().localStores.get(0).storeName;
params.object_store_manager_name = manager.info().localStores.get(0).managerName;
params.local_scheduler_name = manager.info().localStores.get(0).schedulerName;
params.raylet_socket_name = manager.info().localStores.get(0).rayletSocketName;
//params.node_ip_address = NetworkUtil.getIpAddress();
}
private void initStateStore(String redisAddress) throws Exception {
private void initStateStore(String redisAddress, boolean useRaylet) throws Exception {
kvStore = new RedisClient();
kvStore.setAddr(redisAddress);
stateStoreProxy = new StateStoreProxyImpl(kvStore);
stateStoreProxy = useRaylet
? new RayletStateStoreProxyImpl(kvStore)
: new NonRayletStateStoreProxyImpl(kvStore);
//stateStoreProxy.setStore(kvStore);
stateStoreProxy.initializeGlobalState();
}
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
String managerName, String schedulerName) {
String rayletSocketName) {
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
if (!isWorker) {
workerInfo.put("node_ip_address", nodeIpAddress);
workerInfo.put("driver_id", workerId);
workerInfo.put("start_time", String.valueOf(System.currentTimeMillis()));
workerInfo.put("plasma_store_socket", storeName);
workerInfo.put("raylet_socket", rayletSocketName);
workerInfo.put("name", System.getProperty("user.dir"));
//TODO: worker.redis_client.hmset(b"Drivers:" + worker.workerId, driver_info)
kvStore.hmset("Drivers:" + workerId, workerInfo);
} else {
workerInfo.put("node_ip_address", nodeIpAddress);
workerInfo.put("plasma_store_socket", storeName);
workerInfo.put("raylet_socket", rayletSocketName);
//TODO: b"Workers:" + worker.workerId,
kvStore.hmset("Workers:" + workerId, workerInfo);
}
}
private void registerWorker(boolean isWorker, String nodeIpAddress, String storeName,
String managerName, String schedulerName) {
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(WorkerContext.currentWorkerId().getBytes());
if (!isWorker) {
@@ -0,0 +1,79 @@
package org.ray.format.gcs;
// automatically generated by the FlatBuffers compiler, do not modify
import java.nio.*;
import java.lang.*;
import com.google.flatbuffers.*;
@SuppressWarnings("unused")
public final class ClientTableData extends Table {
public static ClientTableData getRootAsClientTableData(ByteBuffer _bb) { return getRootAsClientTableData(_bb, new ClientTableData()); }
public static ClientTableData getRootAsClientTableData(ByteBuffer _bb, ClientTableData obj) { _bb.order(ByteOrder.LITTLE_ENDIAN); return (obj.__assign(_bb.getInt(_bb.position()) + _bb.position(), _bb)); }
public void __init(int _i, ByteBuffer _bb) { bb_pos = _i; bb = _bb; }
public ClientTableData __assign(int _i, ByteBuffer _bb) { __init(_i, _bb); return this; }
public String clientId() { int o = __offset(4); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer clientIdAsByteBuffer() { return __vector_as_bytebuffer(4, 1); }
public ByteBuffer clientIdInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 1); }
public String nodeManagerAddress() { int o = __offset(6); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer nodeManagerAddressAsByteBuffer() { return __vector_as_bytebuffer(6, 1); }
public ByteBuffer nodeManagerAddressInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); }
public String rayletSocketName() { int o = __offset(8); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer rayletSocketNameAsByteBuffer() { return __vector_as_bytebuffer(8, 1); }
public ByteBuffer rayletSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 8, 1); }
public String objectStoreSocketName() { int o = __offset(10); return o != 0 ? __string(o + bb_pos) : null; }
public ByteBuffer objectStoreSocketNameAsByteBuffer() { return __vector_as_bytebuffer(10, 1); }
public ByteBuffer objectStoreSocketNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 10, 1); }
public int nodeManagerPort() { int o = __offset(12); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public int objectManagerPort() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public boolean isInsertion() { int o = __offset(16); return o != 0 ? 0!=bb.get(o + bb_pos) : false; }
public String resourcesTotalLabel(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int resourcesTotalLabelLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
public double resourcesTotalCapacity(int j) { int o = __offset(20); return o != 0 ? bb.getDouble(__vector(o) + j * 8) : 0; }
public int resourcesTotalCapacityLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
public ByteBuffer resourcesTotalCapacityAsByteBuffer() { return __vector_as_bytebuffer(20, 8); }
public ByteBuffer resourcesTotalCapacityInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 20, 8); }
public static int createClientTableData(FlatBufferBuilder builder,
int client_idOffset,
int node_manager_addressOffset,
int raylet_socket_nameOffset,
int object_store_socket_nameOffset,
int node_manager_port,
int object_manager_port,
boolean is_insertion,
int resources_total_labelOffset,
int resources_total_capacityOffset) {
builder.startObject(9);
ClientTableData.addResourcesTotalCapacity(builder, resources_total_capacityOffset);
ClientTableData.addResourcesTotalLabel(builder, resources_total_labelOffset);
ClientTableData.addObjectManagerPort(builder, object_manager_port);
ClientTableData.addNodeManagerPort(builder, node_manager_port);
ClientTableData.addObjectStoreSocketName(builder, object_store_socket_nameOffset);
ClientTableData.addRayletSocketName(builder, raylet_socket_nameOffset);
ClientTableData.addNodeManagerAddress(builder, node_manager_addressOffset);
ClientTableData.addClientId(builder, client_idOffset);
ClientTableData.addIsInsertion(builder, is_insertion);
return ClientTableData.endClientTableData(builder);
}
public static void startClientTableData(FlatBufferBuilder builder) { builder.startObject(9); }
public static void addClientId(FlatBufferBuilder builder, int clientIdOffset) { builder.addOffset(0, clientIdOffset, 0); }
public static void addNodeManagerAddress(FlatBufferBuilder builder, int nodeManagerAddressOffset) { builder.addOffset(1, nodeManagerAddressOffset, 0); }
public static void addRayletSocketName(FlatBufferBuilder builder, int rayletSocketNameOffset) { builder.addOffset(2, rayletSocketNameOffset, 0); }
public static void addObjectStoreSocketName(FlatBufferBuilder builder, int objectStoreSocketNameOffset) { builder.addOffset(3, objectStoreSocketNameOffset, 0); }
public static void addNodeManagerPort(FlatBufferBuilder builder, int nodeManagerPort) { builder.addInt(4, nodeManagerPort, 0); }
public static void addObjectManagerPort(FlatBufferBuilder builder, int objectManagerPort) { builder.addInt(5, objectManagerPort, 0); }
public static void addIsInsertion(FlatBufferBuilder builder, boolean isInsertion) { builder.addBoolean(6, isInsertion, false); }
public static void addResourcesTotalLabel(FlatBufferBuilder builder, int resourcesTotalLabelOffset) { builder.addOffset(7, resourcesTotalLabelOffset, 0); }
public static int createResourcesTotalLabelVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startResourcesTotalLabelVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addResourcesTotalCapacity(FlatBufferBuilder builder, int resourcesTotalCapacityOffset) { builder.addOffset(8, resourcesTotalCapacityOffset, 0); }
public static int createResourcesTotalCapacityVector(FlatBufferBuilder builder, double[] data) { builder.startVector(8, data.length, 8); for (int i = data.length - 1; i >= 0; i--) builder.addDouble(data[i]); return builder.endVector(); }
public static void startResourcesTotalCapacityVector(FlatBufferBuilder builder, int numElems) { builder.startVector(8, numElems, 8); }
public static int endClientTableData(FlatBufferBuilder builder) {
int o = builder.endObject();
return o;
}
}
@@ -35,7 +35,7 @@ public class RunInfo {
public enum ProcessType {
PT_WORKER, PT_LOCAL_SCHEDULER, PT_PLASMA_MANAGER, PT_PLASMA_STORE,
PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI,
PT_GLOBAL_SCHEDULER, PT_REDIS_SERVER, PT_WEB_UI, PT_RAYLET,
PT_DRIVER
}
}
@@ -48,7 +48,7 @@ public class RunManager {
private static boolean killProcess(Process p) {
if (p.isAlive()) {
p.destroyForcibly();
p.destroy();
return true;
} else {
return false;
@@ -307,7 +307,7 @@ public class RunManager {
redisClient.close();
// start global scheduler
if (params.include_global_scheduler) {
if (params.include_global_scheduler && !params.use_raylet) {
startGlobalScheduler(params.working_directory + "/globalScheduler",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
}
@@ -340,49 +340,70 @@ public class RunManager {
}
}
// start object stores
for (int i = 0; i < params.num_local_schedulers; i++) {
AddressInfo info = new AddressInfo();
// store
startObjectStore(i, info, params.working_directory + "/store",
AddressInfo info = new AddressInfo();
if (params.use_raylet) {
// Start object store
int rpcPort = params.object_store_rpc_port;
String storeName = "/tmp/plasma_store" + rpcPort;
startObjectStore(0, info, params.working_directory + "/store",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
// store manager
startObjectManager(i, info,
params.working_directory + "/storeManager", params.redis_address,
params.node_ip_address, params.redirect, params.cleanup);
//Start raylet
startRaylet(storeName, info, params.num_cpus[0],params.num_gpus[0],
params.num_workers,params.working_directory + "/raylet",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
runInfo.localStores.add(info);
}
} else {
for (int i = 0; i < params.num_local_schedulers; i++) {
// Start object stores
startObjectStore(i, info, params.working_directory + "/store",
params.redis_address, params.node_ip_address, params.redirect, params.cleanup);
// start local scheduler
for (int i = 0; i < params.num_local_schedulers; i++) {
int workerCount = 0;
startObjectManager(i, info,
params.working_directory + "/storeManager", params.redis_address,
params.node_ip_address, params.redirect, params.cleanup);
if (params.start_workers_from_local_scheduler) {
workerCount = localNumWorkers[i];
localNumWorkers[i] = 0;
// Start local scheduler
int workerCount = 0;
if (params.start_workers_from_local_scheduler) {
workerCount = localNumWorkers[i];
localNumWorkers[i] = 0;
}
startLocalScheduler(i, info,
params.num_cpus[i], params.num_gpus[i], workerCount,
params.working_directory + "/localsc", params.redis_address,
params.node_ip_address, params.redirect, params.cleanup);
runInfo.localStores.add(info);
}
startLocalScheduler(i, runInfo.localStores.get(i),
params.num_cpus[i], params.num_gpus[i], workerCount,
params.working_directory + "/localScheduler", params.redis_address,
params.node_ip_address, params.redirect, params.cleanup);
}
// start local workers
for (int i = 0; i < params.num_local_schedulers; i++) {
runInfo.localStores.get(i).workerCount = localNumWorkers[i];
for (int j = 0; j < localNumWorkers[i]; j++) {
startWorker(runInfo.localStores.get(i).storeName,
runInfo.localStores.get(i).managerName, runInfo.localStores.get(i).schedulerName,
params.working_directory + "/worker" + i + "." + j, params.redis_address,
params.node_ip_address, UniqueID.nil, "",
params.redirect, params.cleanup);
if (!params.use_raylet) {
for (int i = 0; i < params.num_local_schedulers; i++) {
AddressInfo localStores = runInfo.localStores.get(i);
localStores.workerCount = localNumWorkers[i];
for (int j = 0; j < localNumWorkers[i]; j++) {
startWorker(localStores.storeName, localStores.managerName, localStores.schedulerName,
params.working_directory + "/worker" + i + "." + j, params.redis_address,
params.node_ip_address, UniqueID.nil, "", params.redirect, params.cleanup);
}
}
}
HashSet<RunInfo.ProcessType> excludeTypes = new HashSet<>();
if (!params.use_raylet) {
excludeTypes.add(RunInfo.ProcessType.PT_RAYLET);
} else {
excludeTypes.add(RunInfo.ProcessType.PT_LOCAL_SCHEDULER);
excludeTypes.add(RunInfo.ProcessType.PT_GLOBAL_SCHEDULER);
excludeTypes.add(RunInfo.ProcessType.PT_PLASMA_MANAGER);
}
if (!checkAlive(excludeTypes)) {
cleanup(true);
throw new RuntimeException("Start Ray processes failed");
@@ -622,8 +643,8 @@ public class RunManager {
cmd += " -m " + info.managerName;
String workerCmd = null;
workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name, UniqueID.nil,
"", workDir + rpcPort, ip, redisAddress);
workerCmd = buildWorkerCommand(true, info.storeName, info.managerName, name,
UniqueID.nil, "", workDir + rpcPort, ip, redisAddress);
cmd += " -w \"" + workerCmd + "\"";
if (redisAddress.length() > 0) {
@@ -656,6 +677,82 @@ public class RunManager {
}
}
private void startRaylet(String storeName, AddressInfo info, int numCpus,
int numGpus, int numWorkers, String workDir,
String redisAddress, String ip, boolean redirect,
boolean cleanup) {
int rpcPort = params.raylet_port;
String rayletSocketName = "/tmp/raylet" + rpcPort;
String filePath = paths.raylet;
String workerCmd = null;
workerCmd = buildWorkerCommandRaylet(info.storeName, rayletSocketName, UniqueID.nil,
"", workDir + rpcPort, ip, redisAddress);
int sep = redisAddress.indexOf(':');
assert (sep != -1);
String gcsIp = redisAddress.substring(0, sep);
String gcsPort = redisAddress.substring(sep + 1);
String resourceArgument = "GPU," + numGpus + ",CPU," + numCpus;
String[] cmds = new String[]{filePath, rayletSocketName, storeName, ip, gcsIp,
gcsPort, "" + numWorkers, workerCmd, resourceArgument};
Process p = startProcess(cmds, null, RunInfo.ProcessType.PT_RAYLET,
workDir + rpcPort, redisAddress, ip, redirect, cleanup);
if (p != null && p.isAlive()) {
try {
TimeUnit.MILLISECONDS.sleep(100);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
if (p == null || !p.isAlive()) {
info.rayletSocketName = "";
info.rayletRpcAddr = "";
throw new RuntimeException("Failed to start raylet process.");
} else {
info.rayletSocketName = rayletSocketName;
info.rayletRpcAddr = ip + ":" + rpcPort;
}
}
private String buildWorkerCommandRaylet(String storeName, String rayletSocketName,
UniqueID actorId, String actorClass, String workDir,
String ip, String redisAddress) {
String workerConfigs = "ray.java.start.object_store_name=" + storeName
+ ";ray.java.start.raylet_socket_name=" + rayletSocketName
+ ";ray.java.start.worker_mode=WORKER;ray.java.start.use_raylet=true";
workerConfigs += ";ray.java.start.deploy=" + params.deploy;
if (!actorId.equals(UniqueID.nil)) {
workerConfigs += ";ray.java.start.actor_id=" + actorId;
}
if (!actorClass.equals("")) {
workerConfigs += ";ray.java.start.driver_class=" + actorClass;
}
String jvmArgs = "";
jvmArgs += " -Dlogging.path=" + params.working_directory + "/logs/workers";
jvmArgs += " -Dlogging.file.name=core-*pid_suffix*";
return buildJavaProcessCommand(
RunInfo.ProcessType.PT_WORKER,
"org.ray.runner.worker.DefaultWorker",
"",
workerConfigs,
jvmArgs,
workDir,
ip,
redisAddress,
null
);
}
private String buildWorkerCommand(boolean isFromLocalScheduler, String storeName,
String storeManagerName, String localSchedulerName,
UniqueID actorId, String actorClass, String workDir, String
@@ -103,6 +103,15 @@ public interface KeyValueStoreLink {
*/
List<String> lrange(final String key, final long start, final long end);
/**
* Return the set of elements of the sorted set stored at the specified key.
* @param key The specified key you want to query.
* @param start The start index of the range.
* @param end The end index of the range.
* @return The set of elements you queried.
*/
Set<byte[]> zrange(byte[] key, long start, long end);
/**
* Rpush.
* @return Integer reply, specifically, the number of elements inside the list after the push
@@ -123,4 +132,7 @@ public interface KeyValueStoreLink {
Long publish(byte[] channel, byte[] message);
Object getImpl();
byte[] sendCommand(String command, int commandType, byte[] objectId);
}
@@ -31,5 +31,7 @@ public interface StateStoreProxy {
* getAddressInfo.
* @return list of address information
*/
List<AddressInfo> getAddressInfo(final String nodeIpAddress, int numRetries);
List<AddressInfo> getAddressInfo(final String nodeIpAddress,
final String redisAddress,
int numRetries);
}
@@ -0,0 +1,124 @@
package org.ray.spi.impl;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.ray.spi.KeyValueStoreLink;
import org.ray.spi.StateStoreProxy;
import org.ray.spi.model.AddressInfo;
import org.ray.util.logger.RayLog;
/**
* Base class used to interface with the Ray control state.
*/
public abstract class BaseStateStoreProxyImpl implements StateStoreProxy {
public KeyValueStoreLink rayKvStore;
public ArrayList<KeyValueStoreLink> shardStoreList = new ArrayList<>();
public BaseStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
@Override
public void setStore(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
@Override
public synchronized void initializeGlobalState() throws Exception {
String es;
checkConnected();
String s = rayKvStore.get("NumRedisShards", null);
if (s == null) {
throw new Exception("NumRedisShards not found in redis.");
}
int numRedisShards = Integer.parseInt(s);
if (numRedisShards < 1) {
es = String.format("Expected at least one Redis shard, found %d", numRedisShards);
throw new Exception(es);
}
List<String> ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1);
Set<String> distinctIpAddress = new HashSet<String>(ipAddressPorts);
if (distinctIpAddress.size() != numRedisShards) {
es = String.format("Expected %d Redis shard addresses, found2 %d.", numRedisShards,
distinctIpAddress.size());
throw new Exception(es);
}
shardStoreList.clear();
for (String ipPort : distinctIpAddress) {
shardStoreList.add(new RedisClient(ipPort));
}
}
public void checkConnected() throws Exception {
rayKvStore.checkConnected();
}
@Override
public synchronized Set<String> keys(final String pattern) {
Set<String> allKeys = new HashSet<>();
Set<String> tmpKey;
for (KeyValueStoreLink ashardStoreList : shardStoreList) {
tmpKey = ashardStoreList.keys(pattern);
allKeys.addAll(tmpKey);
}
return allKeys;
}
@Override
public List<AddressInfo> getAddressInfo(final String nodeIpAddress,
final String redisAddress,
int numRetries) {
int count = 0;
while (count < numRetries) {
try {
return doGetAddressInfo(nodeIpAddress, redisAddress);
} catch (Exception e) {
try {
RayLog.core.warn("Error occurred in BaseStateStoreProxyImpl getAddressInfo, "
+ (numRetries - count) + " retries remaining", e);
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException ie) {
RayLog.core.error("error at BaseStateStoreProxyImpl getAddressInfo", e);
throw new RuntimeException(e);
}
}
count++;
}
throw new RuntimeException("cannot get address info from state store");
}
/**
* Get address info of one node from primary redis.
* This method only tries to get address info once, without any retry.
*
* @param nodeIpAddress Usually local ip address.
* @param redisAddress The primary redis address.
* @return A list of SchedulerInfo which contains node manager or local scheduler address info.
* @throws Exception No redis client exception.
*/
protected abstract List<AddressInfo> doGetAddressInfo(final String nodeIpAddress,
final String redisAddress) throws Exception;
protected String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException {
return new String(bs, charset);
}
protected byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException {
if (str != null) {
return str.getBytes(charset);
}
return null;
}
}
@@ -24,20 +24,44 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
return bb;
});
private long client = 0;
boolean useRaylet = false;
public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId, UniqueID actorId,
boolean isWorker, UniqueID driverId, long numGpus) {
public DefaultLocalSchedulerClient(String schedulerSockName, UniqueID clientId,
UniqueID actorId, boolean isWorker, UniqueID driverId,
long numGpus, boolean useRaylet) {
client = _init(schedulerSockName, clientId.getBytes(), actorId.getBytes(), isWorker,
driverId.getBytes(), numGpus);
driverId.getBytes(), numGpus, useRaylet);
this.useRaylet = useRaylet;
}
private static native long _init(String localSchedulerSocket, byte[] workerId, byte[] actorId,
boolean isWorker, byte[] driverTaskId, long numGpus);
private static native long _init(String localSchedulerSocket, byte[] workerId,
byte[] actorId, boolean isWorker, byte[] driverTaskId,
long numGpus, boolean useRaylet);
private static native byte[] _computePutId(long client, byte[] taskId, int putIndex);
private static native void _task_done(long client);
private static native boolean[] _waitObject(long conn, byte[][] objectIds,
int numReturns, int timeout, boolean waitLocal);
@Override
public List<byte[]> wait(byte[][] objectIds, int timeoutMs, int numReturns) {
assert (useRaylet == true);
boolean[] readys = _waitObject(client, objectIds, numReturns, timeoutMs, false);
List<byte[]> ret = new ArrayList<>();
for (int i = 0; i < readys.length; i++) {
if (readys[i]) {
ret.add(objectIds[i]);
}
}
assert (ret.size() == readys.length);
return ret;
}
@Override
public void submitTask(TaskSpec task) {
ByteBuffer info = taskSpec2Info(task);
@@ -45,12 +69,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
if (!task.actorId.isNil()) {
a = task.cursorId.getBytes();
}
_submitTask(client, a, info, info.position(), info.remaining());
_submitTask(client, a, info, info.position(), info.remaining(), useRaylet);
}
@Override
public TaskSpec getTaskTodo() {
byte[] bytes = _getTaskTodo(client);
byte[] bytes = _getTaskTodo(client, useRaylet);
assert (null != bytes);
ByteBuffer bb = ByteBuffer.wrap(bytes);
return taskInfo2Spec(bb);
@@ -62,8 +87,16 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
}
@Override
public void reconstructObject(UniqueID objectId) {
_reconstruct_object(client, objectId.getBytes());
public void reconstructObject(UniqueID objectId, boolean fetchOnly) {
List<UniqueID> objects = new ArrayList<>();
objects.add(objectId);
_reconstruct_objects(client, getIdBytes(objects), fetchOnly);
}
@Override
public void reconstructObjects(List<UniqueID> objectIds, boolean fetchOnly) {
RayLog.core.info("reconstruct objects {}", objectIds);
_reconstruct_objects(client, getIdBytes(objectIds), fetchOnly);
}
@Override
@@ -73,12 +106,13 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
private static native void _notify_unblocked(long client);
private static native void _reconstruct_object(long client, byte[] objectId);
private static native void _reconstruct_objects(long client, byte[][] objectIds,
boolean fetchOnly);
private static native void _put_object(long client, byte[] taskId, byte[] objectId);
// return TaskInfo (in FlatBuffer)
private static native byte[] _getTaskTodo(long client);
private static native byte[] _getTaskTodo(long client, boolean useRaylet);
public static TaskSpec taskInfo2Spec(ByteBuffer bb) {
bb.order(ByteOrder.LITTLE_ENDIAN);
@@ -162,7 +196,10 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
idOffsets[k] = fbb.createString(task.args[i].ids.get(k).toByteBuffer());
}
objectIdOffset = fbb.createVectorOfTables(idOffsets);
} else {
objectIdOffset = fbb.createVectorOfTables(new int[0]);
}
if (task.args[i].data != null) {
dataOffset = fbb.createString(ByteBuffer.wrap(task.args[i].data));
}
@@ -214,8 +251,17 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink {
}
// task -> TaskInfo (with FlatBuffer)
private static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task,
int pos, int sz);
protected static native void _submitTask(long client, byte[] cursorId, /*Direct*/ByteBuffer task,
int pos, int sz, boolean useRaylet);
private static byte[][] getIdBytes(List<UniqueID> objectIds) {
int size = objectIds.size();
byte[][] ids = new byte[size][];
for (int i = 0; i < size; i++) {
ids[i] = objectIds.get(i).getBytes();
}
return ids;
}
public void destroy() {
_destroy(client);
@@ -1,96 +1,18 @@
package org.ray.spi.impl;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import org.ray.spi.KeyValueStoreLink;
import org.ray.spi.StateStoreProxy;
import org.ray.spi.model.AddressInfo;
import org.ray.util.logger.RayLog;
/**
* A class used to interface with the Ray control state.
* A class used to interface with the Ray control state for non-raylet.
*/
public class StateStoreProxyImpl implements StateStoreProxy {
public KeyValueStoreLink rayKvStore;
public ArrayList<KeyValueStoreLink> shardStoreList = new ArrayList<>();
public StateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
public void setStore(KeyValueStoreLink rayKvStore) {
this.rayKvStore = rayKvStore;
}
public synchronized void initializeGlobalState() throws Exception {
String es;
checkConnected();
String s = rayKvStore.get("NumRedisShards", null);
if (s == null) {
throw new Exception("NumRedisShards not found in redis.");
}
int numRedisShards = Integer.parseInt(s);
if (numRedisShards < 1) {
es = String.format("Expected at least one Redis shard, found %d", numRedisShards);
throw new Exception(es);
}
List<String> ipAddressPorts = rayKvStore.lrange("RedisShards", 0, -1);
if (ipAddressPorts.size() != numRedisShards) {
es = String.format("Expected %d Redis shard addresses, found %d.", numRedisShards,
ipAddressPorts.size());
throw new Exception(es);
}
shardStoreList.clear();
for (String ipPort : ipAddressPorts) {
shardStoreList.add(new RedisClient(ipPort));
}
}
public void checkConnected() throws Exception {
rayKvStore.checkConnected();
}
public synchronized Set<String> keys(final String pattern) {
Set<String> allKeys = new HashSet<>();
Set<String> tmpKey;
for (KeyValueStoreLink ashardStoreList : shardStoreList) {
tmpKey = ashardStoreList.keys(pattern);
allKeys.addAll(tmpKey);
}
return allKeys;
}
public List<AddressInfo> getAddressInfo(final String nodeIpAddress, int numRetries) {
int count = 0;
while (count < numRetries) {
try {
return getAddressInfoHelper(nodeIpAddress);
} catch (Exception e) {
try {
RayLog.core.warn("Error occurred in StateStoreProxyImpl getAddressInfo, "
+ (numRetries - count) + " retries remaining", e);
TimeUnit.MILLISECONDS.sleep(1000);
} catch (InterruptedException ie) {
RayLog.core.error("error at StateStoreProxyImpl getAddressInfo", e);
throw new RuntimeException(e);
}
}
count++;
}
throw new RuntimeException("cannot get address info from state store");
public class NonRayletStateStoreProxyImpl extends BaseStateStoreProxyImpl {
public NonRayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
super(rayKvStore);
}
/*
@@ -108,9 +30,11 @@ public class StateStoreProxyImpl implements StateStoreProxy {
* "manager_socket_name"(op)
* "local_scheduler_socket_name"(op)
*/
public List<AddressInfo> getAddressInfoHelper(final String nodeIpAddress) throws Exception {
@Override
public List<AddressInfo> doGetAddressInfo(final String nodeIpAddress,
final String redisAddress) throws Exception {
if (this.rayKvStore == null) {
throw new Exception("no redis client when use getAddressInfoHelper");
throw new Exception("no redis client when use doGetAddressInfo");
}
List<AddressInfo> schedulerInfo = new ArrayList<>();
@@ -136,13 +60,13 @@ public class StateStoreProxyImpl implements StateStoreProxy {
} else if (!info.containsKey("client_type".getBytes())) {
throw new Exception("no client_type in any client");
}
if (charsetDecode(info.get("node_ip_address".getBytes()), "US-ASCII")
.equals(nodeIpAddress)) {
String clientType = charsetDecode(info.get("client_type".getBytes()), "US-ASCII");
if (clientType.equals("plasma_manager")) {
if ("plasma_manager".equals(clientType)) {
plasmaManager.add(info);
} else if (clientType.equals("local_scheduler")) {
} else if ("local_scheduler".equals(clientType)) {
localScheduler.add(info);
}
}
@@ -157,9 +81,9 @@ public class StateStoreProxyImpl implements StateStoreProxy {
for (int i = 0; i < plasmaManager.size(); i++) {
AddressInfo si = new AddressInfo();
si.storeName = charsetDecode(plasmaManager.get(i).get("store_socket_name".getBytes()),
"US-ASCII");
"US-ASCII");
si.managerName = charsetDecode(plasmaManager.get(i).get("manager_socket_name".getBytes()),
"US-ASCII");
"US-ASCII");
byte[] rpc = plasmaManager.get(i).get("manager_rpc_name".getBytes());
if (rpc != null) {
@@ -188,14 +112,4 @@ public class StateStoreProxyImpl implements StateStoreProxy {
return schedulerInfo;
}
private String charsetDecode(byte[] bs, String charset) throws UnsupportedEncodingException {
return new String(bs, charset);
}
private byte[] charsetEncode(String str, String charset) throws UnsupportedEncodingException {
if (str != null) {
return str.getBytes(charset);
}
return null;
}
}
@@ -0,0 +1,62 @@
package org.ray.spi.impl;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.ray.api.UniqueID;
import org.ray.format.gcs.ClientTableData;
import org.ray.spi.KeyValueStoreLink;
import org.ray.spi.model.AddressInfo;
import org.ray.util.NetworkUtil;
/**
* A class used to interface with the Ray control state for raylet.
*/
public class RayletStateStoreProxyImpl extends BaseStateStoreProxyImpl {
public RayletStateStoreProxyImpl(KeyValueStoreLink rayKvStore) {
super(rayKvStore);
}
@Override
public List<AddressInfo> doGetAddressInfo(final String nodeIpAddress,
final String redisAddress) throws Exception {
if (this.rayKvStore == null) {
throw new Exception("no redis client when use doGetAddressInfo");
}
List<AddressInfo> schedulerInfo = new ArrayList<>();
byte[] prefix = "CLIENT".getBytes();
byte[] postfix = UniqueID.genNil().getBytes();
byte[] clientKey = new byte[prefix.length + postfix.length];
System.arraycopy(prefix, 0, clientKey, 0, prefix.length);
System.arraycopy(postfix, 0, clientKey, prefix.length, postfix.length);
Set<byte[]> clients = rayKvStore.zrange(clientKey, 0, -1);
for (byte[] clientMessage : clients) {
ByteBuffer bb = ByteBuffer.wrap(clientMessage);
ClientTableData client = ClientTableData.getRootAsClientTableData(bb);
String clientNodeIpAddress = client.nodeManagerAddress();
String localIpAddress = NetworkUtil.getIpAddress(null);
String redisIpAddress = redisAddress.substring(0, redisAddress.indexOf(':'));
boolean headNodeAddress = "127.0.0.1".equals(clientNodeIpAddress)
&& Objects.equals(redisIpAddress, localIpAddress);
boolean notHeadNodeAddress = Objects.equals(clientNodeIpAddress, nodeIpAddress);
if (headNodeAddress || notHeadNodeAddress) {
AddressInfo si = new AddressInfo();
si.storeName = client.objectStoreSocketName();
si.rayletSocketName = client.rayletSocketName();
si.managerRpcAddr = client.nodeManagerAddress();
si.managerPort = client.nodeManagerPort();
schedulerInfo.add(si);
}
}
return schedulerInfo;
}
}
@@ -13,6 +13,7 @@ public class RedisClient implements KeyValueStoreLink {
private String redisAddress;
private JedisPool jedisPool;
private int handle = 0;
public RedisClient() {
}
@@ -171,6 +172,13 @@ public class RedisClient implements KeyValueStoreLink {
}
}
@Override
public Set<byte[]> zrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.zrange(key, start, end);
}
}
@Override
public Long rpush(String key, String... strings) {
try (Jedis jedis = jedisPool.getResource()) {
@@ -203,4 +211,20 @@ public class RedisClient implements KeyValueStoreLink {
public Object getImpl() {
return jedisPool;
}
@Override
public byte[] sendCommand(String command, int commandType, byte[] objectId) {
if (handle == 0) {
String[] ipPort = redisAddress.split(":");
handle = connect(ipPort[0], Integer.parseInt(ipPort[1]));
}
return execute_command(handle, command, commandType, objectId);
}
private static native int connect(String redisAddress, int port);
private static native void disconnect(int handle);
private static native byte[] execute_command(int handle,
String command, int commandType, byte[] objectId);
}