mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 17:04:56 +08:00
[Java] Implement GcsClient (#4601)
This commit is contained in:
@@ -16,6 +16,7 @@ import org.ray.api.RuntimeContext;
|
||||
import org.ray.api.WaitResult;
|
||||
import org.ray.api.exception.RayException;
|
||||
import org.ray.api.function.RayFunc;
|
||||
import org.ray.api.gcs.GcsClient;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.api.options.ActorCreationOptions;
|
||||
import org.ray.api.options.BaseTaskOptions;
|
||||
@@ -67,6 +68,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
protected ObjectStoreProxy objectStoreProxy;
|
||||
protected FunctionManager functionManager;
|
||||
protected RuntimeContext runtimeContext;
|
||||
protected GcsClient gcsClient;
|
||||
|
||||
public AbstractRayRuntime(RayConfig rayConfig) {
|
||||
this.rayConfig = rayConfig;
|
||||
@@ -317,6 +319,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
|
||||
return actor;
|
||||
}
|
||||
|
||||
@Override
|
||||
public GcsClient getGcsClient() {
|
||||
return gcsClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the task specification.
|
||||
*
|
||||
|
||||
@@ -6,27 +6,18 @@ import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.nio.file.StandardCopyOption;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RayConfig;
|
||||
import org.ray.runtime.config.WorkerMode;
|
||||
import org.ray.runtime.gcs.GcsClientImpl;
|
||||
import org.ray.runtime.gcs.RedisClient;
|
||||
import org.ray.runtime.generated.ActorCheckpointIdData;
|
||||
import org.ray.runtime.generated.TablePrefix;
|
||||
import org.ray.runtime.objectstore.ObjectStoreProxy;
|
||||
import org.ray.runtime.raylet.RayletClientImpl;
|
||||
import org.ray.runtime.runner.RunManager;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
@@ -37,14 +28,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
|
||||
|
||||
/**
|
||||
* Redis client of the primary shard.
|
||||
*/
|
||||
private RedisClient redisClient;
|
||||
/**
|
||||
* Redis clients of all shards.
|
||||
*/
|
||||
private List<RedisClient> redisClients;
|
||||
private RunManager manager = null;
|
||||
|
||||
static {
|
||||
@@ -109,7 +92,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
manager.startRayProcesses(true);
|
||||
}
|
||||
|
||||
initRedisClients();
|
||||
gcsClient = new GcsClientImpl(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
|
||||
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
|
||||
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
|
||||
@@ -128,16 +111,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
|
||||
}
|
||||
|
||||
private void initRedisClients() {
|
||||
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
|
||||
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
|
||||
Preconditions.checkState(numRedisShards == addresses.size());
|
||||
redisClients = addresses.stream().map(RedisClient::new)
|
||||
.collect(Collectors.toList());
|
||||
redisClients.add(redisClient);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void shutdown() {
|
||||
if (null != manager) {
|
||||
@@ -145,7 +118,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Register this worker or driver to GCS.
|
||||
*/
|
||||
private void registerWorker() {
|
||||
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
|
||||
Map<String, String> workerInfo = new HashMap<>();
|
||||
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
|
||||
if (rayConfig.workerMode == WorkerMode.DRIVER) {
|
||||
@@ -165,70 +142,4 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
|
||||
redisClient.hmset("Workers:" + workerId, workerInfo);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
|
||||
* timestamp in descending order.
|
||||
*/
|
||||
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
|
||||
List<Checkpoint> checkpoints = new ArrayList<>();
|
||||
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
|
||||
// all redis shards..
|
||||
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
|
||||
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
|
||||
for (RedisClient client : redisClients) {
|
||||
byte[] result = client.get(key, null);
|
||||
if (result == null) {
|
||||
continue;
|
||||
}
|
||||
ActorCheckpointIdData data = ActorCheckpointIdData
|
||||
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
|
||||
|
||||
UniqueId[] checkpointIds
|
||||
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());
|
||||
|
||||
for (int i = 0; i < checkpointIds.length; i++) {
|
||||
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
|
||||
return checkpoints;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Query whether the actor exists in Gcs.
|
||||
*/
|
||||
boolean actorExistsInGcs(UniqueId actorId) {
|
||||
byte[] key = ArrayUtils.addAll("ACTOR".getBytes(), actorId.getBytes());
|
||||
|
||||
// TODO(qwang): refactor this with `GlobalState` after this issue
|
||||
// getting finished. https://github.com/ray-project/ray/issues/3933
|
||||
for (RedisClient client : redisClients) {
|
||||
if (client.exists(key)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query whether the raylet task exists in Gcs.
|
||||
*/
|
||||
public boolean rayletTaskExistsInGcs(UniqueId taskId) {
|
||||
byte[] key = ArrayUtils.addAll("RAYLET_TASK".getBytes(), taskId.getBytes());
|
||||
|
||||
// TODO(qwang): refactor this with `GlobalState` after this issue
|
||||
// getting finished. https://github.com/ray-project/ray/issues/3933
|
||||
for (RedisClient client : redisClients) {
|
||||
if (client.exists(key)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package org.ray.runtime;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import org.ray.api.Ray;
|
||||
import org.ray.api.RuntimeContext;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.gcs.GcsClientImpl;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
|
||||
public class RuntimeContextImpl implements RuntimeContext {
|
||||
@@ -36,7 +38,7 @@ public class RuntimeContextImpl implements RuntimeContext {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ((RayNativeRuntime) runtime).actorExistsInGcs(getCurrentActorId());
|
||||
return ((GcsClientImpl) Ray.getGcsClient()).actorExists(getCurrentActorId());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.ray.api.exception.RayTaskException;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.config.RunMode;
|
||||
import org.ray.runtime.functionmanager.RayFunction;
|
||||
import org.ray.runtime.gcs.GcsClientImpl;
|
||||
import org.ray.runtime.task.ArgumentsBuilder;
|
||||
import org.ray.runtime.task.TaskSpec;
|
||||
import org.slf4j.Logger;
|
||||
@@ -171,8 +172,8 @@ public class Worker {
|
||||
numTasksSinceLastCheckpoint = 0;
|
||||
lastCheckpointTimestamp = System.currentTimeMillis();
|
||||
checkpointIds = new ArrayList<>();
|
||||
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
|
||||
.getCheckpointsForActor(actorId);
|
||||
List<Checkpoint> availableCheckpoints
|
||||
= ((GcsClientImpl)runtime.getGcsClient()).getCheckpointsForActor(actorId);
|
||||
if (availableCheckpoints.isEmpty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
package org.ray.runtime.gcs;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.ray.api.Checkpointable.Checkpoint;
|
||||
import org.ray.api.gcs.GcsClient;
|
||||
import org.ray.api.gcs.NodeInfo;
|
||||
import org.ray.api.id.UniqueId;
|
||||
import org.ray.runtime.generated.ActorCheckpointIdData;
|
||||
import org.ray.runtime.generated.ClientTableData;
|
||||
import org.ray.runtime.generated.TablePrefix;
|
||||
import org.ray.runtime.util.UniqueIdUtil;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* An implementation of GcsClient.
|
||||
*/
|
||||
public class GcsClientImpl implements GcsClient {
|
||||
|
||||
private static Logger LOGGER = LoggerFactory.getLogger(GcsClientImpl.class);
|
||||
|
||||
private RedisClient primary;
|
||||
|
||||
private List<RedisClient> shards;
|
||||
|
||||
public GcsClientImpl(String redisAddress, String redisPassword) {
|
||||
primary = new RedisClient(redisAddress, redisPassword);
|
||||
int numShards = 0;
|
||||
try {
|
||||
numShards = Integer.valueOf(primary.get("NumRedisShards", null));
|
||||
Preconditions.checkState(numShards > 0,
|
||||
String.format("Expected at least one Redis shards, found %d.", numShards));
|
||||
} catch (NumberFormatException e) {
|
||||
throw new RuntimeException("Failed to get number of redis shards.", e);
|
||||
}
|
||||
|
||||
List<byte[]> shardAddresses = primary.lrange("RedisShards".getBytes(), 0, -1);
|
||||
Preconditions.checkState(shardAddresses.size() == numShards);
|
||||
shards = shardAddresses.stream().map((byte[] address) -> {
|
||||
return new RedisClient(new String(address));
|
||||
}).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<NodeInfo> getAllNodeInfo() {
|
||||
final String prefix = TablePrefix.name(TablePrefix.CLIENT);
|
||||
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes());
|
||||
List<byte[]> results = primary.lrange(key, 0, -1);
|
||||
|
||||
if (results == null) {
|
||||
return new ArrayList<>();
|
||||
}
|
||||
|
||||
// This map is used for deduplication of client entries.
|
||||
Map<UniqueId, NodeInfo> clients = new HashMap<>();
|
||||
for (byte[] result : results) {
|
||||
Preconditions.checkNotNull(result);
|
||||
ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result));
|
||||
final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer());
|
||||
|
||||
if (data.isInsertion()) {
|
||||
//Code path of node insertion.
|
||||
Map<String, Double> resources = new HashMap<>();
|
||||
// Compute resources.
|
||||
Preconditions.checkState(
|
||||
data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength());
|
||||
for (int i = 0; i < data.resourcesTotalLabelLength(); i++) {
|
||||
resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i));
|
||||
}
|
||||
|
||||
NodeInfo nodeInfo = new NodeInfo(
|
||||
clientId, data.nodeManagerAddress(), true, resources);
|
||||
clients.put(clientId, nodeInfo);
|
||||
} else {
|
||||
// Code path of node deletion.
|
||||
NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress,
|
||||
false, clients.get(clientId).resources);
|
||||
clients.put(clientId, nodeInfo);
|
||||
}
|
||||
}
|
||||
|
||||
return new ArrayList<>(clients.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* If the actor exists in GCS.
|
||||
*/
|
||||
public boolean actorExists(UniqueId actorId) {
|
||||
byte[] key = ArrayUtils.addAll(
|
||||
TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes());
|
||||
return primary.exists(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Query whether the raylet task exists in Gcs.
|
||||
*/
|
||||
public boolean rayletTaskExistsInGcs(UniqueId taskId) {
|
||||
byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(),
|
||||
taskId.getBytes());
|
||||
RedisClient client = getShardClient(taskId);
|
||||
return client.exists(key);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the available checkpoints for the given actor ID.
|
||||
*/
|
||||
public List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
|
||||
List<Checkpoint> checkpoints = new ArrayList<>();
|
||||
final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
|
||||
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
|
||||
RedisClient client = getShardClient(actorId);
|
||||
|
||||
byte[] result = client.get(key);
|
||||
if (result != null) {
|
||||
ActorCheckpointIdData data =
|
||||
ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
|
||||
UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(
|
||||
data.checkpointIdsAsByteBuffer());
|
||||
|
||||
for (int i = 0; i < checkpointIds.length; i++) {
|
||||
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
|
||||
}
|
||||
}
|
||||
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
|
||||
return checkpoints;
|
||||
}
|
||||
|
||||
private RedisClient getShardClient(UniqueId key) {
|
||||
return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key),
|
||||
shards.size()));
|
||||
}
|
||||
|
||||
}
|
||||
@@ -70,6 +70,10 @@ public class RedisClient {
|
||||
|
||||
}
|
||||
|
||||
public byte[] get(byte[] key) {
|
||||
return get(key, null);
|
||||
}
|
||||
|
||||
public byte[] get(byte[] key, byte[] field) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
if (field == null) {
|
||||
@@ -80,7 +84,12 @@ public class RedisClient {
|
||||
}
|
||||
}
|
||||
|
||||
public List<String> lrange(String key, long start, long end) {
|
||||
/**
|
||||
* Return the specified elements of the list stored at the specified key.
|
||||
*
|
||||
* @return Multi bulk reply, specifically a list of elements in the specified range.
|
||||
*/
|
||||
public List<byte[]> lrange(byte[] key, long start, long end) {
|
||||
try (Jedis jedis = jedisPool.getResource()) {
|
||||
return jedis.lrange(key, start, end);
|
||||
}
|
||||
|
||||
@@ -134,4 +134,73 @@ public class UniqueIdUtil {
|
||||
|
||||
return ByteBuffer.wrap(bytesOfIds);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the murmur hash code of this ID.
|
||||
*/
|
||||
public static long murmurHashCode(UniqueId id) {
|
||||
return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is the same as `hash()` method of `ID` class in ray/src/ray/id.h
|
||||
*/
|
||||
private static long murmurHash64A(byte[] data, int length, int seed) {
|
||||
final long m = 0xc6a4a7935bd1e995L;
|
||||
final int r = 47;
|
||||
|
||||
long h = (seed & 0xFFFFFFFFL) ^ (length * m);
|
||||
|
||||
int length8 = length / 8;
|
||||
|
||||
for (int i = 0; i < length8; i++) {
|
||||
final int i8 = i * 8;
|
||||
long k = ((long)data[i8] & 0xff)
|
||||
+ (((long)data[i8 + 1] & 0xff) << 8)
|
||||
+ (((long)data[i8 + 2] & 0xff) << 16)
|
||||
+ (((long)data[i8 + 3] & 0xff) << 24)
|
||||
+ (((long)data[i8 + 4] & 0xff) << 32)
|
||||
+ (((long)data[i8 + 5] & 0xff) << 40)
|
||||
+ (((long)data[i8 + 6] & 0xff) << 48)
|
||||
+ (((long)data[i8 + 7] & 0xff) << 56);
|
||||
|
||||
k *= m;
|
||||
k ^= k >>> r;
|
||||
k *= m;
|
||||
|
||||
h ^= k;
|
||||
h *= m;
|
||||
}
|
||||
|
||||
final int remaining = length % 8;
|
||||
if (remaining >= 7) {
|
||||
h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48;
|
||||
}
|
||||
if (remaining >= 6) {
|
||||
h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40;
|
||||
}
|
||||
if (remaining >= 5) {
|
||||
h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32;
|
||||
}
|
||||
if (remaining >= 4) {
|
||||
h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24;
|
||||
}
|
||||
if (remaining >= 3) {
|
||||
h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16;
|
||||
}
|
||||
if (remaining >= 2) {
|
||||
h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8;
|
||||
}
|
||||
if (remaining >= 1) {
|
||||
h ^= (long) (data[length & ~7] & 0xff);
|
||||
h *= m;
|
||||
}
|
||||
|
||||
h ^= h >>> r;
|
||||
h *= m;
|
||||
h ^= h >>> r;
|
||||
|
||||
return h;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user