[Java] Implement GcsClient (#4601)

This commit is contained in:
Wang Qing
2019-04-12 22:44:47 +08:00
committed by Hao Chen
parent fe07a5b4b1
commit 5cfbfe5df6
16 changed files with 364 additions and 105 deletions
@@ -16,6 +16,7 @@ import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.function.RayFunc;
import org.ray.api.gcs.GcsClient;
import org.ray.api.id.UniqueId;
import org.ray.api.options.ActorCreationOptions;
import org.ray.api.options.BaseTaskOptions;
@@ -67,6 +68,7 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected ObjectStoreProxy objectStoreProxy;
protected FunctionManager functionManager;
protected RuntimeContext runtimeContext;
protected GcsClient gcsClient;
public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
@@ -317,6 +319,11 @@ public abstract class AbstractRayRuntime implements RayRuntime {
return actor;
}
@Override
public GcsClient getGcsClient() {
return gcsClient;
}
/**
* Create the task specification.
*
@@ -6,27 +6,18 @@ import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.gcs.GcsClientImpl;
import org.ray.runtime.gcs.RedisClient;
import org.ray.runtime.generated.ActorCheckpointIdData;
import org.ray.runtime.generated.TablePrefix;
import org.ray.runtime.objectstore.ObjectStoreProxy;
import org.ray.runtime.raylet.RayletClientImpl;
import org.ray.runtime.runner.RunManager;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -37,14 +28,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
private static final Logger LOGGER = LoggerFactory.getLogger(RayNativeRuntime.class);
/**
* Redis client of the primary shard.
*/
private RedisClient redisClient;
/**
* Redis clients of all shards.
*/
private List<RedisClient> redisClients;
private RunManager manager = null;
static {
@@ -109,7 +92,7 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
manager.startRayProcesses(true);
}
initRedisClients();
gcsClient = new GcsClientImpl(rayConfig.getRedisAddress(), rayConfig.redisPassword);
// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
@@ -128,16 +111,6 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
rayConfig.objectStoreSocketName, rayConfig.rayletSocketName);
}
private void initRedisClients() {
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
int numRedisShards = Integer.valueOf(redisClient.get("NumRedisShards", null));
List<String> addresses = redisClient.lrange("RedisShards", 0, -1);
Preconditions.checkState(numRedisShards == addresses.size());
redisClients = addresses.stream().map(RedisClient::new)
.collect(Collectors.toList());
redisClients.add(redisClient);
}
@Override
public void shutdown() {
if (null != manager) {
@@ -145,7 +118,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
}
}
/**
* Register this worker or driver to GCS.
*/
private void registerWorker() {
RedisClient redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);
Map<String, String> workerInfo = new HashMap<>();
String workerId = new String(workerContext.getCurrentWorkerId().getBytes());
if (rayConfig.workerMode == WorkerMode.DRIVER) {
@@ -165,70 +142,4 @@ public final class RayNativeRuntime extends AbstractRayRuntime {
redisClient.hmset("Workers:" + workerId, workerInfo);
}
}
/**
* Get the available checkpoints for the given actor ID, return a list sorted by checkpoint
* timestamp in descending order.
*/
List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
List<Checkpoint> checkpoints = new ArrayList<>();
// TODO(hchen): implement the equivalent of Python's `GlobalState`, to avoid looping over
// all redis shards..
String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
for (RedisClient client : redisClients) {
byte[] result = client.get(key, null);
if (result == null) {
continue;
}
ActorCheckpointIdData data = ActorCheckpointIdData
.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
UniqueId[] checkpointIds
= UniqueIdUtil.getUniqueIdsFromByteBuffer(data.checkpointIdsAsByteBuffer());
for (int i = 0; i < checkpointIds.length; i++) {
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
}
break;
}
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
return checkpoints;
}
/**
* Query whether the actor exists in Gcs.
*/
boolean actorExistsInGcs(UniqueId actorId) {
byte[] key = ArrayUtils.addAll("ACTOR".getBytes(), actorId.getBytes());
// TODO(qwang): refactor this with `GlobalState` after this issue
// getting finished. https://github.com/ray-project/ray/issues/3933
for (RedisClient client : redisClients) {
if (client.exists(key)) {
return true;
}
}
return false;
}
/**
* Query whether the raylet task exists in Gcs.
*/
public boolean rayletTaskExistsInGcs(UniqueId taskId) {
byte[] key = ArrayUtils.addAll("RAYLET_TASK".getBytes(), taskId.getBytes());
// TODO(qwang): refactor this with `GlobalState` after this issue
// getting finished. https://github.com/ray-project/ray/issues/3933
for (RedisClient client : redisClients) {
if (client.exists(key)) {
return true;
}
}
return false;
}
}
@@ -1,9 +1,11 @@
package org.ray.runtime;
import com.google.common.base.Preconditions;
import org.ray.api.Ray;
import org.ray.api.RuntimeContext;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.gcs.GcsClientImpl;
import org.ray.runtime.task.TaskSpec;
public class RuntimeContextImpl implements RuntimeContext {
@@ -36,7 +38,7 @@ public class RuntimeContextImpl implements RuntimeContext {
return false;
}
return ((RayNativeRuntime) runtime).actorExistsInGcs(getCurrentActorId());
return ((GcsClientImpl) Ray.getGcsClient()).actorExists(getCurrentActorId());
}
@Override
@@ -10,6 +10,7 @@ import org.ray.api.exception.RayTaskException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.functionmanager.RayFunction;
import org.ray.runtime.gcs.GcsClientImpl;
import org.ray.runtime.task.ArgumentsBuilder;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
@@ -171,8 +172,8 @@ public class Worker {
numTasksSinceLastCheckpoint = 0;
lastCheckpointTimestamp = System.currentTimeMillis();
checkpointIds = new ArrayList<>();
List<Checkpoint> availableCheckpoints = ((RayNativeRuntime) runtime)
.getCheckpointsForActor(actorId);
List<Checkpoint> availableCheckpoints
= ((GcsClientImpl)runtime.getGcsClient()).getCheckpointsForActor(actorId);
if (availableCheckpoints.isEmpty()) {
return;
}
@@ -0,0 +1,140 @@
package org.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.ray.api.Checkpointable.Checkpoint;
import org.ray.api.gcs.GcsClient;
import org.ray.api.gcs.NodeInfo;
import org.ray.api.id.UniqueId;
import org.ray.runtime.generated.ActorCheckpointIdData;
import org.ray.runtime.generated.ClientTableData;
import org.ray.runtime.generated.TablePrefix;
import org.ray.runtime.util.UniqueIdUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* An implementation of GcsClient.
*/
public class GcsClientImpl implements GcsClient {
private static Logger LOGGER = LoggerFactory.getLogger(GcsClientImpl.class);
private RedisClient primary;
private List<RedisClient> shards;
public GcsClientImpl(String redisAddress, String redisPassword) {
primary = new RedisClient(redisAddress, redisPassword);
int numShards = 0;
try {
numShards = Integer.valueOf(primary.get("NumRedisShards", null));
Preconditions.checkState(numShards > 0,
String.format("Expected at least one Redis shards, found %d.", numShards));
} catch (NumberFormatException e) {
throw new RuntimeException("Failed to get number of redis shards.", e);
}
List<byte[]> shardAddresses = primary.lrange("RedisShards".getBytes(), 0, -1);
Preconditions.checkState(shardAddresses.size() == numShards);
shards = shardAddresses.stream().map((byte[] address) -> {
return new RedisClient(new String(address));
}).collect(Collectors.toList());
}
@Override
public List<NodeInfo> getAllNodeInfo() {
final String prefix = TablePrefix.name(TablePrefix.CLIENT);
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), UniqueId.NIL.getBytes());
List<byte[]> results = primary.lrange(key, 0, -1);
if (results == null) {
return new ArrayList<>();
}
// This map is used for deduplication of client entries.
Map<UniqueId, NodeInfo> clients = new HashMap<>();
for (byte[] result : results) {
Preconditions.checkNotNull(result);
ClientTableData data = ClientTableData.getRootAsClientTableData(ByteBuffer.wrap(result));
final UniqueId clientId = UniqueId.fromByteBuffer(data.clientIdAsByteBuffer());
if (data.isInsertion()) {
//Code path of node insertion.
Map<String, Double> resources = new HashMap<>();
// Compute resources.
Preconditions.checkState(
data.resourcesTotalLabelLength() == data.resourcesTotalCapacityLength());
for (int i = 0; i < data.resourcesTotalLabelLength(); i++) {
resources.put(data.resourcesTotalLabel(i), data.resourcesTotalCapacity(i));
}
NodeInfo nodeInfo = new NodeInfo(
clientId, data.nodeManagerAddress(), true, resources);
clients.put(clientId, nodeInfo);
} else {
// Code path of node deletion.
NodeInfo nodeInfo = new NodeInfo(clientId, clients.get(clientId).nodeAddress,
false, clients.get(clientId).resources);
clients.put(clientId, nodeInfo);
}
}
return new ArrayList<>(clients.values());
}
/**
* If the actor exists in GCS.
*/
public boolean actorExists(UniqueId actorId) {
byte[] key = ArrayUtils.addAll(
TablePrefix.name(TablePrefix.ACTOR).getBytes(), actorId.getBytes());
return primary.exists(key);
}
/**
* Query whether the raylet task exists in Gcs.
*/
public boolean rayletTaskExistsInGcs(UniqueId taskId) {
byte[] key = ArrayUtils.addAll(TablePrefix.name(TablePrefix.RAYLET_TASK).getBytes(),
taskId.getBytes());
RedisClient client = getShardClient(taskId);
return client.exists(key);
}
/**
* Get the available checkpoints for the given actor ID.
*/
public List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
List<Checkpoint> checkpoints = new ArrayList<>();
final String prefix = TablePrefix.name(TablePrefix.ACTOR_CHECKPOINT_ID);
final byte[] key = ArrayUtils.addAll(prefix.getBytes(), actorId.getBytes());
RedisClient client = getShardClient(actorId);
byte[] result = client.get(key);
if (result != null) {
ActorCheckpointIdData data =
ActorCheckpointIdData.getRootAsActorCheckpointIdData(ByteBuffer.wrap(result));
UniqueId[] checkpointIds = UniqueIdUtil.getUniqueIdsFromByteBuffer(
data.checkpointIdsAsByteBuffer());
for (int i = 0; i < checkpointIds.length; i++) {
checkpoints.add(new Checkpoint(checkpointIds[i], data.timestamps(i)));
}
}
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
return checkpoints;
}
private RedisClient getShardClient(UniqueId key) {
return shards.get((int) Long.remainderUnsigned(UniqueIdUtil.murmurHashCode(key),
shards.size()));
}
}
@@ -70,6 +70,10 @@ public class RedisClient {
}
public byte[] get(byte[] key) {
return get(key, null);
}
public byte[] get(byte[] key, byte[] field) {
try (Jedis jedis = jedisPool.getResource()) {
if (field == null) {
@@ -80,7 +84,12 @@ public class RedisClient {
}
}
public List<String> lrange(String key, long start, long end) {
/**
* Return the specified elements of the list stored at the specified key.
*
* @return Multi bulk reply, specifically a list of elements in the specified range.
*/
public List<byte[]> lrange(byte[] key, long start, long end) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.lrange(key, start, end);
}
@@ -134,4 +134,73 @@ public class UniqueIdUtil {
return ByteBuffer.wrap(bytesOfIds);
}
/**
* Compute the murmur hash code of this ID.
*/
public static long murmurHashCode(UniqueId id) {
return murmurHash64A(id.getBytes(), UniqueId.LENGTH, 0);
}
/**
* This method is the same as `hash()` method of `ID` class in ray/src/ray/id.h
*/
private static long murmurHash64A(byte[] data, int length, int seed) {
final long m = 0xc6a4a7935bd1e995L;
final int r = 47;
long h = (seed & 0xFFFFFFFFL) ^ (length * m);
int length8 = length / 8;
for (int i = 0; i < length8; i++) {
final int i8 = i * 8;
long k = ((long)data[i8] & 0xff)
+ (((long)data[i8 + 1] & 0xff) << 8)
+ (((long)data[i8 + 2] & 0xff) << 16)
+ (((long)data[i8 + 3] & 0xff) << 24)
+ (((long)data[i8 + 4] & 0xff) << 32)
+ (((long)data[i8 + 5] & 0xff) << 40)
+ (((long)data[i8 + 6] & 0xff) << 48)
+ (((long)data[i8 + 7] & 0xff) << 56);
k *= m;
k ^= k >>> r;
k *= m;
h ^= k;
h *= m;
}
final int remaining = length % 8;
if (remaining >= 7) {
h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48;
}
if (remaining >= 6) {
h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40;
}
if (remaining >= 5) {
h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32;
}
if (remaining >= 4) {
h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24;
}
if (remaining >= 3) {
h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16;
}
if (remaining >= 2) {
h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8;
}
if (remaining >= 1) {
h ^= (long) (data[length & ~7] & 0xff);
h *= m;
}
h ^= h >>> r;
h *= m;
h ^= h >>> r;
return h;
}
}