[Placement Group] Add get / get all / remove interface for Placement Group Java api. (#11821)

* add placement group java get/get all interface

* add remove placement group api

* fix some issue like: Placement Group -> placement group

* extract dumplicate code to placement group utils

* specify running mode for placement group ut

* update checkGlobalStateAccessorPointerValid -> validateGlobalStateAccessorPointer

* use THROW_EXCEPTION_AND_RETURN_IF_NOT_OK

* update pg log print
This commit is contained in:
DK.Pino
2020-11-17 12:32:39 +08:00
committed by GitHub
parent 90574b66cc
commit 0f9e2fec12
22 changed files with 479 additions and 71 deletions
@@ -14,6 +14,7 @@ import io.ray.api.function.PyFunction;
import io.ray.api.function.RayFunc;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.api.placementgroup.PlacementGroup;
@@ -184,6 +185,21 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
return createPlacementGroup(DEFAULT_PLACEMENT_GROUP_NAME, bundles, strategy);
}
@Override
public void removePlacementGroup(PlacementGroupId id) {
taskSubmitter.removePlacementGroup(id);
}
@Override
public PlacementGroup getPlacementGroup(PlacementGroupId id) {
return gcsClient.getPlacementGroupInfo(id);
}
@Override
public List<PlacementGroup> getAllPlacementGroups() {
return gcsClient.getAllPlacementGroupInfo();
}
@SuppressWarnings("unchecked")
@Override
public <T extends BaseActorHandle> T getActorHandle(ActorId actorId) {
@@ -3,12 +3,15 @@ package io.ray.runtime;
import com.google.common.base.Preconditions;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.JobId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.UniqueId;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.runtime.config.RayConfig;
import io.ray.runtime.context.LocalModeWorkerContext;
import io.ray.runtime.object.LocalModeObjectStore;
import io.ray.runtime.task.LocalModeTaskExecutor;
import io.ray.runtime.task.LocalModeTaskSubmitter;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
@@ -84,6 +87,21 @@ public class RayDevRuntime extends AbstractRayRuntime {
super.setAsyncContext(asyncContext);
}
@Override
public PlacementGroup getPlacementGroup(
PlacementGroupId id) {
//@TODO(clay4444): We need a LocalGcsClient before implements this.
throw new UnsupportedOperationException(
"Ray doesn't support placement group operations in local mode.");
}
@Override
public List<PlacementGroup> getAllPlacementGroups() {
//@TODO(clay4444): We need a LocalGcsClient before implements this.
throw new UnsupportedOperationException(
"Ray doesn't support placement group operations in local mode.");
}
@Override
public void exitActor() {
@@ -6,13 +6,16 @@ import io.ray.api.Checkpointable.Checkpoint;
import io.ray.api.id.ActorId;
import io.ray.api.id.BaseId;
import io.ray.api.id.JobId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.NodeInfo;
import io.ray.runtime.generated.Gcs;
import io.ray.runtime.generated.Gcs.ActorCheckpointIdData;
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
import io.ray.runtime.generated.Gcs.TablePrefix;
import io.ray.runtime.placementgroup.PlacementGroupUtils;
import io.ray.runtime.util.IdUtil;
import java.util.ArrayList;
import java.util.HashMap;
@@ -52,6 +55,30 @@ public class GcsClient {
globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword);
}
/**
* Get placement group by {@link PlacementGroupId}
* @param placementGroupId Id of placement group.
* @return The placement group.
*/
public PlacementGroup getPlacementGroupInfo(PlacementGroupId placementGroupId) {
byte[] result = globalStateAccessor.getPlacementGroupInfo(placementGroupId);
return PlacementGroupUtils.generatePlacementGroupFromByteArray(result);
}
/**
* Get all placement groups in this cluster.
* @return All placement groups.
*/
public List<PlacementGroup> getAllPlacementGroupInfo() {
List<byte[]> results = globalStateAccessor.getAllPlacementGroupInfo();
List<PlacementGroup> placementGroups = new ArrayList<>();
for (byte[] result : results) {
placementGroups.add(PlacementGroupUtils.generatePlacementGroupFromByteArray(result));
}
return placementGroups;
}
public List<NodeInfo> getAllNodeInfo() {
List<byte[]> results = globalStateAccessor.getAllNodeInfo();
@@ -2,6 +2,7 @@ package io.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import io.ray.api.id.ActorId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.UniqueId;
import java.util.List;
@@ -33,8 +34,7 @@ public class GlobalStateAccessor {
private GlobalStateAccessor(String redisAddress, String redisPassword) {
globalStateAccessorNativePointer =
nativeCreateGlobalStateAccessor(redisAddress, redisPassword);
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Global state accessor native pointer must not be 0.");
validateGlobalStateAccessorPointer();
connect();
}
@@ -42,14 +42,18 @@ public class GlobalStateAccessor {
return this.nativeConnect(globalStateAccessorNativePointer);
}
private void validateGlobalStateAccessorPointer() {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Global state accessor native pointer must not be 0.");
}
/**
* @return A list of job info with JobInfo protobuf schema.
*/
public List<byte[]> getAllJobInfo() {
// Fetch a job list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get all job info when global state accessor have been destroyed.");
validateGlobalStateAccessorPointer();
return this.nativeGetAllJobInfo(globalStateAccessorNativePointer);
}
}
@@ -60,8 +64,7 @@ public class GlobalStateAccessor {
public List<byte[]> getAllNodeInfo() {
// Fetch a node list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get all node info when global state accessor have been destroyed.");
validateGlobalStateAccessorPointer();
return this.nativeGetAllNodeInfo(globalStateAccessorNativePointer);
}
}
@@ -72,16 +75,30 @@ public class GlobalStateAccessor {
*/
public byte[] getNodeResourceInfo(UniqueId nodeId) {
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get resource info by node id when global state accessor have been destroyed.");
validateGlobalStateAccessorPointer();
return nativeGetNodeResourceInfo(globalStateAccessorNativePointer, nodeId.getBytes());
}
}
public byte[] getPlacementGroupInfo(PlacementGroupId placementGroupId) {
synchronized (GlobalStateAccessor.class) {
Preconditions.checkNotNull(placementGroupId,
"PlacementGroupId can't be null when get placement group info.");
return nativeGetPlacementGroupInfo(globalStateAccessorNativePointer,
placementGroupId.getBytes());
}
}
public List<byte[]> getAllPlacementGroupInfo() {
synchronized (GlobalStateAccessor.class) {
validateGlobalStateAccessorPointer();
return this.nativeGetAllPlacementGroupInfo(globalStateAccessorNativePointer);
}
}
public byte[] getInternalConfig() {
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0,
"Get internal config when global state accessor have been destroyed.");
validateGlobalStateAccessorPointer();
return nativeGetInternalConfig(globalStateAccessorNativePointer);
}
}
@@ -92,7 +109,7 @@ public class GlobalStateAccessor {
public List<byte[]> getAllActorInfo() {
// Fetch a actor list with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0);
validateGlobalStateAccessorPointer();
return this.nativeGetAllActorInfo(globalStateAccessorNativePointer);
}
}
@@ -103,7 +120,7 @@ public class GlobalStateAccessor {
public byte[] getActorInfo(ActorId actorId) {
// Fetch an actor with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0);
validateGlobalStateAccessorPointer();
return this.nativeGetActorInfo(globalStateAccessorNativePointer, actorId.getBytes());
}
}
@@ -114,7 +131,7 @@ public class GlobalStateAccessor {
public byte[] getActorCheckpointId(ActorId actorId) {
// Fetch an actor checkpoint id with protobuf bytes format from GCS.
synchronized (GlobalStateAccessor.class) {
Preconditions.checkState(globalStateAccessorNativePointer != 0);
validateGlobalStateAccessorPointer();
return this.nativeGetActorCheckpointId(globalStateAccessorNativePointer, actorId.getBytes());
}
}
@@ -148,4 +165,9 @@ public class GlobalStateAccessor {
private native byte[] nativeGetActorInfo(long nativePtr, byte[] actorId);
private native byte[] nativeGetActorCheckpointId(long nativePtr, byte[] actorId);
private native byte[] nativeGetPlacementGroupInfo(long nativePtr,
byte[] placementGroupId);
private native List<byte[]> nativeGetAllPlacementGroupInfo(long nativePtr);
}
@@ -1,58 +0,0 @@
package io.ray.runtime.placementgroup;
import io.ray.api.id.BaseId;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Random;
/**
* Represents the id of a placement group.
*/
public class PlacementGroupId extends BaseId implements Serializable {
public static final int LENGTH = 16;
public static final PlacementGroupId NIL = nil();
private PlacementGroupId(byte[] id) {
super(id);
}
/**
* Creates a PlacementGroupId from the given ByteBuffer.
*/
public static PlacementGroupId fromByteBuffer(ByteBuffer bb) {
return new PlacementGroupId(byteBuffer2Bytes(bb));
}
/**
* Create a PlacementGroupId instance according to the given bytes.
*/
public static PlacementGroupId fromBytes(byte[] bytes) {
return new PlacementGroupId(bytes);
}
/**
* Generate a nil PlacementGroupId.
*/
private static PlacementGroupId nil() {
byte[] b = new byte[LENGTH];
Arrays.fill(b, (byte) 0xFF);
return new PlacementGroupId(b);
}
/**
* Generate an PlacementGroupId with random value. Used for local mode and test only.
*/
public static PlacementGroupId fromRandom() {
byte[] b = new byte[LENGTH];
new Random().nextBytes(b);
return new PlacementGroupId(b);
}
@Override
public int size() {
return LENGTH;
}
}
@@ -1,6 +1,8 @@
package io.ray.runtime.placementgroup;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.placementgroup.PlacementGroupState;
import io.ray.api.placementgroup.PlacementStrategy;
import java.util.List;
import java.util.Map;
@@ -48,7 +50,7 @@ public class PlacementGroupImpl implements PlacementGroup {
}
/**
* A help class for create the Placement Group.
* A help class for create the placement group.
*/
public static class Builder {
private PlacementGroupId id;
@@ -58,8 +60,8 @@ public class PlacementGroupImpl implements PlacementGroup {
private PlacementGroupState state;
/**
* Set the Id of the Placement Group.
* @param id Id of the Placement Group.
* Set the Id of the placement group.
* @param id Id of the placement group.
* @return self.
*/
public Builder setId(PlacementGroupId id) {
@@ -68,8 +70,8 @@ public class PlacementGroupImpl implements PlacementGroup {
}
/**
* Set the name of the Placement Group.
* @param name Name of the Placement Group.
* Set the name of the placement group.
* @param name Name of the placement group.
* @return self.
*/
public Builder setName(String name) {
@@ -78,8 +80,8 @@ public class PlacementGroupImpl implements PlacementGroup {
}
/**
* Set the bundles of the Placement Group.
* @param bundles the bundles of the Placement Group.
* Set the bundles of the placement group.
* @param bundles the bundles of the placement group.
* @return self.
*/
public Builder setBundles(List<Map<String, Double>> bundles) {
@@ -88,8 +90,8 @@ public class PlacementGroupImpl implements PlacementGroup {
}
/**
* Set the placement strategy of the Placement Group.
* @param strategy the placement strategy of the Placement Group.
* Set the placement strategy of the placement group.
* @param strategy the placement strategy of the placement group.
* @return self.
*/
public Builder setStrategy(PlacementStrategy strategy) {
@@ -98,8 +100,8 @@ public class PlacementGroupImpl implements PlacementGroup {
}
/**
* Set the placement state of the Placement Group.
* @param state the state of the Placement Group.
* Set the placement state of the placement group.
* @param state the state of the placement group.
* @return self.
*/
public Builder setState(PlacementGroupState state) {
@@ -1,32 +0,0 @@
package io.ray.runtime.placementgroup;
/**
* State of Placement Group.
*/
public enum PlacementGroupState {
/**
* Wait for resource to schedule.
*/
PENDING(0),
/**
* The Placement Group has created on some node.
*/
CREATED(1),
/**
* The Placement Group has removed.
*/
REMOVED(2);
private int value = 0;
PlacementGroupState(int value) {
this.value = value;
}
public int value() {
return this.value;
}
}
@@ -0,0 +1,108 @@
package io.ray.runtime.placementgroup;
import com.google.common.base.Preconditions;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.placementgroup.PlacementGroupState;
import io.ray.api.placementgroup.PlacementStrategy;
import io.ray.runtime.generated.Common;
import io.ray.runtime.generated.Common.Bundle;
import io.ray.runtime.generated.Gcs.PlacementGroupTableData;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
/**
* Utils for placement group.
*/
public class PlacementGroupUtils {
private static List<Map<String, Double>> covertToUserSpecifiedBundles(List<Bundle> bundles) {
List<Map<String, Double>> result = new ArrayList<>();
// NOTE(clay4444): We need to guarantee the order here.
for (int i = 0; i < bundles.size(); i++) {
Bundle bundle = bundles.get(i);
result.add(bundle.getUnitResourcesMap());
}
return result;
}
private static PlacementStrategy covertToUserSpecifiedStrategy(
Common.PlacementStrategy strategy) {
switch (strategy) {
case PACK:
return PlacementStrategy.PACK;
case STRICT_PACK:
return PlacementStrategy.STRICT_PACK;
case SPREAD:
return PlacementStrategy.SPREAD;
case STRICT_SPREAD:
return PlacementStrategy.STRICT_SPREAD;
default:
return PlacementStrategy.UNRECOGNIZED;
}
}
private static PlacementGroupState covertToUserSpecifiedState(
PlacementGroupTableData.PlacementGroupState state) {
switch (state) {
case PENDING:
return PlacementGroupState.PENDING;
case CREATED:
return PlacementGroupState.CREATED;
case REMOVED:
return PlacementGroupState.REMOVED;
case RESCHEDULING:
return PlacementGroupState.RESCHEDULING;
default:
return PlacementGroupState.UNRECOGNIZED;
}
}
/**
* Generate a PlacementGroupImpl from placementGroupTableData protobuf data.
* @param placementGroupTableData protobuf data.
* @return placement group info {@link PlacementGroupImpl}
*/
private static PlacementGroupImpl generatePlacementGroupFromPbData(
PlacementGroupTableData placementGroupTableData) {
PlacementGroupState state = covertToUserSpecifiedState(
placementGroupTableData.getState());
PlacementStrategy strategy = covertToUserSpecifiedStrategy(
placementGroupTableData.getStrategy());
List<Map<String, Double>> bundles = covertToUserSpecifiedBundles(
placementGroupTableData.getBundlesList());
PlacementGroupId placementGroupId = PlacementGroupId.fromByteBuffer(
placementGroupTableData.getPlacementGroupId().asReadOnlyByteBuffer());
return new PlacementGroupImpl.Builder()
.setId(placementGroupId).setName(placementGroupTableData.getName())
.setState(state).setStrategy(strategy).setBundles(bundles)
.build();
}
/**
* Generate a PlacementGroupImpl from byte array.
* @param placementGroupByteArray bytes array from native method.
* @return placement group info {@link PlacementGroupImpl}
*/
public static PlacementGroupImpl generatePlacementGroupFromByteArray(
byte[] placementGroupByteArray) {
Preconditions.checkNotNull(placementGroupByteArray,
"Can't generate a placement group from empty byte array.");
PlacementGroupTableData placementGroupTableData;
try {
placementGroupTableData = PlacementGroupTableData.parseFrom(placementGroupByteArray);
} catch (InvalidProtocolBufferException e) {
throw new RuntimeException(
"Received invalid placement group table protobuf data from GCS.", e);
}
return generatePlacementGroupFromPbData(placementGroupTableData);
}
}
@@ -8,6 +8,7 @@ import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.api.options.ActorCreationOptions;
@@ -75,6 +76,8 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
private final Map<ActorId, TaskExecutor.ActorContext> actorContexts = new ConcurrentHashMap<>();
private final Map<PlacementGroupId, PlacementGroup> placementGroups = new ConcurrentHashMap<>();
public LocalModeTaskSubmitter(RayRuntimeInternal runtime, TaskExecutor taskExecutor,
LocalModeObjectStore objectStore) {
this.runtime = runtime;
@@ -225,8 +228,16 @@ public class LocalModeTaskSubmitter implements TaskSubmitter {
@Override
public PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy) {
return new PlacementGroupImpl.Builder()
.setName(name).setBundles(bundles).setStrategy(strategy).build();
PlacementGroupImpl placementGroup = new PlacementGroupImpl.Builder()
.setId(PlacementGroupId.fromRandom()).setName(name)
.setBundles(bundles).setStrategy(strategy).build();
placementGroups.put(placementGroup.getId(), placementGroup);
return placementGroup;
}
@Override
public void removePlacementGroup(PlacementGroupId id) {
placementGroups.remove(id);
}
@Override
@@ -6,13 +6,13 @@ import io.ray.api.BaseActorHandle;
import io.ray.api.Ray;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.placementgroup.PlacementStrategy;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.functionmanager.FunctionDescriptor;
import io.ray.runtime.placementgroup.PlacementGroupId;
import io.ray.runtime.placementgroup.PlacementGroupImpl;
import java.util.List;
import java.util.Map;
@@ -86,6 +86,11 @@ public class NativeTaskSubmitter implements TaskSubmitter {
.setName(name).setBundles(bundles).setStrategy(strategy).build();
}
@Override
public void removePlacementGroup(PlacementGroupId id) {
nativeRemovePlacementGroup(id.getBytes());
}
private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
int functionDescriptorHash, List<FunctionArg> args, int numReturns, CallOptions callOptions);
@@ -99,4 +104,7 @@ public class NativeTaskSubmitter implements TaskSubmitter {
private static native byte[] nativeCreatePlacementGroup(String name,
List<Map<String, Double>> bundles, int strategy);
private static native void nativeRemovePlacementGroup(byte[] placementGroupId);
}
@@ -3,6 +3,7 @@ package io.ray.runtime.task;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ActorId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.options.ActorCreationOptions;
import io.ray.api.options.CallOptions;
import io.ray.api.placementgroup.PlacementGroup;
@@ -53,7 +54,7 @@ public interface TaskSubmitter {
/**
* Create a placement group.
*
* @param name Name of the Placement Group.
* @param name Name of the placement group.
* @param bundles Pre-allocated resource list.
* @param strategy Actor placement strategy.
* @return A handle to the created placement group.
@@ -61,6 +62,12 @@ public interface TaskSubmitter {
PlacementGroup createPlacementGroup(String name, List<Map<String, Double>> bundles,
PlacementStrategy strategy);
/**
* Remove a placement group by id.
* @param id Id of the placement group.
*/
void removePlacementGroup(PlacementGroupId id);
BaseActorHandle getActor(ActorId actorId);
}