diff --git a/java/api/src/main/java/org/ray/api/RayRemote.java b/java/api/src/main/java/org/ray/api/RayRemote.java index d3690809d..0e4af13b5 100644 --- a/java/api/src/main/java/org/ray/api/RayRemote.java +++ b/java/api/src/main/java/org/ray/api/RayRemote.java @@ -4,6 +4,7 @@ import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; +import org.ray.util.ResourceItem; /** * a ray remote function or class (as an actor). @@ -11,10 +12,9 @@ import java.lang.annotation.Target; @Target({ElementType.METHOD, ElementType.TYPE}) @Retention(RetentionPolicy.RUNTIME) public @interface RayRemote { - /** - * whether to use external I/O pool to execute the function. + * This is used for default resources. + * @return The resources of the method or actor need. */ - boolean externalIo() default false; - + ResourceItem[] resources() default {@ResourceItem()}; } diff --git a/java/common/src/main/java/org/ray/util/ResourceItem.java b/java/common/src/main/java/org/ray/util/ResourceItem.java new file mode 100644 index 000000000..25e67acfb --- /dev/null +++ b/java/common/src/main/java/org/ray/util/ResourceItem.java @@ -0,0 +1,7 @@ +package org.ray.util; + +public @interface ResourceItem { + public String name() default ""; + public double value() default 0; + +} diff --git a/java/common/src/main/java/org/ray/util/ResourceUtil.java b/java/common/src/main/java/org/ray/util/ResourceUtil.java new file mode 100644 index 000000000..a8adf2c33 --- /dev/null +++ b/java/common/src/main/java/org/ray/util/ResourceUtil.java @@ -0,0 +1,102 @@ +package org.ray.util; + +import java.util.HashMap; +import java.util.Map; + +public class ResourceUtil { + public static final String CPU_LITERAL = "CPU"; + public static final String GPU_LITERAL = "GPU"; + + /** + * Convert the array that contains resource items to a map. + * + * @param resourceArray The resources list to be converted. + * @return The map whose key represents the resource name + * and the value represents the resource quantity. + */ + public static Map getResourcesMapFromArray(ResourceItem[] resourceArray) { + Map resourceMap = new HashMap<>(); + if (resourceArray != null) { + for (ResourceItem item : resourceArray) { + if (!item.name().isEmpty()) { + resourceMap.put(item.name(), item.value()); + } + } + } + + return resourceMap; + } + + /** + * Convert the resources map to a format string. + * + * @param resources The resource map to be Converted. + * @return The format resources string, like "{CPU:4, GPU:0}". + */ + public static String getResourcesFromatStringFromMap(Map resources) { + StringBuilder builder = new StringBuilder(); + builder.append("{"); + int count = 1; + for (Map.Entry entry : resources.entrySet()) { + builder.append(entry.getKey()).append(":").append(entry.getValue()); + count++; + if (count != resources.size()) { + builder.append(", "); + } + } + builder.append("}"); + return builder.toString(); + } + + /** + * Convert resources map to a string that is used + * for the command line argument of starting raylet. + * + * @param resources The resources map to be converted. + * @return The starting-raylet command line argument, like "CPU,4,GPU,0". + */ + public static String getResourcesStringFromMap(Map resources) { + StringBuilder builder = new StringBuilder(); + if (resources != null) { + int count = 1; + for (Map.Entry entry : resources.entrySet()) { + builder.append(entry.getKey()).append(",").append(entry.getValue()); + if (count != resources.size()) { + builder.append(","); + } + count++; + } + } + return builder.toString(); + } + + /** + * Parse the static resources configure field and convert to the resources map. + * + * @param resources The static resources string to be parsed. + * @return The map whose key represents the resource name + * and the value represents the resource quantity. + * @throws IllegalArgumentException If the resources string's format does match, + * it will throw an IllegalArgumentException. + */ + public static Map getResourcesMapFromString(String resources) + throws IllegalArgumentException { + Map ret = new HashMap<>(); + if (resources != null) { + String[] items = resources.split(","); + for (String item : items) { + String trimItem = item.trim(); + String[] resourcePair = trimItem.split(":"); + + if (resourcePair.length != 2) { + throw new IllegalArgumentException("Format of static resurces configure is invalid."); + } + + final String resourceName = resourcePair[0].trim(); + final Double resourceValue = Double.valueOf(resourcePair[1].trim()); + ret.put(resourceName, resourceValue); + } + } + return ret; + } +} diff --git a/java/ray.config.ini b/java/ray.config.ini index a8a6b0981..8427ed635 100644 --- a/java/ray.config.ini +++ b/java/ray.config.ini @@ -62,7 +62,9 @@ deploy = false onebox_delay_seconds_before_run_app_logic = 0 -use_raylet = false +use_raylet = true + +static_resources = CPU:4,GPU:0 ; java class which main is served as the driver in a java worker driver_class = diff --git a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java index a7ca6ac62..2b722a41c 100644 --- a/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java +++ b/java/runtime-common/src/main/java/org/ray/core/model/RayParameters.java @@ -124,6 +124,9 @@ public class RayParameters { @AConfig(comment = "worker fetch request size") public int worker_fetch_request_size = 10000; + @AConfig(comment = "static resource list of this node") + public String static_resources = ""; + public RayParameters(ConfigReader config) { if (null != config) { String networkInterface = config.getStringValue("ray.java", "network_interface", null, diff --git a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java index f999f8e8a..4ddfdf306 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java +++ b/java/runtime-common/src/main/java/org/ray/spi/LocalSchedulerProxy.java @@ -6,6 +6,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; + import org.ray.api.RayList; import org.ray.api.RayMap; import org.ray.api.RayObject; @@ -17,6 +18,7 @@ import org.ray.core.UniqueIdHelper; import org.ray.core.WorkerContext; import org.ray.spi.model.RayInvocation; import org.ray.spi.model.TaskSpec; +import org.ray.util.ResourceUtil; import org.ray.util.logger.RayLog; /** @@ -108,6 +110,8 @@ public class LocalSchedulerProxy { task.taskId = taskId; task.returnIds = returnIds; task.cursorId = invocation.getActor() != null ? invocation.getActor().getTaskCursor() : null; + task.resources = ResourceUtil + .getResourcesMapFromArray(invocation.getRemoteAnnotation().resources()); //WorkerContext.onSubmitTask(); RayLog.core.info( diff --git a/java/runtime-common/src/main/java/org/ray/spi/model/TaskSpec.java b/java/runtime-common/src/main/java/org/ray/spi/model/TaskSpec.java index 91d8eb09a..88101092a 100644 --- a/java/runtime-common/src/main/java/org/ray/spi/model/TaskSpec.java +++ b/java/runtime-common/src/main/java/org/ray/spi/model/TaskSpec.java @@ -1,7 +1,9 @@ package org.ray.spi.model; import java.util.Arrays; +import java.util.Map; import org.ray.api.UniqueID; +import org.ray.util.ResourceUtil; /** * Represents necessary information of a task for scheduling and executing. @@ -42,6 +44,9 @@ public class TaskSpec { // Id for create a target actor public UniqueID createActorId; + // The task's resource demands. + public Map resources; + public UniqueID cursorId; @Override @@ -56,6 +61,8 @@ public class TaskSpec { builder.append("\treturnIds: ").append(Arrays.toString(returnIds)).append("\n"); builder.append("\tactorHandleId: ").append(actorHandleId).append("\n"); builder.append("\tcreateActorId: ").append(createActorId).append("\n"); + builder.append("\tresources: ") + .append(ResourceUtil.getResourcesFromatStringFromMap(resources)).append("\n"); builder.append("\tcursorId: ").append(cursorId).append("\n"); builder.append("\targs:\n"); for (FunctionArg arg : args) { @@ -65,4 +72,5 @@ public class TaskSpec { } return builder.toString(); } -} \ No newline at end of file + +} diff --git a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java index 43742133c..d625f3c90 100644 --- a/java/runtime-native/src/main/java/org/ray/runner/RunManager.java +++ b/java/runtime-native/src/main/java/org/ray/runner/RunManager.java @@ -16,6 +16,7 @@ import org.ray.core.model.RunMode; import org.ray.runner.RunInfo.ProcessType; import org.ray.spi.PathConfig; import org.ray.spi.model.AddressInfo; +import org.ray.util.ResourceUtil; import org.ray.util.StringUtil; import org.ray.util.config.ConfigReader; import org.ray.util.logger.RayLog; @@ -350,10 +351,13 @@ public class RunManager { startObjectStore(0, info, params.working_directory + "/store", params.redis_address, params.node_ip_address, params.redirect, params.cleanup); + Map staticResources = + ResourceUtil.getResourcesMapFromString(params.static_resources); + //Start raylet - startRaylet(storeName, info, params.num_cpus[0],params.num_gpus[0], - params.num_workers,params.working_directory + "/raylet", - params.redis_address, params.node_ip_address, params.redirect, params.cleanup); + startRaylet(storeName, info, params.num_workers, + params.working_directory + "/raylet", params.redis_address, + params.node_ip_address, params.redirect, staticResources, params.cleanup); runInfo.localStores.add(info); } else { @@ -677,10 +681,9 @@ public class RunManager { } } - private void startRaylet(String storeName, AddressInfo info, int numCpus, - int numGpus, int numWorkers, String workDir, - String redisAddress, String ip, boolean redirect, - boolean cleanup) { + private void startRaylet(String storeName, AddressInfo info, int numWorkers, + String workDir, String redisAddress, String ip, boolean redirect, + Map staticResources, boolean cleanup) { int rpcPort = params.raylet_port; String rayletSocketName = "/tmp/raylet" + rpcPort; @@ -695,8 +698,8 @@ public class RunManager { assert (sep != -1); String gcsIp = redisAddress.substring(0, sep); String gcsPort = redisAddress.substring(sep + 1); - - String resourceArgument = "GPU," + numGpus + ",CPU," + numCpus; + + String resourceArgument = ResourceUtil.getResourcesStringFromMap(staticResources); String[] cmds = new String[]{filePath, rayletSocketName, storeName, ip, gcsIp, gcsPort, "" + numWorkers, workerCmd, resourceArgument}; diff --git a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java index 78deb5413..d1a71b711 100644 --- a/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java +++ b/java/runtime-native/src/main/java/org/ray/spi/impl/DefaultLocalSchedulerClient.java @@ -5,11 +5,13 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.ray.api.UniqueID; import org.ray.core.RayRuntime; import org.ray.spi.LocalSchedulerLink; import org.ray.spi.model.FunctionArg; import org.ray.spi.model.TaskSpec; +import org.ray.util.ResourceUtil; import org.ray.util.logger.RayLog; /** @@ -64,6 +66,20 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { @Override public void submitTask(TaskSpec task) { + // We don't support resources management in non raylet mode. + if (!useRaylet) { + task.resources.clear(); + task.resources.put(ResourceUtil.CPU_LITERAL, 0.0); + } else { + if (!task.resources.containsKey(ResourceUtil.CPU_LITERAL)) { + task.resources.put(ResourceUtil.CPU_LITERAL, 0.0); + } + + if (!task.resources.containsKey(ResourceUtil.GPU_LITERAL)) { + task.resources.put(ResourceUtil.GPU_LITERAL, 0.0); + } + } + ByteBuffer info = taskSpec2Info(task); byte[] a = null; if (!task.actorId.isNil()) { @@ -220,13 +236,15 @@ public class DefaultLocalSchedulerClient implements LocalSchedulerLink { // The required_resources vector indicates the quantities of the different // resources required by this task. The index in this vector corresponds to // the resource type defined in the ResourceIndex enum. For example, - - int[]requiredResourcesOffsets = new int[1]; - for (int i = 0; i < requiredResourcesOffsets.length; i++) { - int keyOffset = 0; - keyOffset = fbb.createString(ByteBuffer.wrap("CPU".getBytes())); - requiredResourcesOffsets[i] = ResourcePair.createResourcePair(fbb, keyOffset, 0.0); + int[] requiredResourcesOffsets = new int[task.resources.size()]; + int i = 0; + for (Map.Entry entry : task.resources.entrySet()) { + int keyOffset = fbb.createString(ByteBuffer.wrap(entry.getKey().getBytes())); + requiredResourcesOffsets[i] = + ResourcePair.createResourcePair(fbb, keyOffset, entry.getValue()); + i++; } + int requiredResourcesOffset = fbb.createVectorOfTables(requiredResourcesOffsets); int root = TaskInfo.createTaskInfo( diff --git a/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java new file mode 100644 index 000000000..a1edacb29 --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/ResourcesManagementTest.java @@ -0,0 +1,82 @@ +package org.ray.api.test; + +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.api.RayRemote; +import org.ray.api.WaitResult; +import org.ray.core.RayRuntime; +import org.ray.util.ResourceItem; + +/** + * Resources Management Test. + */ +@RunWith(MyRunner.class) +public class ResourcesManagementTest { + + @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), + @ResourceItem(name = "GPU", value = 0)}) + public static Integer echo1(Integer number) { + return number; + } + + @RayRemote(resources = {@ResourceItem(name = "CPU", value = 4), + @ResourceItem(name = "GPU", value = 2)}) + public static Integer echo2(Integer number) { + return number; + } + + @RayRemote(resources = {@ResourceItem(name = "CPU", value = 2), + @ResourceItem(name = "GPU", value = 0)}) + public static class Echo1 { + public Integer echo(Integer number) { + return number; + } + } + + @RayRemote(resources = {@ResourceItem(name = "CPU", value = 8), + @ResourceItem(name = "GPU", value = 0)}) + public static class Echo2 { + public Integer echo(Integer number) { + return number; + } + } + + @Test + public void testMethods() { + Assume.assumeTrue(RayRuntime.getParams().use_raylet); + // This is a case that can satisfy required resources. + RayObject result1 = Ray.call(ResourcesManagementTest::echo1, 100); + Assert.assertEquals(100, (int) result1.get()); + + // This is a case that can't satisfy required resources. + final RayObject result2 = Ray.call(ResourcesManagementTest::echo2, 200); + WaitResult waitResult = Ray.wait(result2, 1000); + + Assert.assertEquals(0, waitResult.getReadyOnes().size()); + Assert.assertEquals(1, waitResult.getRemainOnes().size()); + } + + @Test + public void testActors() { + Assume.assumeTrue(RayRuntime.getParams().use_raylet); + // This is a case that can satisfy required resources. + RayActor echo1 = Ray.create(Echo1.class); + final RayObject result1 = Ray.call(Echo1::echo, echo1, 100); + Assert.assertEquals(100, (int) result1.get()); + + // This is a case that can't satisfy required resources. + RayActor echo2 = Ray.create(Echo2.class); + final RayObject result2 = Ray.call(Echo2::echo, echo2, 100); + WaitResult waitResult = Ray.wait(result2, 1000); + + Assert.assertEquals(0, waitResult.getReadyOnes().size()); + Assert.assertEquals(1, waitResult.getRemainOnes().size()); + } + +} + diff --git a/java/test/src/main/java/org/ray/api/test/TypesTest.java b/java/test/src/main/java/org/ray/api/test/TypesTest.java index 264f61322..9d9170474 100644 --- a/java/test/src/main/java/org/ray/api/test/TypesTest.java +++ b/java/test/src/main/java/org/ray/api/test/TypesTest.java @@ -88,17 +88,17 @@ public class TypesTest { return rets; } - @RayRemote(externalIo = true) + @RayRemote public static Integer sayRayFuture() { return 123; } - @RayRemote(externalIo = true) + @RayRemote public static MultipleReturns2 sayRayFutures() { return new MultipleReturns2<>(123, "123"); } - @RayRemote(externalIo = true) + @RayRemote public static Map sayRayFuturesN( Collection userReturnIds, String prefix) {