diff --git a/java/api/src/main/java/io/ray/api/Ray.java b/java/api/src/main/java/io/ray/api/Ray.java index c4044486e..693c50dd0 100644 --- a/java/api/src/main/java/io/ray/api/Ray.java +++ b/java/api/src/main/java/io/ray/api/Ray.java @@ -2,11 +2,14 @@ package io.ray.api; import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.api.runtime.RayRuntime; import io.ray.api.runtime.RayRuntimeFactory; import io.ray.api.runtimecontext.RuntimeContext; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; @@ -240,4 +243,22 @@ public final class Ray extends RayCall { public static RuntimeContext getRuntimeContext() { return runtime.getRuntimeContext(); } + + /** + * Create a placement group. + * A placement group is used to place actors according to a specific strategy + * and resource constraints. + * It will sends a request to GCS to preallocate the specified resources, which is asynchronous. + * If the specified resource cannot be allocated, it will wait for the resource + * to be updated and rescheduled. + * This function only works when gcs actor manager is turned on. + * + * @param bundles Preallocated resource list. + * @param strategy Actor placement strategy. + * @return A handle to the created placement group. + */ + public static PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy) { + return runtime.createPlacementGroup(bundles, strategy); + } } diff --git a/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java b/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java index ec281705b..827eb0412 100644 --- a/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java +++ b/java/api/src/main/java/io/ray/api/call/BaseActorCreator.java @@ -2,6 +2,7 @@ package io.ray.api.call; import io.ray.api.Ray; import io.ray.api.options.ActorCreationOptions; +import io.ray.api.placementgroup.PlacementGroup; import java.util.Map; /** @@ -85,7 +86,6 @@ public class BaseActorCreator { } /** - * /** * Set the max number of concurrent calls to allow for this actor. *

* The max concurrency defaults to 1 for threaded execution. @@ -100,6 +100,19 @@ public class BaseActorCreator { return self(); } + /** + * Set the placement group to place this actor in. + * + * @param group The placement group of the actor. + * @param bundleIndex The index of the bundle to place this actor in. + * @return self + * @see ActorCreationOptions.Builder#setPlacementGroup(PlacementGroup, int) + */ + public T setPlacementGroup(PlacementGroup group, int bundleIndex) { + builder.setPlacementGroup(group, bundleIndex); + return self(); + } + @SuppressWarnings("unchecked") private T self() { return (T) this; diff --git a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java index 363e915e9..a33c10dbe 100644 --- a/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java +++ b/java/api/src/main/java/io/ray/api/options/ActorCreationOptions.java @@ -1,6 +1,7 @@ package io.ray.api.options; import io.ray.api.Ray; +import io.ray.api.placementgroup.PlacementGroup; import java.util.HashMap; import java.util.Map; @@ -13,15 +14,20 @@ public class ActorCreationOptions extends BaseTaskOptions { public final int maxRestarts; public final String jvmOptions; public final int maxConcurrency; + public final PlacementGroup group; + public final int bundleIndex; private ActorCreationOptions(boolean global, String name, Map resources, - int maxRestarts, String jvmOptions, int maxConcurrency) { + int maxRestarts, String jvmOptions, int maxConcurrency, + PlacementGroup group, int bundleIndex) { super(resources); this.global = global; this.name = name; this.maxRestarts = maxRestarts; this.jvmOptions = jvmOptions; this.maxConcurrency = maxConcurrency; + this.group = group; + this.bundleIndex = bundleIndex; } /** @@ -34,6 +40,8 @@ public class ActorCreationOptions extends BaseTaskOptions { private int maxRestarts = 0; private String jvmOptions = null; private int maxConcurrency = 1; + private PlacementGroup group; + private int bundleIndex; /** * Set the actor name of a named actor. @@ -135,9 +143,22 @@ public class ActorCreationOptions extends BaseTaskOptions { return this; } + /** + * Set the placement group to place this actor in. + * + * @param group The placement group of the actor. + * @param bundleIndex The index of the bundle to place this actor in. + * @return self + */ + public Builder setPlacementGroup(PlacementGroup group, int bundleIndex) { + this.group = group; + this.bundleIndex = bundleIndex; + return this; + } + public ActorCreationOptions build() { return new ActorCreationOptions( - global, name, resources, maxRestarts, jvmOptions, maxConcurrency); + global, name, resources, maxRestarts, jvmOptions, maxConcurrency, group, bundleIndex); } } diff --git a/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroup.java b/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroup.java new file mode 100644 index 000000000..36531680c --- /dev/null +++ b/java/api/src/main/java/io/ray/api/placementgroup/PlacementGroup.java @@ -0,0 +1,10 @@ +package io.ray.api.placementgroup; + +/** + * A placement group is used to place interdependent actors according to a specific strategy + * {@link PlacementStrategy}. + * When a placement group is created, the corresponding actor slots and resources are preallocated. + * A placement group consists of one or more bundles plus a specific placement strategy. + */ +public interface PlacementGroup { +} diff --git a/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java b/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java new file mode 100644 index 000000000..c97c24d84 --- /dev/null +++ b/java/api/src/main/java/io/ray/api/placementgroup/PlacementStrategy.java @@ -0,0 +1,25 @@ +package io.ray.api.placementgroup; + +/** + * The actor placement strategy. + */ +public enum PlacementStrategy { + /** + * Packs Bundles close together inside nodes as tight as possible. + */ + PACK(0), + /** + * Places Bundles across distinct nodes as even as possible. + */ + SPREAD(1); + + private int value = 0; + + PlacementStrategy(int value) { + this.value = value; + } + + public int value() { + return this.value; + } +} diff --git a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java index 318b880e8..9340567ab 100644 --- a/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/io/ray/api/runtime/RayRuntime.java @@ -14,8 +14,11 @@ import io.ray.api.id.ObjectId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.api.runtimecontext.RuntimeContext; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; @@ -169,6 +172,9 @@ public interface RayRuntime { PyActorHandle createActor(PyActorClass pyActorClass, Object[] args, ActorCreationOptions options); + PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy); + RuntimeContext getRuntimeContext(); Object getAsyncContext(); diff --git a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java index b86698c60..b2acec342 100644 --- a/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/io/ray/runtime/AbstractRayRuntime.java @@ -17,6 +17,8 @@ import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.api.runtimecontext.RuntimeContext; import io.ray.runtime.config.RayConfig; import io.ray.runtime.context.RuntimeContextImpl; @@ -35,6 +37,7 @@ import io.ray.runtime.task.FunctionArg; import io.ray.runtime.task.TaskExecutor; import io.ray.runtime.task.TaskSubmitter; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Callable; import org.slf4j.Logger; @@ -155,6 +158,12 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal { return (PyActorHandle) createActorImpl(functionDescriptor, args, options); } + @Override + public PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy) { + return taskSubmitter.createPlacementGroup(bundles, strategy); + } + @SuppressWarnings("unchecked") @Override public T getActorHandle(ActorId actorId) { diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java new file mode 100644 index 000000000..46005c96e --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupId.java @@ -0,0 +1,58 @@ +package io.ray.runtime.placementgroup; + +import io.ray.api.id.BaseId; +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Random; + +/** + * Represents the id of a placement group. + */ +public class PlacementGroupId extends BaseId implements Serializable { + + public static final int LENGTH = 16; + + public static final PlacementGroupId NIL = nil(); + + private PlacementGroupId(byte[] id) { + super(id); + } + + /** + * Creates a PlacementGroupId from the given ByteBuffer. + */ + public static PlacementGroupId fromByteBuffer(ByteBuffer bb) { + return new PlacementGroupId(byteBuffer2Bytes(bb)); + } + + /** + * Create a PlacementGroupId instance according to the given bytes. + */ + public static PlacementGroupId fromBytes(byte[] bytes) { + return new PlacementGroupId(bytes); + } + + /** + * Generate a nil PlacementGroupId. + */ + private static PlacementGroupId nil() { + byte[] b = new byte[LENGTH]; + Arrays.fill(b, (byte) 0xFF); + return new PlacementGroupId(b); + } + + /** + * Generate an PlacementGroupId with random value. Used for local mode and test only. + */ + public static PlacementGroupId fromRandom() { + byte[] b = new byte[LENGTH]; + new Random().nextBytes(b); + return new PlacementGroupId(b); + } + + @Override + public int size() { + return LENGTH; + } +} diff --git a/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java new file mode 100644 index 000000000..6a0fac180 --- /dev/null +++ b/java/runtime/src/main/java/io/ray/runtime/placementgroup/PlacementGroupImpl.java @@ -0,0 +1,28 @@ +package io.ray.runtime.placementgroup; + +import io.ray.api.placementgroup.PlacementGroup; + +/** + * The default implementation of `PlacementGroup` interface. + */ +public class PlacementGroupImpl implements PlacementGroup { + + private PlacementGroupId id; + private int bundleCount = 0; + + public PlacementGroupImpl() { + } + + public PlacementGroupImpl(PlacementGroupId id, int bundleCount) { + this.id = id; + this.bundleCount = bundleCount; + } + + public PlacementGroupId getId() { + return id; + } + + public int getBundleCount() { + return bundleCount; + } +} diff --git a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java index c2c3e8616..c356169b5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/LocalModeTaskSubmitter.java @@ -12,6 +12,8 @@ import io.ray.api.id.TaskId; import io.ray.api.id.UniqueId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.runtime.RayRuntimeInternal; import io.ray.runtime.actor.LocalModeActorHandle; import io.ray.runtime.context.LocalModeWorkerContext; @@ -27,6 +29,7 @@ import io.ray.runtime.generated.Common.TaskSpec; import io.ray.runtime.generated.Common.TaskType; import io.ray.runtime.object.LocalModeObjectStore; import io.ray.runtime.object.NativeRayObject; +import io.ray.runtime.placementgroup.PlacementGroupImpl; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Collections; @@ -102,7 +105,7 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { // Check whether task arguments are ready. for (TaskArg arg : taskSpec.getArgsList()) { ByteString idByteString = arg.getObjectRef().getObjectId(); - if (idByteString != ByteString.EMPTY) { + if (idByteString != ByteString.EMPTY) { ObjectId id = new ObjectId(idByteString.toByteArray()); if (!objectStore.isObjectReady(id)) { unreadyObjects.add(id); @@ -209,6 +212,12 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } } + @Override + public PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy) { + return new PlacementGroupImpl(); + } + @Override public BaseActorHandle getActor(ActorId actorId) { return actorHandles.get(actorId).copy(); diff --git a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java index 5443410e5..c153ad1f1 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/NativeTaskSubmitter.java @@ -8,9 +8,14 @@ import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.runtime.actor.NativeActorHandle; import io.ray.runtime.functionmanager.FunctionDescriptor; +import io.ray.runtime.placementgroup.PlacementGroupId; +import io.ray.runtime.placementgroup.PlacementGroupImpl; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import org.apache.commons.lang3.StringUtils; @@ -34,11 +39,20 @@ public class NativeTaskSubmitter implements TaskSubmitter { @Override public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List args, ActorCreationOptions options) throws IllegalArgumentException { - if (options != null && StringUtils.isNotBlank(options.name)) { - Optional actor = - options.global ? Ray.getGlobalActor(options.name) : Ray.getActor(options.name); - Preconditions.checkArgument(!actor.isPresent(), - String.format("Actor of name %s exists", options.name)); + if (options != null) { + if (options.group != null) { + PlacementGroupImpl group = (PlacementGroupImpl)options.group; + Preconditions.checkArgument(options.bundleIndex >= 0 + && options.bundleIndex < group.getBundleCount(), + String.format("Bundle index %s is invalid", options.bundleIndex)); + } + + if (StringUtils.isNotBlank(options.name)) { + Optional actor = + options.global ? Ray.getGlobalActor(options.name) : Ray.getActor(options.name); + Preconditions.checkArgument(!actor.isPresent(), + String.format("Actor of name %s exists", options.name)); + } } byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args, options); @@ -63,6 +77,13 @@ public class NativeTaskSubmitter implements TaskSubmitter { return returnIds.stream().map(ObjectId::new).collect(Collectors.toList()); } + @Override + public PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy) { + byte[] bytes = nativeCreatePlacementGroup(bundles, strategy.value()); + return new PlacementGroupImpl(PlacementGroupId.fromBytes(bytes), bundles.size()); + } + private static native List nativeSubmitTask(FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); @@ -73,4 +94,7 @@ public class NativeTaskSubmitter implements TaskSubmitter { private static native List nativeSubmitActorTask(byte[] actorId, FunctionDescriptor functionDescriptor, int functionDescriptorHash, List args, int numReturns, CallOptions callOptions); + + private static native byte[] nativeCreatePlacementGroup(List> bundles, + int strategy); } diff --git a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java index 1ea6b86bb..f67b7f4d5 100644 --- a/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java +++ b/java/runtime/src/main/java/io/ray/runtime/task/TaskSubmitter.java @@ -5,8 +5,11 @@ import io.ray.api.id.ActorId; import io.ray.api.id.ObjectId; import io.ray.api.options.ActorCreationOptions; import io.ray.api.options.CallOptions; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; import io.ray.runtime.functionmanager.FunctionDescriptor; import java.util.List; +import java.util.Map; /** * A set of methods to submit tasks and create actors. @@ -47,6 +50,15 @@ public interface TaskSubmitter { List submitActorTask(BaseActorHandle actor, FunctionDescriptor functionDescriptor, List args, int numReturns, CallOptions options); + /** + * Create a placement group. + * @param bundles Preallocated resource list. + * @param strategy Actor placement strategy. + * @return A handle to the created placement group. + */ + PlacementGroup createPlacementGroup(List> bundles, + PlacementStrategy strategy); + BaseActorHandle getActor(ActorId actorId); } diff --git a/java/test/src/main/java/io/ray/test/PlacementGroupTest.java b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java new file mode 100644 index 000000000..e3c6fa845 --- /dev/null +++ b/java/test/src/main/java/io/ray/test/PlacementGroupTest.java @@ -0,0 +1,50 @@ +package io.ray.test; + +import io.ray.api.ActorHandle; +import io.ray.api.Ray; +import io.ray.api.id.ActorId; +import io.ray.api.placementgroup.PlacementGroup; +import io.ray.api.placementgroup.PlacementStrategy; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.testng.Assert; +import org.testng.annotations.Test; + +@Test +public class PlacementGroupTest extends BaseTest { + + public static class Counter { + + private int value; + + public Counter(int initValue) { + this.value = initValue; + } + + public int getValue() { + return value; + } + } + + // TODO(ffbin): Currently Java doesn't support multi-node tests. + // This test just creates a placement group with one bundle. + // It's not comprehensive to test all placement group test cases. + public void testCreateAndCallActor() { + List> bundles = new ArrayList<>(); + Map bundle = new HashMap<>(); + bundle.put("CPU", 1.0); + bundles.add(bundle); + PlacementStrategy strategy = PlacementStrategy.PACK; + PlacementGroup placementGroup = Ray.createPlacementGroup(bundles, strategy); + + // Test creating an actor from a constructor. + ActorHandle actor = Ray.actor(Counter::new, 1) + .setPlacementGroup(placementGroup, 0).remote(); + Assert.assertNotEquals(actor.getId(), ActorId.NIL); + + // Test calling an actor. + Assert.assertEquals(Integer.valueOf(1), actor.task(Counter::getValue).remote().get()); + } +} diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 95102198e..f4b5af501 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -119,18 +119,18 @@ struct ActorCreationOptions { using PlacementStrategy = rpc::PlacementStrategy; struct PlacementGroupCreationOptions { - PlacementGroupCreationOptions() {} + PlacementGroupCreationOptions() = default; PlacementGroupCreationOptions( - const std::string &name, PlacementStrategy strategy, - const std::vector> &bundles) - : strategy(strategy), bundles(bundles), name(name) {} + std::string name, PlacementStrategy strategy, + std::vector> bundles) + : name(std::move(name)), strategy(strategy), bundles(std::move(bundles)) {} + /// The name of the placement group. + const std::string name; /// The strategy to place the bundle in Placement Group. const PlacementStrategy strategy = rpc::PACK; /// The resource bundles in this placement group. const std::vector> bundles; - /// The name of the placement group. - const std::string name; }; } // namespace ray diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h index fa509697a..5f9e101a1 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.h @@ -25,7 +25,7 @@ extern "C" { * Class: io_ray_runtime_RayNativeRuntime * Method: nativeInitialize * Signature: - * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;ILjava/lang/String;Ljava/util/Map;[B)V + * (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize( JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject, @@ -42,7 +42,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, job /* * Class: io_ray_runtime_RayNativeRuntime * Method: nativeShutdown - * Signature: (Z)V + * Signature: ()V */ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *, jclass); diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 0d2a052ab..00a84f52e 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -16,10 +16,10 @@ #include +#include "jni_utils.h" #include "ray/common/id.h" #include "ray/core_worker/common.h" #include "ray/core_worker/core_worker.h" -#include "jni_utils.h" /// Store C++ instances of ray function in the cache to avoid unnessesary JNI operations. thread_local std::unordered_map>> @@ -109,6 +109,7 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, std::unordered_map resources; std::vector dynamic_worker_options; uint64_t max_concurrency = 1; + auto placement_options = std::make_pair(ray::PlacementGroupID::Nil(), -1); if (actorCreationOptions) { global = env->GetBooleanField(actorCreationOptions, java_actor_creation_options_global); @@ -130,6 +131,19 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, } max_concurrency = static_cast(env->GetIntField( actorCreationOptions, java_actor_creation_options_max_concurrency)); + + auto group = + env->GetObjectField(actorCreationOptions, java_actor_creation_options_group); + if (group) { + auto placement_group_id = env->GetObjectField(group, java_placement_group_id); + auto java_id_bytes = static_cast( + env->CallObjectMethod(placement_group_id, java_base_id_get_bytes)); + RAY_CHECK_JAVA_EXCEPTION(env); + auto id = JavaByteArrayToId(env, java_id_bytes); + auto index = env->GetIntField(actorCreationOptions, + java_actor_creation_options_bundle_index); + placement_options = std::make_pair(id, index); + } } auto full_name = GetActorFullName(global, name); @@ -142,10 +156,34 @@ inline ray::ActorCreationOptions ToActorCreationOptions(JNIEnv *env, dynamic_worker_options, /*is_detached=*/false, full_name, - /*is_asyncio=*/false}; + /*is_asyncio=*/false, + placement_options}; return actor_creation_options; } +inline ray::PlacementStrategy ConvertStrategy(jint java_strategy) { + return 0 == java_strategy ? ray::rpc::PACK : ray::rpc::SPREAD; +} + +inline ray::PlacementGroupCreationOptions ToPlacementGroupCreationOptions( + JNIEnv *env, jobject java_bundles, jint java_strategy) { + std::vector> bundles; + JavaListToNativeVector>( + env, java_bundles, &bundles, [](JNIEnv *env, jobject java_bundle) { + return JavaMapToNativeMap( + env, java_bundle, + [](JNIEnv *env, jobject java_key) { + return JavaStringToNativeString(env, (jstring)java_key); + }, + [](JNIEnv *env, jobject java_value) { + double value = env->CallDoubleMethod(java_value, java_double_double_value); + RAY_CHECK_JAVA_EXCEPTION(env); + return value; + }); + }); + return ray::PlacementGroupCreationOptions("", ConvertStrategy(java_strategy), bundles); +} + #ifdef __cplusplus extern "C" { #endif @@ -212,6 +250,19 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask( return NativeIdVectorToJavaByteArrayList(env, return_ids); } +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *env, + jclass, + jobject bundles, + jint strategy) { + auto options = ToPlacementGroupCreationOptions(env, bundles, strategy); + ray::PlacementGroupID placement_group_id; + auto status = ray::CoreWorkerProcess::GetCoreWorker().CreatePlacementGroup( + options, &placement_group_id); + THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); + return IdToJavaByteArray(env, placement_group_id); +} + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h index 1863fe311..80f1aa004 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.h @@ -52,6 +52,15 @@ Java_io_ray_runtime_task_NativeTaskSubmitter_nativeSubmitActorTask(JNIEnv *, jcl jint, jobject, jint, jobject); +/* + * Class: io_ray_runtime_task_NativeTaskSubmitter + * Method: nativeCreatePlacementGroup + * Signature: (Ljava/util/List;I)[B + */ +JNIEXPORT jbyteArray JNICALL +Java_io_ray_runtime_task_NativeTaskSubmitter_nativeCreatePlacementGroup(JNIEnv *, jclass, + jobject, jint); + #ifdef __cplusplus } #endif diff --git a/src/ray/core_worker/lib/java/jni_init.cc b/src/ray/core_worker/lib/java/jni_init.cc index c3078b114..a6472e33a 100644 --- a/src/ray/core_worker/lib/java/jni_init.cc +++ b/src/ray/core_worker/lib/java/jni_init.cc @@ -74,6 +74,8 @@ jfieldID java_actor_creation_options_name; jfieldID java_actor_creation_options_max_restarts; jfieldID java_actor_creation_options_jvm_options; jfieldID java_actor_creation_options_max_concurrency; +jfieldID java_actor_creation_options_group; +jfieldID java_actor_creation_options_bundle_index; jclass java_gcs_client_options_class; jfieldID java_gcs_client_options_ip; @@ -89,6 +91,9 @@ jclass java_task_executor_class; jmethodID java_task_executor_parse_function_arguments; jmethodID java_task_executor_execute; +jclass java_placement_group_class; +jfieldID java_placement_group_id; + JavaVM *jvm; inline jclass LoadClass(JNIEnv *env, const char *class_name) { @@ -177,6 +182,12 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_base_task_options_resources = env->GetFieldID(java_base_task_options_class, "resources", "Ljava/util/Map;"); + java_placement_group_class = + LoadClass(env, "io/ray/runtime/placementgroup/PlacementGroupImpl"); + java_placement_group_id = + env->GetFieldID(java_placement_group_class, "id", + "Lio/ray/runtime/placementgroup/PlacementGroupId;"); + java_actor_creation_options_class = LoadClass(env, "io/ray/api/options/ActorCreationOptions"); java_actor_creation_options_global = @@ -189,6 +200,11 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) { java_actor_creation_options_class, "jvmOptions", "Ljava/lang/String;"); java_actor_creation_options_max_concurrency = env->GetFieldID(java_actor_creation_options_class, "maxConcurrency", "I"); + java_actor_creation_options_group = + env->GetFieldID(java_actor_creation_options_class, "group", + "Lio/ray/api/placementgroup/PlacementGroup;"); + java_actor_creation_options_bundle_index = + env->GetFieldID(java_actor_creation_options_class, "bundleIndex", "I"); java_gcs_client_options_class = LoadClass(env, "io/ray/runtime/gcs/GcsClientOptions"); java_gcs_client_options_ip = env->GetFieldID(java_gcs_client_options_class, "ip", "Ljava/lang/String;"); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 8663f8d0e..750eb5b2d 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -127,6 +127,10 @@ extern jfieldID java_actor_creation_options_max_restarts; extern jfieldID java_actor_creation_options_jvm_options; /// maxConcurrency field of ActorCreationOptions class extern jfieldID java_actor_creation_options_max_concurrency; +/// group field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_group; +/// bundleIndex field of ActorCreationOptions class +extern jfieldID java_actor_creation_options_bundle_index; /// GcsClientOptions class extern jclass java_gcs_client_options_class; @@ -153,6 +157,11 @@ extern jmethodID java_task_executor_parse_function_arguments; /// execute method of TaskExecutor class extern jmethodID java_task_executor_execute; +/// PlacementGroup class +extern jclass java_placement_group_class; +/// id field of PlacementGroup class +extern jfieldID java_placement_group_id; + #define CURRENT_JNI_VERSION JNI_VERSION_1_8 extern JavaVM *jvm; diff --git a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc index 5b0e629ec..2230cf525 100644 --- a/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc +++ b/src/ray/gcs/gcs_server/gcs_placement_group_scheduler.cc @@ -41,8 +41,8 @@ ScheduleMap GcsPackStrategy::Schedule( const GcsNodeManager &node_manager) { ScheduleMap schedule_map; auto &alive_nodes = node_manager.GetAllAliveNodes(); - for (size_t pos = 0; pos < bundles.size(); pos++) { - schedule_map[bundles[pos]->BundleId()] = + for (auto &bundle : bundles) { + schedule_map[bundle->BundleId()] = ClientID::FromBinary(alive_nodes.begin()->second->node_id()); } return schedule_map; @@ -80,13 +80,13 @@ void GcsPlacementGroupScheduler::Schedule( auto strategy = placement_group->GetStrategy(); auto alive_nodes = gcs_node_manager_.GetAllAliveNodes(); /// If the placement group don't have bundle, the placement group creates success. - if (bundles.size() == 0) { + if (bundles.empty()) { schedule_success_handler(placement_group); return; } // If alive_node is empty, the the placement group creates fail. - if (alive_nodes.size() == 0) { + if (alive_nodes.empty()) { schedule_failure_handler(placement_group); return; } diff --git a/src/ray/raylet/scheduling_policy.cc b/src/ray/raylet/scheduling_policy.cc index 775d57ff1..43521e553 100644 --- a/src/ray/raylet/scheduling_policy.cc +++ b/src/ray/raylet/scheduling_policy.cc @@ -155,9 +155,11 @@ bool SchedulingPolicy::ScheduleBundle( ResourceSet available_node_resources = ResourceSet(node_resources.GetAvailableResources()); available_node_resources.SubtractResources(node_resources.GetLoadResources()); - RAY_LOG(DEBUG) << "client_id " << node_client_id - << " avail: " << node_resources.GetAvailableResources().ToString() - << " load: " << node_resources.GetLoadResources().ToString(); + RAY_LOG(DEBUG) << "Scheduling bundle, client id = " << node_client_id + << ", available resources = " + << node_resources.GetAvailableResources().ToString() + << ", resources load = " << node_resources.GetLoadResources().ToString() + << ", the resource needed = " << resource_demand.ToString(); /// If the resource_demand is subset of the whole available_node_resources, this bundle /// can be set in this node, return true. return resource_demand.IsSubset(available_node_resources); diff --git a/src/ray/rpc/gcs_server/gcs_rpc_client.h b/src/ray/rpc/gcs_server/gcs_rpc_client.h index f671c0f1c..068f97f3a 100644 --- a/src/ray/rpc/gcs_server/gcs_rpc_client.h +++ b/src/ray/rpc/gcs_server/gcs_rpc_client.h @@ -109,6 +109,10 @@ class GcsRpcClient { new GrpcClient(address, port, client_call_manager)); worker_info_grpc_client_ = std::unique_ptr>( new GrpcClient(address, port, client_call_manager)); + placement_group_info_grpc_client_ = + std::unique_ptr>( + new GrpcClient(address, port, + client_call_manager)); } /// Add job info to gcs server.