From 0648bd28efa3d86956c2e5029e26dedf289b98da Mon Sep 17 00:00:00 2001 From: fyrestone Date: Sat, 8 Feb 2020 13:01:28 +0800 Subject: [PATCH] [xlang] Cross language Python support (#6709) --- doc/source/conf.py | 13 + .../functionmanager/FunctionManager.java | 32 +- .../JavaFunctionDescriptor.java | 14 +- .../functionmanager/PyFunctionDescriptor.java | 2 +- .../org/ray/runtime/runner/RunManager.java | 8 +- .../runtime/task/LocalModeTaskSubmitter.java | 25 +- .../org/ray/runtime/task/TaskExecutor.java | 12 +- .../functionmanager/FunctionManagerTest.java | 48 ++- .../api/test/CrossLanguageInvocationTest.java | 37 +++ .../test_cross_language_invocation.py | 19 ++ python/ray/__init__.py | 5 + python/ray/_raylet.pxd | 7 +- python/ray/_raylet.pyx | 55 ++-- python/ray/actor.py | 207 +++++++++---- python/ray/cross_language.py | 84 ++++++ python/ray/dashboard/dashboard.py | 8 - python/ray/function_manager.py | 236 +-------------- python/ray/includes/common.pxd | 7 +- python/ray/includes/function_descriptor.pxd | 69 +++++ python/ray/includes/function_descriptor.pxi | 285 ++++++++++++++++++ python/ray/includes/task.pxd | 5 +- python/ray/includes/task.pxi | 8 +- python/ray/remote_function.py | 34 ++- python/ray/state.py | 11 +- python/ray/tests/test_advanced_3.py | 5 +- python/ray/tests/test_basic.py | 18 ++ python/ray/tests/test_cross_language.py | 15 + python/ray/tests/test_metrics.py | 8 +- python/ray/worker.py | 10 +- src/ray/common/function_descriptor.cc | 78 +++++ src/ray/common/function_descriptor.h | 189 ++++++++++++ src/ray/common/task/task_spec.cc | 17 +- src/ray/common/task/task_spec.h | 11 +- src/ray/common/task/task_util.h | 27 +- src/ray/core_worker/actor_handle.cc | 9 +- src/ray/core_worker/actor_handle.h | 7 +- src/ray/core_worker/common.h | 6 +- src/ray/core_worker/core_worker.cc | 16 +- src/ray/core_worker/lib/java/jni_utils.h | 23 ++ .../java/org_ray_runtime_RayNativeRuntime.cc | 4 +- .../org_ray_runtime_actor_NativeRayActor.cc | 2 +- ...rg_ray_runtime_task_NativeTaskSubmitter.cc | 14 +- src/ray/core_worker/test/core_worker_test.cc | 37 ++- .../test/direct_task_transport_test.cc | 35 ++- src/ray/core_worker/test/mock_worker.cc | 15 +- src/ray/protobuf/common.proto | 36 ++- src/ray/protobuf/core_worker.proto | 2 +- src/ray/ray_exported_symbols.lds | 1 + src/ray/ray_version_script.lds | 1 + src/ray/raylet/lineage_cache_test.cc | 7 +- src/ray/raylet/scheduling_queue.cc | 14 +- .../raylet/task_dependency_manager_test.cc | 7 +- streaming/python/includes/transfer.pxi | 23 +- streaming/python/jobworker.py | 10 +- .../python/tests/test_direct_transfer.py | 10 +- .../src/lib/java/streaming_jni_common.cc | 60 ++-- streaming/src/queue/transport.cc | 2 +- streaming/src/test/mock_actor.cc | 30 +- streaming/src/test/queue_tests_base.h | 12 +- 59 files changed, 1412 insertions(+), 580 deletions(-) create mode 100644 python/ray/cross_language.py create mode 100644 python/ray/includes/function_descriptor.pxd create mode 100644 python/ray/includes/function_descriptor.pxi create mode 100644 python/ray/tests/test_cross_language.py create mode 100644 src/ray/common/function_descriptor.cc create mode 100644 src/ray/common/function_descriptor.h diff --git a/doc/source/conf.py b/doc/source/conf.py index 34c7a124f..b2aa30faf 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -63,6 +63,19 @@ sys.modules["tensorflow"].VERSION = "9.9.9" # documentation root, use os.path.abspath to make it absolute, like shown here. sys.path.insert(0, os.path.abspath("../../python/")) +import ray + + +# Avoid @ray.remote run when doc generating +def fake_remote(*args, **kwargs): + def _inner_wrapper(cls_or_func): + return cls_or_func + + return _inner_wrapper + + +ray.remote = fake_remote + # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java index 733b6d68b..dac4ee52a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/FunctionManager.java @@ -84,8 +84,8 @@ public class FunctionManager { SerializedLambda serializedLambda = LambdaUtils.getSerializedLambda(func); final String className = serializedLambda.getImplClass().replace('/', '.'); final String methodName = serializedLambda.getImplMethodName(); - final String typeDescriptor = serializedLambda.getImplMethodSignature(); - functionDescriptor = new JavaFunctionDescriptor(className, methodName, typeDescriptor); + final String signature = serializedLambda.getImplMethodSignature(); + functionDescriptor = new JavaFunctionDescriptor(className, methodName, signature); RAY_FUNC_CACHE.get().put(func.getClass(), functionDescriptor); } return getFunction(jobId, functionDescriptor); @@ -167,13 +167,26 @@ public class FunctionManager { } } } - return classFunctions.get(ImmutablePair.of(descriptor.name, descriptor.typeDescriptor)); + final Pair key = ImmutablePair.of(descriptor.name, descriptor.signature); + RayFunction func = classFunctions.get(key); + if (func == null) { + if (classFunctions.containsKey(key)) { + throw new RuntimeException( + String.format("RayFunction %s is overloaded, the signature can't be empty.", + descriptor.toString())); + } else { + throw new RuntimeException( + String.format("RayFunction %s not found", descriptor.toString())); + } + } + return func; } /** * Load all functions from a class. */ Map, RayFunction> loadFunctionsForClass(String className) { + // If RayFunction is null, the function is overloaded. Map, RayFunction> map = new HashMap<>(); try { Class clazz = Class.forName(className, true, classLoader); @@ -187,10 +200,17 @@ public class FunctionManager { final String methodName = e instanceof Method ? e.getName() : CONSTRUCTOR_NAME; final Type type = e instanceof Method ? Type.getType((Method) e) : Type.getType((Constructor) e); - final String typeDescriptor = type.getDescriptor(); + final String signature = type.getDescriptor(); RayFunction rayFunction = new RayFunction(e, classLoader, - new JavaFunctionDescriptor(className, methodName, typeDescriptor)); - map.put(ImmutablePair.of(methodName, typeDescriptor), rayFunction); + new JavaFunctionDescriptor(className, methodName, signature)); + map.put(ImmutablePair.of(methodName, signature), rayFunction); + // For cross language call java function without signature + final Pair emptyDescriptor = ImmutablePair.of(methodName, ""); + if (map.containsKey(emptyDescriptor)) { + map.put(emptyDescriptor, null); // Mark this function as overloaded. + } else { + map.put(emptyDescriptor, rayFunction); + } } } catch (Exception e) { throw new RuntimeException("Failed to load functions from class " + className, e); diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java index 25b34539c..edf16ffaa 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/JavaFunctionDescriptor.java @@ -19,14 +19,14 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor { */ public final String name; /** - * Function's type descriptor. + * Function's signature. */ - public final String typeDescriptor; + public final String signature; - public JavaFunctionDescriptor(String className, String name, String typeDescriptor) { + public JavaFunctionDescriptor(String className, String name, String signature) { this.className = className; this.name = name; - this.typeDescriptor = typeDescriptor; + this.signature = signature; } @Override @@ -45,17 +45,17 @@ public final class JavaFunctionDescriptor implements FunctionDescriptor { JavaFunctionDescriptor that = (JavaFunctionDescriptor) o; return Objects.equal(className, that.className) && Objects.equal(name, that.name) && - Objects.equal(typeDescriptor, that.typeDescriptor); + Objects.equal(signature, that.signature); } @Override public int hashCode() { - return Objects.hashCode(className, name, typeDescriptor); + return Objects.hashCode(className, name, signature); } @Override public List toList() { - return ImmutableList.of(className, name, typeDescriptor); + return ImmutableList.of(className, name, signature); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java index 6845e79c7..912cca2e8 100644 --- a/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java +++ b/java/runtime/src/main/java/org/ray/runtime/functionmanager/PyFunctionDescriptor.java @@ -28,7 +28,7 @@ public class PyFunctionDescriptor implements FunctionDescriptor { @Override public List toList() { - return Arrays.asList(moduleName, className, functionName); + return Arrays.asList(moduleName, className, functionName, "" /* function hash */); } @Override diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 4e645ed69..fe5114e0b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -146,8 +146,12 @@ public class RunManager { e.printStackTrace(); } if (!p.isAlive()) { - throw new RuntimeException( - String.format("Failed to start %s. Exit code: %d.", name, p.exitValue())); + String message = String.format("Failed to start %s. Exit code: %d.", + name, p.exitValue()); + if (rayConfig.redirectOutput) { + message += String.format(" Logs are redirected to %s and %s.", stdout, stderr); + } + throw new RuntimeException(message); } processes.add(Pair.of(name, p)); if (LOGGER.isInfoEnabled()) { diff --git a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java index 9f8151909..8498dfd8a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/LocalModeTaskSubmitter.java @@ -28,6 +28,7 @@ import org.ray.runtime.actor.LocalModeRayActor; import org.ray.runtime.context.LocalModeWorkerContext; import org.ray.runtime.functionmanager.FunctionDescriptor; import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.runtime.generated.Common; import org.ray.runtime.generated.Common.ActorCreationTaskSpec; import org.ray.runtime.generated.Common.ActorTaskSpec; import org.ray.runtime.generated.Common.Language; @@ -153,14 +154,20 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { List args) { byte[] taskIdBytes = new byte[TaskId.LENGTH]; new Random().nextBytes(taskIdBytes); + List functionDescriptorList = functionDescriptor.toList(); + Preconditions.checkState(functionDescriptorList.size() >= 3); return TaskSpec.newBuilder() .setType(taskType) .setLanguage(Language.JAVA) .setJobId( ByteString.copyFrom(runtime.getRayConfig().getJobId().getBytes())) .setTaskId(ByteString.copyFrom(taskIdBytes)) - .addAllFunctionDescriptor(functionDescriptor.toList().stream().map(ByteString::copyFromUtf8) - .collect(Collectors.toList())) + .setFunctionDescriptor(org.ray.runtime.generated.Common.FunctionDescriptor.newBuilder() + .setJavaFunctionDescriptor( + org.ray.runtime.generated.Common.JavaFunctionDescriptor.newBuilder() + .setClassName(functionDescriptorList.get(0)) + .setFunctionName(functionDescriptorList.get(1)) + .setSignature(functionDescriptorList.get(2)))) .addAllArgs(args.stream().map(arg -> arg.id != null ? TaskArg.newBuilder() .addObjectIds(ByteString.copyFrom(arg.id.getBytes())).build() : TaskArg.newBuilder().setData(ByteString.copyFrom(arg.value.data)) @@ -307,9 +314,17 @@ public class LocalModeTaskSubmitter implements TaskSubmitter { } private static JavaFunctionDescriptor getJavaFunctionDescriptor(TaskSpec taskSpec) { - List functionDescriptor = taskSpec.getFunctionDescriptorList(); - return new JavaFunctionDescriptor(functionDescriptor.get(0).toStringUtf8(), - functionDescriptor.get(1).toStringUtf8(), functionDescriptor.get(2).toStringUtf8()); + org.ray.runtime.generated.Common.FunctionDescriptor functionDescriptor = + taskSpec.getFunctionDescriptor(); + if (functionDescriptor.getFunctionDescriptorCase() == + Common.FunctionDescriptor.FunctionDescriptorCase.JAVA_FUNCTION_DESCRIPTOR) { + return new JavaFunctionDescriptor( + functionDescriptor.getJavaFunctionDescriptor().getClassName(), + functionDescriptor.getJavaFunctionDescriptor().getFunctionName(), + functionDescriptor.getJavaFunctionDescriptor().getSignature()); + } else { + throw new RuntimeException("Can't build non java function descriptor"); + } } private static List getFunctionArgs(TaskSpec taskSpec) { diff --git a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java index dedf693ed..230faf7e2 100644 --- a/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java +++ b/java/runtime/src/main/java/org/ray/runtime/task/TaskExecutor.java @@ -48,11 +48,11 @@ public abstract class TaskExecutor { List returnObjects = new ArrayList<>(); ClassLoader oldLoader = Thread.currentThread().getContextClassLoader(); - // Find the executable object. - RayFunction rayFunction = runtime.getFunctionManager() - .getFunction(jobId, parseFunctionDescriptor(rayFunctionInfo)); - Preconditions.checkNotNull(rayFunction); + JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo); + RayFunction rayFunction = null; try { + // Find the executable object. + rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader); @@ -91,7 +91,9 @@ public abstract class TaskExecutor { } catch (Exception e) { LOGGER.error("Error executing task " + taskId, e); if (taskType != TaskType.ACTOR_CREATION_TASK) { - if (rayFunction.hasReturn()) { + boolean hasReturn = rayFunction != null && rayFunction.hasReturn(); + boolean isCrossLanguage = functionDescriptor.signature.equals(""); + if (hasReturn || isCrossLanguage) { returnObjects.add(ObjectSerializer .serialize(new RayTaskException("Error executing task " + taskId, e))); } diff --git a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java index 0834106f9..937a22819 100644 --- a/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java +++ b/java/runtime/src/test/java/org/ray/runtime/functionmanager/FunctionManagerTest.java @@ -37,6 +37,14 @@ public class FunctionManagerTest { public Object bar() { return null; } + + public Object overloadFunction(int i) { + return null; + } + + public Object overloadFunction(double d) { + return null; + } } private static RayFunc0 fooFunc; @@ -45,6 +53,8 @@ public class FunctionManagerTest { private static JavaFunctionDescriptor fooDescriptor; private static JavaFunctionDescriptor barDescriptor; private static JavaFunctionDescriptor barConstructorDescriptor; + private static JavaFunctionDescriptor overloadFunctionDescriptorInt; + private static JavaFunctionDescriptor overloadFunctionDescriptorDouble; @BeforeClass public static void beforeClass() { @@ -58,6 +68,10 @@ public class FunctionManagerTest { barConstructorDescriptor = new JavaFunctionDescriptor(Bar.class.getName(), FunctionManager.CONSTRUCTOR_NAME, "()V"); + overloadFunctionDescriptorInt = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), + "overloadFunction", "(I)Ljava/lang/Object;"); + overloadFunctionDescriptorDouble = new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), + "overloadFunction", "(D)Ljava/lang/Object;"); } @Test @@ -102,6 +116,13 @@ public class FunctionManagerTest { Assert.assertTrue(func.isConstructor()); Assert.assertEquals(func.getFunctionDescriptor(), barConstructorDescriptor); Assert.assertNotNull(func.getRayRemoteAnnotation()); + + // Test raise overload exception + Assert.expectThrows(RuntimeException.class, () -> { + functionManager.getFunction(JobId.NIL, + new JavaFunctionDescriptor(FunctionManagerTest.class.getName(), + "overloadFunction", "")); + }); } @Test @@ -109,12 +130,31 @@ public class FunctionManagerTest { JobFunctionTable functionTable = new JobFunctionTable(getClass().getClassLoader()); Map, RayFunction> res = functionTable .loadFunctionsForClass(Bar.class.getName()); - // The result should 2 entries, one for the constructor, the other for bar. - Assert.assertEquals(res.size(), 2); + // The result should be 4 entries: + // 1, the constructor with signature + // 2, the constructor without signature + // 3, bar with signature + // 4, bar without signature + Assert.assertEquals(res.size(), 7); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barDescriptor.name, barDescriptor.typeDescriptor))); + ImmutablePair.of(barDescriptor.name, barDescriptor.signature))); Assert.assertTrue(res.containsKey( - ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.typeDescriptor))); + ImmutablePair.of(barConstructorDescriptor.name, barConstructorDescriptor.signature))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(barDescriptor.name, ""))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(barConstructorDescriptor.name, ""))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(overloadFunctionDescriptorInt.name, overloadFunctionDescriptorInt.signature))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(overloadFunctionDescriptorDouble.name, overloadFunctionDescriptorDouble.signature))); + Assert.assertTrue(res.containsKey( + ImmutablePair.of(overloadFunctionDescriptorInt.name, ""))); + Pair overloadKey = ImmutablePair.of(overloadFunctionDescriptorInt.name, ""); + RayFunction func = res.get(overloadKey); + // The function is overloaded. + Assert.assertTrue(res.containsKey(overloadKey)); + Assert.assertNull(func); } @Test diff --git a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java index 0cde562b0..3612ca4ef 100644 --- a/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java +++ b/java/test/src/main/java/org/ray/api/test/CrossLanguageInvocationTest.java @@ -10,11 +10,14 @@ import org.ray.api.Ray; import org.ray.api.RayObject; import org.ray.api.RayPyActor; import org.ray.api.TestUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.testng.Assert; import org.testng.annotations.Test; public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { + private static final Logger LOGGER = LoggerFactory.getLogger(CrossLanguageInvocationTest.class); private static final String PYTHON_MODULE = "test_cross_language_invocation"; @Override @@ -46,6 +49,12 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { Assert.assertEquals(res.get(), "Response from Python: hello".getBytes()); } + @Test + public void testPythonCallJavaFunction() { + RayObject res = Ray.callPy(PYTHON_MODULE, "py_func_call_java_function", "hello".getBytes()); + Assert.assertEquals(res.get(), "[Python]py_func -> [Java]bytesEcho -> hello".getBytes()); + } + @Test(groups = {"directCall"}) public void testCallingPythonActor() { // Python worker doesn't support direct call yet. @@ -54,4 +63,32 @@ public class CrossLanguageInvocationTest extends BaseMultiLanguageTest { RayObject res = Ray.callPy(actor, "increase", "1".getBytes()); Assert.assertEquals(res.get(), "2".getBytes()); } + + @Test + public void testPythonCallJavaActor() { + RayObject res = Ray.callPy(PYTHON_MODULE, "py_func_call_java_actor", "1".getBytes()); + Assert.assertEquals(res.get(), "Counter1".getBytes()); + } + + public static byte[] bytesEcho(byte[] value) { + // This function will be called from test_cross_language_invocation.py + String valueStr = new String(value); + LOGGER.debug(String.format("bytesEcho called with: %s", valueStr)); + return ("[Java]bytesEcho -> " + valueStr).getBytes(); + } + + public static class TestActor { + public TestActor(byte[] v) { + value = v; + } + + public byte[] concat(byte[] v) { + byte[] c = new byte[value.length + v.length]; + System.arraycopy(value, 0, c, 0, value.length); + System.arraycopy(v, 0, c, value.length, v.length); + return c; + } + + private byte[] value; + } } diff --git a/java/test/src/main/resources/test_cross_language_invocation.py b/java/test/src/main/resources/test_cross_language_invocation.py index 0bc84c9f8..1ac99ac4b 100644 --- a/java/test/src/main/resources/test_cross_language_invocation.py +++ b/java/test/src/main/resources/test_cross_language_invocation.py @@ -12,6 +12,25 @@ def py_func(value): return b"Response from Python: " + value +@ray.remote +def py_func_call_java_function(value): + assert isinstance(value, bytes) + f = ray.java_function("org.ray.api.test.CrossLanguageInvocationTest", + "bytesEcho") + r = f.remote(value) + return b"[Python]py_func -> " + ray.get(r) + + +@ray.remote +def py_func_call_java_actor(value): + assert isinstance(value, bytes) + c = ray.java_actor_class( + "org.ray.api.test.CrossLanguageInvocationTest$TestActor") + java_actor = c.remote(b"Counter") + r = java_actor.concat.remote(value) + return ray.get(r) + + @ray.remote class Counter(object): def __init__(self, value): diff --git a/python/ray/__init__.py b/python/ray/__init__.py index bef7e5afd..4ae597d99 100644 --- a/python/ray/__init__.py +++ b/python/ray/__init__.py @@ -107,6 +107,7 @@ from ray._raylet import ( ObjectID, TaskID, UniqueID, + Language, ) # noqa: E402 _config = _Config() @@ -141,6 +142,7 @@ import ray.projects # noqa: E402 import ray.actor # noqa: F401 from ray.actor import method # noqa: E402 from ray.runtime_context import _get_runtime_context # noqa: E402 +from ray.cross_language import java_function, java_actor_class # noqa: E402 # Ray version string. __version__ = "0.9.0.dev0" @@ -182,6 +184,9 @@ __all__ = [ "shutdown", "show_in_webui", "wait", + "Language", + "java_function", + "java_actor_class", ] # ID types diff --git a/python/ray/_raylet.pxd b/python/ray/_raylet.pxd index 7fc0a991b..46e459232 100644 --- a/python/ray/_raylet.pxd +++ b/python/ray/_raylet.pxd @@ -20,6 +20,9 @@ from ray.includes.unique_ids cimport ( CObjectID, CActorID ) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) cdef class Buffer: cdef: @@ -66,4 +69,6 @@ cdef class CoreWorker: self, worker, outputs, const c_vector[CObjectID] return_ids, c_vector[shared_ptr[CRayObject]] *returns) -cdef c_vector[c_string] string_vector_from_list(list string_list) +cdef class FunctionDescriptor: + cdef: + CFunctionDescriptor descriptor diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 33932c836..02ea735e6 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -2,6 +2,7 @@ # distutils: language = c++ # cython: embedsignature = True # cython: language_level = 3 +# cython: c_string_encoding = default from cpython.exc cimport PyErr_CheckSignals @@ -87,7 +88,6 @@ from ray.exceptions import ( RayTimeoutError, ) from ray.experimental.no_return import NoReturn -from ray.function_manager import FunctionDescriptor from ray.utils import decode from ray.ray_constants import ( DEFAULT_PUT_OBJECT_DELAY, @@ -106,6 +106,7 @@ cimport cpython include "includes/unique_ids.pxi" include "includes/ray_config.pxi" +include "includes/function_descriptor.pxi" include "includes/task.pxi" include "includes/buffer.pxi" include "includes/common.pxi" @@ -206,6 +207,7 @@ def compute_task_id(ObjectID object_id): return TaskID(object_id.native().TaskId().Binary()) +@cython.auto_pickle(False) cdef class Language: cdef CLanguage lang @@ -218,7 +220,7 @@ cdef class Language: def __eq__(self, other): return (isinstance(other, Language) and - (self.lang) == (other.lang)) + (self.lang) == ((other).lang)) def __repr__(self): if self.lang == LANGUAGE_PYTHON: @@ -230,11 +232,12 @@ cdef class Language: else: raise Exception("Unexpected error") + def __reduce__(self): + return Language, (self.lang,) -# Programming language enum values. -cdef Language LANG_PYTHON = Language.from_native(LANGUAGE_PYTHON) -cdef Language LANG_CPP = Language.from_native(LANGUAGE_CPP) -cdef Language LANG_JAVA = Language.from_native(LANGUAGE_JAVA) + PYTHON = Language.from_native(LANGUAGE_PYTHON) + CPP = Language.from_native(LANGUAGE_CPP) + JAVA = Language.from_native(LANGUAGE_JAVA) cdef int prepare_resources( @@ -261,17 +264,8 @@ cdef int prepare_resources( return 0 -cdef c_vector[c_string] string_vector_from_list(list string_list): - cdef: - c_vector[c_string] out - for s in string_list: - if not isinstance(s, bytes): - raise TypeError("string_list elements must be bytes") - out.push_back(s) - return out - cdef void prepare_args( - CoreWorker core_worker, list args, c_vector[CTaskArg] *args_vector): + CoreWorker core_worker, args, c_vector[CTaskArg] *args_vector): cdef: size_t size int64_t put_threshold @@ -345,11 +339,10 @@ cdef execute_task( # Automatically restrict the GPUs available to this task. ray.utils.set_cuda_visible_devices(ray.get_gpu_ids()) - descriptor = tuple(ray_function.GetFunctionDescriptor()) + function_descriptor = CFunctionDescriptorToPython( + ray_function.GetFunctionDescriptor()) if task_type == TASK_TYPE_ACTOR_CREATION_TASK: - function_descriptor = FunctionDescriptor.from_bytes_list( - ray_function.GetFunctionDescriptor()) actor_class = manager.load_actor_class(job_id, function_descriptor) actor_id = core_worker.get_actor_id() worker.actors[actor_id] = actor_class.__new__(actor_class) @@ -359,13 +352,11 @@ cdef execute_task( last_checkpoint_timestamp=int(1000 * time.time()), checkpoint_ids=[])) - execution_info = execution_infos.get(descriptor) + execution_info = execution_infos.get(function_descriptor) if not execution_info: - function_descriptor = FunctionDescriptor.from_bytes_list( - ray_function.GetFunctionDescriptor()) execution_info = manager.get_execution_info( job_id, function_descriptor) - execution_infos[descriptor] = execution_info + execution_infos[function_descriptor] = execution_info function_name = execution_info.function_name extra_data = (b'{"name": ' + function_name.encode("ascii") + @@ -500,9 +491,6 @@ cdef execute_task( ray_signal.reset() if execution_info.max_calls != 0: - function_descriptor = FunctionDescriptor.from_bytes_list( - ray_function.GetFunctionDescriptor()) - # Reset the state of the worker for the next task to execute. # Increase the task execution counter. manager.increase_task_counter(job_id, function_descriptor) @@ -772,7 +760,8 @@ cdef class CoreWorker: message.decode("utf-8"))) def submit_task(self, - function_descriptor, + Language language, + FunctionDescriptor function_descriptor, args, int num_return_vals, c_bool is_direct_call, @@ -790,7 +779,7 @@ cdef class CoreWorker: task_options = CTaskOptions( num_return_vals, is_direct_call, c_resources) ray_function = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) + language.lang, function_descriptor.descriptor) prepare_args(self, args, &args_vector) with nogil: @@ -801,7 +790,8 @@ cdef class CoreWorker: return VectorToObjectIDs(return_ids) def create_actor(self, - function_descriptor, + Language language, + FunctionDescriptor function_descriptor, args, uint64_t max_reconstructions, resources, @@ -822,7 +812,7 @@ cdef class CoreWorker: prepare_resources(resources, &c_resources) prepare_resources(placement_resources, &c_placement_resources) ray_function = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) + language.lang, function_descriptor.descriptor) prepare_args(self, args, &args_vector) with nogil: @@ -837,8 +827,9 @@ cdef class CoreWorker: return ActorID(c_actor_id.Binary()) def submit_actor_task(self, + Language language, ActorID actor_id, - function_descriptor, + FunctionDescriptor function_descriptor, args, int num_return_vals, double num_method_cpus): @@ -856,7 +847,7 @@ cdef class CoreWorker: c_resources[b"CPU"] = num_method_cpus task_options = CTaskOptions(num_return_vals, False, c_resources) ray_function = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(function_descriptor)) + language.lang, function_descriptor.descriptor) prepare_args(self, args, &args_vector) with nogil: diff --git a/python/ray/actor.py b/python/ray/actor.py index 1d9ef73d7..a03bdc43a 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -7,12 +7,13 @@ import weakref from abc import ABCMeta, abstractmethod from collections import namedtuple -from ray.function_manager import FunctionDescriptor import ray.ray_constants as ray_constants import ray._raylet import ray.signature as signature import ray.worker -from ray import ActorID, ActorClassID +from ray import ActorID, ActorClassID, Language +from ray._raylet import PythonFunctionDescriptor +from ray import cross_language logger = logging.getLogger(__name__) @@ -143,8 +144,11 @@ class ActorClassMetadata: """Metadata for an actor class. Attributes: + language: The actor language, e.g. Python, Java. modified_class: The original class that was decorated (with some additional methods added like __ray_terminate__). + actor_creation_function_descriptor: The function descriptor for + the actor creation task. class_id: The ID of this actor class. class_name: The name of this class. num_cpus: The default number of CPUs required by the actor creation @@ -154,7 +158,6 @@ class ActorClassMetadata: memory: The heap memory quota for this actor. object_store_memory: The object store memory quota for this actor. resources: The default resources required by the actor creation task. - actor_method_cpus: The number of CPUs required by actor method tasks. last_export_session_and_job: A pair of the last exported session and job to help us to know whether this function was exported. This is an imperfect mechanism used to determine if we need to @@ -172,11 +175,17 @@ class ActorClassMetadata: each actor method. """ - def __init__(self, modified_class, class_id, max_reconstructions, num_cpus, - num_gpus, memory, object_store_memory, resources): + def __init__(self, language, modified_class, + actor_creation_function_descriptor, class_id, + max_reconstructions, num_cpus, num_gpus, memory, + object_store_memory, resources): + self.language = language self.modified_class = modified_class + self.actor_creation_function_descriptor = \ + actor_creation_function_descriptor + self.class_name = actor_creation_function_descriptor.class_name + self.is_cross_language = language != Language.PYTHON self.class_id = class_id - self.class_name = modified_class.__name__ self.max_reconstructions = max_reconstructions self.num_cpus = num_cpus self.num_gpus = num_gpus @@ -192,7 +201,8 @@ class ActorClassMetadata: ] constructor_name = "__init__" - if constructor_name not in self.actor_method_names: + if not self.is_cross_language and \ + constructor_name not in self.actor_method_names: # Add __init__ if it does not exist. # Actor creation will be executed with __init__ together. @@ -290,7 +300,10 @@ class ActorClass: def _ray_from_modified_class(cls, modified_class, class_id, max_reconstructions, num_cpus, num_gpus, memory, object_store_memory, resources): - for attribute in ["remote", "_remote", "_ray_from_modified_class"]: + for attribute in [ + "remote", "_remote", "_ray_from_modified_class", + "_ray_from_function_descriptor" + ]: if hasattr(modified_class, attribute): logger.warning("Creating an actor from class {} overwrites " "attribute {} of that class".format( @@ -307,10 +320,28 @@ class ActorClass: DerivedActorClass.__qualname__ = name # Construct the base object. self = DerivedActorClass.__new__(DerivedActorClass) + # Actor creation function descriptor. + actor_creation_function_descriptor = PythonFunctionDescriptor( + modified_class.__module__, "__init__", modified_class.__name__) self.__ray_metadata__ = ActorClassMetadata( - modified_class, class_id, max_reconstructions, num_cpus, num_gpus, - memory, object_store_memory, resources) + Language.PYTHON, modified_class, + actor_creation_function_descriptor, class_id, max_reconstructions, + num_cpus, num_gpus, memory, object_store_memory, resources) + + return self + + @classmethod + def _ray_from_function_descriptor(cls, language, + actor_creation_function_descriptor, + max_reconstructions, num_cpus, num_gpus, + memory, object_store_memory, resources): + self = ActorClass.__new__(ActorClass) + + self.__ray_metadata__ = ActorClassMetadata( + language, None, actor_creation_function_descriptor, None, + max_reconstructions, num_cpus, num_gpus, memory, + object_store_memory, resources) return self @@ -460,29 +491,33 @@ class ActorClass: if meta.num_cpus is None else meta.num_cpus) actor_method_cpu = ray_constants.DEFAULT_ACTOR_METHOD_CPU_SPECIFIED - function_name = "__init__" - function_descriptor = FunctionDescriptor( - meta.modified_class.__module__, function_name, - meta.modified_class.__name__) - # Do not export the actor class or the actor if run in LOCAL_MODE # Instead, instantiate the actor locally and add it to the worker's # dictionary if worker.mode == ray.LOCAL_MODE: + assert not meta.is_cross_language, \ + "Cross language ActorClass cannot be executed locally." actor_id = ActorID.from_random() worker.actors[actor_id] = meta.modified_class( *copy.deepcopy(args), **copy.deepcopy(kwargs)) else: # Export the actor. - if (meta.last_export_session_and_job != - worker.current_session_and_job): + if not meta.is_cross_language and (meta.last_export_session_and_job + != + worker.current_session_and_job): # If this actor class was not exported in this session and job, # we need to export this function again, because current GCS # doesn't have it. meta.last_export_session_and_job = ( worker.current_session_and_job) + # After serialize / deserialize modified class, the __module__ + # of modified class will be ray.cloudpickle.cloudpickle. + # So, here pass actor_creation_function_descriptor to make + # sure export actor class correct. worker.function_actor_manager.export_actor_class( - meta.modified_class, meta.actor_method_names) + meta.modified_class, + meta.actor_creation_function_descriptor, + meta.actor_method_names) resources = ray.utils.resources_from_resource_arguments( cpus_to_use, meta.num_gpus, meta.memory, @@ -497,24 +532,28 @@ class ActorClass: if actor_method_cpu == 1: actor_placement_resources = resources.copy() actor_placement_resources["CPU"] += 1 - function_signature = meta.method_signatures[function_name] - creation_args = signature.flatten_args(function_signature, args, - kwargs) + if meta.is_cross_language: + creation_args = cross_language.format_args( + worker, args, kwargs) + else: + function_signature = meta.method_signatures["__init__"] + creation_args = signature.flatten_args(function_signature, + args, kwargs) actor_id = worker.core_worker.create_actor( - function_descriptor.get_function_descriptor_list(), + meta.language, meta.actor_creation_function_descriptor, creation_args, meta.max_reconstructions, resources, actor_placement_resources, is_direct_call, max_concurrency, detached, is_asyncio) actor_handle = ActorHandle( + meta.language, actor_id, - meta.modified_class.__module__, - meta.class_name, - meta.actor_method_names, meta.method_decorators, meta.method_signatures, meta.actor_method_num_return_vals, actor_method_cpu, + meta.is_cross_language, + meta.actor_creation_function_descriptor, worker.current_session_and_job, original_handle=True) @@ -536,9 +575,8 @@ class ActorHandle: cloudpickle). Attributes: + _ray_actor_language: The actor language. _ray_actor_id: Actor ID. - _ray_module_name: The module name of this actor. - _ray_actor_method_names: The names of the actor methods. _ray_method_decorators: Optional decorators for the function invocation. This can be used to change the behavior on the invocation side, whereas a regular decorator can be used to change @@ -546,48 +584,55 @@ class ActorHandle: _ray_method_signatures: The signatures of the actor methods. _ray_method_num_return_vals: The default number of return values for each method. - _ray_class_name: The name of the actor class. _ray_actor_method_cpus: The number of CPUs required by actor methods. _ray_original_handle: True if this is the original actor handle for a given actor. If this is true, then the actor will be destroyed when this handle goes out of scope. + _ray_is_cross_language: Whether this actor is cross language. + _ray_actor_creation_function_descriptor: The function descriptor + of the actor creation task. """ def __init__(self, + language, actor_id, - module_name, - class_name, - actor_method_names, method_decorators, method_signatures, method_num_return_vals, actor_method_cpus, + is_cross_language, + actor_creation_function_descriptor, session_and_job, original_handle=False): + self._ray_actor_language = language self._ray_actor_id = actor_id - self._ray_module_name = module_name self._ray_original_handle = original_handle - self._ray_actor_method_names = actor_method_names self._ray_method_decorators = method_decorators self._ray_method_signatures = method_signatures self._ray_method_num_return_vals = method_num_return_vals - self._ray_class_name = class_name self._ray_actor_method_cpus = actor_method_cpus self._ray_session_and_job = session_and_job - self._ray_function_descriptor_lists = { - method_name: FunctionDescriptor( - self._ray_module_name, method_name, - self._ray_class_name).get_function_descriptor_list() - for method_name in self._ray_method_signatures.keys() - } + self._ray_is_cross_language = is_cross_language + self._ray_actor_creation_function_descriptor = \ + actor_creation_function_descriptor + self._ray_function_descriptor = {} - for method_name in actor_method_names: - method = ActorMethod( - self, - method_name, - self._ray_method_num_return_vals[method_name], - decorator=self._ray_method_decorators.get(method_name)) - setattr(self, method_name, method) + if not self._ray_is_cross_language: + assert isinstance(actor_creation_function_descriptor, + PythonFunctionDescriptor) + module_name = actor_creation_function_descriptor.module_name + class_name = actor_creation_function_descriptor.class_name + for method_name in self._ray_method_signatures.keys(): + function_descriptor = PythonFunctionDescriptor( + module_name, method_name, class_name) + self._ray_function_descriptor[ + method_name] = function_descriptor + method = ActorMethod( + self, + method_name, + self._ray_method_num_return_vals[method_name], + decorator=self._ray_method_decorators.get(method_name)) + setattr(self, method_name, method) def _actor_method_call(self, method_name, @@ -615,22 +660,34 @@ class ActorHandle: args = args or [] kwargs = kwargs or {} - function_signature = self._ray_method_signatures[method_name] - - if not args and not kwargs and not function_signature: - list_args = [] + if self._ray_is_cross_language: + list_args = cross_language.format_args(worker, args, kwargs) + function_descriptor = \ + cross_language.get_function_descriptor_for_actor_method( + self._ray_actor_language, + self._ray_actor_creation_function_descriptor, method_name) else: - list_args = signature.flatten_args(function_signature, args, - kwargs) + function_signature = self._ray_method_signatures[method_name] + + if not args and not kwargs and not function_signature: + list_args = [] + else: + list_args = signature.flatten_args(function_signature, args, + kwargs) + function_descriptor = self._ray_function_descriptor[method_name] + if worker.mode == ray.LOCAL_MODE: + assert not self._ray_is_cross_language,\ + "Cross language remote actor method " \ + "cannot be executed locally." function = getattr(worker.actors[self._actor_id], method_name) object_ids = worker.local_mode_manager.execute( function, method_name, args, kwargs, num_return_vals) else: object_ids = worker.core_worker.submit_actor_task( - self._ray_actor_id, - self._ray_function_descriptor_lists[method_name], list_args, - num_return_vals, self._ray_actor_method_cpus) + self._ray_actor_language, self._ray_actor_id, + function_descriptor, list_args, num_return_vals, + self._ray_actor_method_cpus) if len(object_ids) == 1: object_ids = object_ids[0] @@ -639,13 +696,28 @@ class ActorHandle: return object_ids + def __getattr__(self, item): + if not self._ray_is_cross_language: + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, item)) + + return ActorMethod( + self, + item, + ray_constants. + # Currently, we use default num returns + DEFAULT_ACTOR_METHOD_NUM_RETURN_VALS, + # Currently, cross-lang actor method not support decorator + decorator=None) + # Make tab completion work. def __dir__(self): - return self._ray_actor_method_names + return self._ray_method_signatures.keys() def __repr__(self): - return "Actor({}, {})".format(self._ray_class_name, - self._actor_id.hex()) + return "Actor({}, {})".format( + self._ray_actor_creation_function_descriptor.class_name, + self._actor_id.hex()) def __del__(self): """Terminate the worker that is running this actor.""" @@ -666,7 +738,7 @@ class ActorHandle: logger.warning( "Actor is garbage collected in the wrong driver." + " Actor id = %s, class name = %s.", self._ray_actor_id, - self._ray_class_name) + self._ray_actor_creation_function_descriptor.class_name) return if worker.connected and self._ray_original_handle: # Note: in py2 the weakref is destroyed prior to calling __del__ @@ -708,17 +780,18 @@ class ActorHandle: worker = ray.worker.get_global_worker() worker.check_connected() state = { + "actor_language": self._ray_actor_language, # Local mode just uses the actor ID. "core_handle": worker.core_worker.serialize_actor_handle( self._ray_actor_id) if hasattr(worker, "core_worker") else self._ray_actor_id, - "module_name": self._ray_module_name, - "class_name": self._ray_class_name, - "actor_method_names": self._ray_actor_method_names, "method_decorators": self._ray_method_decorators, "method_signatures": self._ray_method_signatures, "method_num_return_vals": self._ray_method_num_return_vals, - "actor_method_cpus": self._ray_actor_method_cpus + "actor_method_cpus": self._ray_actor_method_cpus, + "is_cross_language": self._ray_is_cross_language, + "actor_creation_function_descriptor": self. + _ray_actor_creation_function_descriptor, } return state @@ -738,16 +811,16 @@ class ActorHandle: # TODO(swang): Accessing the worker's current task ID is not # thread-safe. # Local mode just uses the actor ID. + state["actor_language"], worker.core_worker.deserialize_and_register_actor_handle( state["core_handle"]) if hasattr(worker, "core_worker") else state["core_handle"], - state["module_name"], - state["class_name"], - state["actor_method_names"], state["method_decorators"], state["method_signatures"], state["method_num_return_vals"], state["actor_method_cpus"], + state["is_cross_language"], + state["actor_creation_function_descriptor"], worker.current_session_and_job) def __getstate__(self): diff --git a/python/ray/cross_language.py b/python/ray/cross_language.py new file mode 100644 index 000000000..90cf3a239 --- /dev/null +++ b/python/ray/cross_language.py @@ -0,0 +1,84 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ray import Language +from ray._raylet import JavaFunctionDescriptor + +__all__ = [ + "java_function", + "java_actor_class", +] + + +def format_args(worker, args, kwargs): + """Format args for various languages. + + Args: + worker: The global worker instance. + args: The arguments for cross language. + kwargs: The keyword arguments for cross language. + + Returns: + List of args and kwargs (if supported). + """ + if not worker.load_code_from_local: + raise Exception("Cross language feature needs " + "--load-code-from-local to be set.") + if kwargs: + raise Exception("Cross language remote functions " + "does not support kwargs.") + return args + + +def get_function_descriptor_for_actor_method( + language, actor_creation_function_descriptor, method_name): + """Get function descriptor for cross language actor method call. + + Args: + language: Target language. + actor_creation_function_descriptor: + The function signature for actor creation. + method_name: The name of actor method. + + Returns: + Function descriptor for cross language actor method call. + """ + if language == Language.JAVA: + return JavaFunctionDescriptor( + actor_creation_function_descriptor.class_name, + method_name, + # Currently not support call actor method with signature. + "") + else: + raise NotImplementedError("Cross language remote actor method " + "not support language {}".format(language)) + + +def java_function(class_name, function_name): + from ray.remote_function import RemoteFunction + return RemoteFunction( + Language.JAVA, + lambda *args, **kwargs: None, + JavaFunctionDescriptor(class_name, function_name, ""), + None, # num_cpus, + None, # num_gpus, + None, # memory, + None, # object_store_memory, + None, # resources, + None, # num_return_vals, + None, # max_calls, + None) # max_retries) + + +def java_actor_class(class_name): + from ray.actor import ActorClass + return ActorClass._ray_from_function_descriptor( + Language.JAVA, + JavaFunctionDescriptor(class_name, "", ""), + 0, # max_reconstructions, + None, # num_cpus, + None, # num_gpus, + None, # memory, + None, # object_store_memory, + None) # resources, diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 53bd772d5..c85efa868 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -441,12 +441,6 @@ class NodeStats(threading.Thread): if addr in self._addr_to_actor_id: actor_info = flattened_tree[self._addr_to_actor_id[ addr]] - if "currentTaskFuncDesc" in core_worker_stats: - core_worker_stats[ - "currentTaskFuncDesc"] = list( - map( - b64_decode, core_worker_stats[ - "currentTaskFuncDesc"])) format_reply_id(core_worker_stats) actor_info.update(core_worker_stats) actor_info["averageTaskExecutionSpeed"] = round( @@ -464,8 +458,6 @@ class NodeStats(threading.Thread): caller_id = self._addr_to_actor_id.get(caller_addr, "root") child_to_parent[actor_id] = caller_id infeasible_task["state"] = -1 - infeasible_task["functionDescriptor"] = list( - map(b64_decode, infeasible_task["functionDescriptor"])) format_reply_id(infeasible_tasks) flattened_tree[actor_id] = infeasible_task diff --git a/python/ray/function_manager.py b/python/ray/function_manager.py index 6bebc905a..c97d851ea 100644 --- a/python/ray/function_manager.py +++ b/python/ray/function_manager.py @@ -17,8 +17,8 @@ import ray from ray import profiling from ray import ray_constants from ray import cloudpickle as pickle +from ray._raylet import PythonFunctionDescriptor from ray.utils import ( - binary_to_hex, is_function_or_method, is_class_method, is_static_method, @@ -36,226 +36,6 @@ FunctionExecutionInfo = namedtuple("FunctionExecutionInfo", logger = logging.getLogger(__name__) -class FunctionDescriptor: - """A class used to describe a python function. - - Attributes: - module_name: the module name that the function belongs to. - class_name: the class name that the function belongs to if exists. - It could be empty is the function is not a class method. - function_name: the function name of the function. - function_hash: the hash code of the function source code if the - function code is available. - function_id: the function id calculated from this descriptor. - is_for_driver_task: whether this descriptor is for driver task. - """ - - def __init__(self, - module_name, - function_name, - class_name="", - function_source_hash=b""): - self._module_name = module_name - self._class_name = class_name - self._function_name = function_name - self._function_source_hash = function_source_hash - self._function_id = self._get_function_id() - - def __repr__(self): - return ("FunctionDescriptor:" + self._module_name + "." + - self._class_name + "." + self._function_name + "." + - binary_to_hex(self._function_source_hash)) - - @classmethod - def from_bytes_list(cls, function_descriptor_list): - """Create a FunctionDescriptor instance from list of bytes. - - This function is used to create the function descriptor from - backend data. - - Args: - cls: Current class which is required argument for classmethod. - function_descriptor_list: list of bytes to represent the - function descriptor. - - Returns: - The FunctionDescriptor instance created from the bytes list. - """ - assert isinstance(function_descriptor_list, list) - if len(function_descriptor_list) == 0: - # This is a function descriptor of driver task. - return FunctionDescriptor.for_driver_task() - elif (len(function_descriptor_list) == 3 - or len(function_descriptor_list) == 4): - module_name = ensure_str(function_descriptor_list[0]) - class_name = ensure_str(function_descriptor_list[1]) - function_name = ensure_str(function_descriptor_list[2]) - if len(function_descriptor_list) == 4: - return cls(module_name, function_name, class_name, - function_descriptor_list[3]) - else: - return cls(module_name, function_name, class_name) - else: - raise Exception( - "Invalid input for FunctionDescriptor.from_bytes_list") - - @classmethod - def from_function(cls, function, pickled_function): - """Create a FunctionDescriptor from a function instance. - - This function is used to create the function descriptor from - a python function. If a function is a class function, it should - not be used by this function. - - Args: - cls: Current class which is required argument for classmethod. - function: the python function used to create the function - descriptor. - pickled_function: This is factored in to ensure that any - modifications to the function result in a different function - descriptor. - - Returns: - The FunctionDescriptor instance created according to the function. - """ - module_name = function.__module__ - function_name = function.__name__ - class_name = "" - - pickled_function_hash = hashlib.sha1(pickled_function).digest() - - return cls(module_name, function_name, class_name, - pickled_function_hash) - - @classmethod - def from_class(cls, target_class): - """Create a FunctionDescriptor from a class. - - Args: - cls: Current class which is required argument for classmethod. - target_class: the python class used to create the function - descriptor. - - Returns: - The FunctionDescriptor instance created according to the class. - """ - module_name = target_class.__module__ - class_name = target_class.__name__ - return cls(module_name, "__init__", class_name) - - @classmethod - def for_driver_task(cls): - """Create a FunctionDescriptor instance for a driver task.""" - return cls("", "", "", b"") - - @property - def is_for_driver_task(self): - """See whether this function descriptor is for a driver or not. - - Returns: - True if this function descriptor is for driver tasks. - """ - return all( - len(x) == 0 - for x in [self.module_name, self.class_name, self.function_name]) - - @property - def module_name(self): - """Get the module name of current function descriptor. - - Returns: - The module name of the function descriptor. - """ - return self._module_name - - @property - def class_name(self): - """Get the class name of current function descriptor. - - Returns: - The class name of the function descriptor. It could be - empty if the function is not a class method. - """ - return self._class_name - - @property - def function_name(self): - """Get the function name of current function descriptor. - - Returns: - The function name of the function descriptor. - """ - return self._function_name - - @property - def function_hash(self): - """Get the hash code of the function source code. - - Returns: - The bytes with length of ray_constants.ID_SIZE if the source - code is available. Otherwise, the bytes length will be 0. - """ - return self._function_source_hash - - @property - def function_id(self): - """Get the function id calculated from this descriptor. - - Returns: - The value of ray.ObjectID that represents the function id. - """ - return self._function_id - - def _get_function_id(self): - """Calculate the function id of current function descriptor. - - This function id is calculated from all the fields of function - descriptor. - - Returns: - ray.ObjectID to represent the function descriptor. - """ - if self.is_for_driver_task: - return ray.FunctionID.nil() - function_id_hash = hashlib.sha1() - # Include the function module and name in the hash. - function_id_hash.update(self.module_name.encode("ascii")) - function_id_hash.update(self.function_name.encode("ascii")) - function_id_hash.update(self.class_name.encode("ascii")) - function_id_hash.update(self._function_source_hash) - # Compute the function ID. - function_id = function_id_hash.digest() - return ray.FunctionID(function_id) - - def get_function_descriptor_list(self): - """Return a list of bytes representing the function descriptor. - - This function is used to pass this function descriptor to backend. - - Returns: - A list of bytes. - """ - descriptor_list = [] - if self.is_for_driver_task: - # Driver task returns an empty list. - return descriptor_list - else: - descriptor_list.append(self.module_name.encode("ascii")) - descriptor_list.append(self.class_name.encode("ascii")) - descriptor_list.append(self.function_name.encode("ascii")) - if len(self._function_source_hash) != 0: - descriptor_list.append(self._function_source_hash) - return descriptor_list - - def is_actor_method(self): - """Wether this function descriptor is an actor method. - - Returns: - True if it's an actor method, False if it's a normal function. - """ - return len(self._class_name) > 0 - - class FunctionActorManager: """A class used to export/load remote functions and actors. @@ -488,7 +268,7 @@ class FunctionActorManager: self._num_task_executions[job_id][function_id] = 0 except Exception: logger.exception( - "Failed to load function %s.".format(function_name)) + "Failed to load function {}.".format(function_name)) raise Exception( "Function {} failed to be loaded from local code.".format( function_descriptor)) @@ -551,10 +331,10 @@ class FunctionActorManager: self._worker.redis_client.hmset(key, actor_class_info) self._worker.redis_client.rpush("Exports", key) - def export_actor_class(self, Class, actor_method_names): + def export_actor_class(self, Class, actor_creation_function_descriptor, + actor_method_names): if self._worker.load_code_from_local: return - function_descriptor = FunctionDescriptor.from_class(Class) # `current_job_id` shouldn't be NIL, unless: # 1) This worker isn't an actor; # 2) And a previous task started a background thread, which didn't @@ -565,10 +345,10 @@ class FunctionActorManager: "please make sure the thread finishes before the task finishes.") job_id = self._worker.current_job_id key = (b"ActorClass:" + job_id.binary() + b":" + - function_descriptor.function_id.binary()) + actor_creation_function_descriptor.function_id.binary()) actor_class_info = { - "class_name": Class.__name__, - "module": Class.__module__, + "class_name": actor_creation_function_descriptor.class_name, + "module": actor_creation_function_descriptor.module_name, "class": pickle.dumps(Class), "job_id": job_id.binary(), "collision_identifier": self.compute_collision_identifier(Class), @@ -617,7 +397,7 @@ class FunctionActorManager: actor_methods = inspect.getmembers( actor_class, predicate=is_function_or_method) for actor_method_name, actor_method in actor_methods: - method_descriptor = FunctionDescriptor( + method_descriptor = PythonFunctionDescriptor( module_name, actor_method_name, actor_class_name) method_id = method_descriptor.function_id executor = self._make_actor_method_executor( diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index 9b54f9178..cfbb72d11 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -13,6 +13,9 @@ from ray.includes.unique_ids cimport ( CObjectID, CTaskID, ) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) cdef extern from * namespace "polyfill": @@ -201,9 +204,9 @@ cdef extern from "ray/core_worker/common.h" nogil: cdef cppclass CRayFunction "ray::RayFunction": CRayFunction() CRayFunction(CLanguage language, - const c_vector[c_string] function_descriptor) + const CFunctionDescriptor &function_descriptor) CLanguage GetLanguage() - const c_vector[c_string]& GetFunctionDescriptor() + const CFunctionDescriptor GetFunctionDescriptor() cdef cppclass CTaskArg "ray::TaskArg": @staticmethod diff --git a/python/ray/includes/function_descriptor.pxd b/python/ray/includes/function_descriptor.pxd new file mode 100644 index 000000000..c1d516511 --- /dev/null +++ b/python/ray/includes/function_descriptor.pxd @@ -0,0 +1,69 @@ +from libc.stdint cimport uint8_t, uint64_t +from libcpp cimport bool as c_bool +from libcpp.memory cimport unique_ptr, shared_ptr +from libcpp.string cimport string as c_string +from libcpp.unordered_map cimport unordered_map +from libcpp.vector cimport vector as c_vector + +from ray.includes.common cimport ( + CLanguage, + ResourceSet, +) +from ray.includes.unique_ids cimport ( + CActorID, + CJobID, + CObjectID, + CTaskID, +) + +cdef extern from "ray/protobuf/common.pb.h" nogil: + cdef cppclass CFunctionDescriptorType \ + "ray::FunctionDescriptorType": + pass + + cdef CFunctionDescriptorType EmptyFunctionDescriptorType \ + "ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET" + cdef CFunctionDescriptorType JavaFunctionDescriptorType \ + "ray::FunctionDescriptorType::kJavaFunctionDescriptor" + cdef CFunctionDescriptorType PythonFunctionDescriptorType \ + "ray::FunctionDescriptorType::kPythonFunctionDescriptor" + + +cdef extern from "ray/common/function_descriptor.h" nogil: + cdef cppclass CFunctionDescriptorInterface \ + "ray::CFunctionDescriptorInterface": + CFunctionDescriptorType Type() + c_string ToString() + c_string Serialize() + + ctypedef shared_ptr[CFunctionDescriptorInterface] CFunctionDescriptor \ + "ray::FunctionDescriptor" + + cdef cppclass CFunctionDescriptorBuilder "ray::FunctionDescriptorBuilder": + @staticmethod + CFunctionDescriptor Empty() + + @staticmethod + CFunctionDescriptor BuildJava(const c_string &class_name, + const c_string &function_name, + const c_string &signature) + + @staticmethod + CFunctionDescriptor BuildPython(const c_string &module_name, + const c_string &class_name, + const c_string &function_name, + const c_string &function_source_hash) + + @staticmethod + CFunctionDescriptor Deserialize(const c_string &serialized_binary) + + cdef cppclass CJavaFunctionDescriptor "ray::JavaFunctionDescriptor": + c_string ClassName() + c_string FunctionName() + c_string Signature() + + cdef cppclass CPythonFunctionDescriptor "ray::PythonFunctionDescriptor": + c_string ModuleName() + c_string ClassName() + c_string FunctionName() + c_string FunctionHash() diff --git a/python/ray/includes/function_descriptor.pxi b/python/ray/includes/function_descriptor.pxi new file mode 100644 index 000000000..995d0efca --- /dev/null +++ b/python/ray/includes/function_descriptor.pxi @@ -0,0 +1,285 @@ +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, + CFunctionDescriptorBuilder, + CPythonFunctionDescriptor, + CJavaFunctionDescriptor, + EmptyFunctionDescriptorType, + JavaFunctionDescriptorType, + PythonFunctionDescriptorType, +) + +import hashlib +import cython +import inspect + + +ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &) +cdef unordered_map[int, FunctionDescriptor_from_cpp] \ + FunctionDescriptor_constructor_map +cdef CFunctionDescriptorToPython(CFunctionDescriptor function_descriptor): + cdef int function_descriptor_type = function_descriptor.get().Type() + it = FunctionDescriptor_constructor_map.find(function_descriptor_type) + if it == FunctionDescriptor_constructor_map.end(): + raise Exception("Can't construct FunctionDescriptor from type {}" + .format(function_descriptor_type)) + else: + constructor = dereference(it).second + return constructor(function_descriptor) + + +@cython.auto_pickle(False) +cdef class FunctionDescriptor: + def __cinit__(self, *args, **kwargs): + if type(self) == FunctionDescriptor: + raise Exception("type {} is abstract".format(type(self).__name__)) + + def __hash__(self): + return hash(self.descriptor.get().ToString()) + + def __eq__(self, other): + return (type(self) == type(other) and + self.descriptor.get().ToString() == + (other).descriptor.get().ToString()) + + def __repr__(self): + return self.descriptor.get().ToString() + + def to_dict(self): + d = {"type": type(self).__name__} + for k, v in vars(type(self)).items(): + if inspect.isgetsetdescriptor(v): + d[k] = v.__get__(self) + return d + + +FunctionDescriptor_constructor_map[EmptyFunctionDescriptorType] = \ + EmptyFunctionDescriptor.from_cpp + + +@cython.auto_pickle(False) +cdef class EmptyFunctionDescriptor(FunctionDescriptor): + def __cinit__(self): + self.descriptor = CFunctionDescriptorBuilder.Empty() + + def __reduce__(self): + return EmptyFunctionDescriptor, () + + @staticmethod + cdef from_cpp(const CFunctionDescriptor &c_function_descriptor): + return EmptyFunctionDescriptor() + + +FunctionDescriptor_constructor_map[JavaFunctionDescriptorType] = \ + JavaFunctionDescriptor.from_cpp + + +@cython.auto_pickle(False) +cdef class JavaFunctionDescriptor(FunctionDescriptor): + cdef: + CJavaFunctionDescriptor *typed_descriptor + + def __cinit__(self, + class_name, + function_name, + signature): + self.descriptor = CFunctionDescriptorBuilder.BuildJava( + class_name, function_name, signature) + self.typed_descriptor = ( + self.descriptor.get()) + + def __reduce__(self): + return JavaFunctionDescriptor, (self.typed_descriptor.ClassName(), + self.typed_descriptor.FunctionName(), + self.typed_descriptor.Signature()) + + @staticmethod + cdef from_cpp(const CFunctionDescriptor &c_function_descriptor): + cdef CJavaFunctionDescriptor *typed_descriptor = \ + (c_function_descriptor.get()) + return JavaFunctionDescriptor(typed_descriptor.ClassName(), + typed_descriptor.FunctionName(), + typed_descriptor.Signature()) + + @property + def class_name(self): + """Get the class name of current function descriptor. + + Returns: + The class name of the function descriptor. It could be + empty if the function is not a class method. + """ + return self.typed_descriptor.ClassName() + + @property + def function_name(self): + """Get the function name of current function descriptor. + + Returns: + The function name of the function descriptor. + """ + return self.typed_descriptor.FunctionName() + + @property + def signature(self): + """Get the signature of current function descriptor. + + Returns: + The signature of the function descriptor. + """ + return self.typed_descriptor.Signature() + + +FunctionDescriptor_constructor_map[PythonFunctionDescriptorType] = \ + PythonFunctionDescriptor.from_cpp + + +@cython.auto_pickle(False) +cdef class PythonFunctionDescriptor(FunctionDescriptor): + cdef: + CPythonFunctionDescriptor *typed_descriptor + object _function_id + + def __cinit__(self, + module_name, + function_name, + class_name="", + function_source_hash=""): + self.descriptor = CFunctionDescriptorBuilder.BuildPython( + module_name, class_name, function_name, function_source_hash) + self.typed_descriptor = ( + self.descriptor.get()) + + def __reduce__(self): + return PythonFunctionDescriptor, (self.typed_descriptor.ModuleName(), + self.typed_descriptor.FunctionName(), + self.typed_descriptor.ClassName(), + self.typed_descriptor.FunctionHash()) + + @staticmethod + cdef from_cpp(const CFunctionDescriptor &c_function_descriptor): + cdef CPythonFunctionDescriptor *typed_descriptor = \ + (c_function_descriptor.get()) + return PythonFunctionDescriptor(typed_descriptor.ModuleName(), + typed_descriptor.FunctionName(), + typed_descriptor.ClassName(), + typed_descriptor.FunctionHash()) + + @classmethod + def from_function(cls, function, pickled_function): + """Create a FunctionDescriptor from a function instance. + + This function is used to create the function descriptor from + a python function. If a function is a class function, it should + not be used by this function. + + Args: + cls: Current class which is required argument for classmethod. + function: the python function used to create the function + descriptor. + pickled_function: This is factored in to ensure that any + modifications to the function result in a different function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the function. + """ + module_name = function.__module__ + function_name = function.__name__ + class_name = "" + + pickled_function_hash = hashlib.sha1(pickled_function).hexdigest() + + return cls(module_name, function_name, class_name, + pickled_function_hash) + + @classmethod + def from_class(cls, target_class): + """Create a FunctionDescriptor from a class. + + Args: + cls: Current class which is required argument for classmethod. + target_class: the python class used to create the function + descriptor. + + Returns: + The FunctionDescriptor instance created according to the class. + """ + module_name = target_class.__module__ + class_name = target_class.__name__ + return cls(module_name, "__init__", class_name) + + @property + def module_name(self): + """Get the module name of current function descriptor. + + Returns: + The module name of the function descriptor. + """ + return self.typed_descriptor.ModuleName() + + @property + def class_name(self): + """Get the class name of current function descriptor. + + Returns: + The class name of the function descriptor. It could be + empty if the function is not a class method. + """ + return self.typed_descriptor.ClassName() + + @property + def function_name(self): + """Get the function name of current function descriptor. + + Returns: + The function name of the function descriptor. + """ + return self.typed_descriptor.FunctionName() + + @property + def function_hash(self): + """Get the hash string of the function source code. + + Returns: + The hex of function hash if the source code is available. + Otherwise, it will be an empty string. + """ + return self.typed_descriptor.FunctionHash() + + @property + def function_id(self): + """Get the function id calculated from this descriptor. + + Returns: + The value of ray.ObjectID that represents the function id. + """ + if not self._function_id: + self._function_id = self._get_function_id() + return self._function_id + + def _get_function_id(self): + """Calculate the function id of current function descriptor. + + This function id is calculated from all the fields of function + descriptor. + + Returns: + ray.ObjectID to represent the function descriptor. + """ + function_id_hash = hashlib.sha1() + # Include the function module and name in the hash. + function_id_hash.update(self.typed_descriptor.ModuleName()) + function_id_hash.update(self.typed_descriptor.FunctionName()) + function_id_hash.update(self.typed_descriptor.ClassName()) + function_id_hash.update(self.typed_descriptor.FunctionHash()) + # Compute the function ID. + function_id = function_id_hash.digest() + return ray.FunctionID(function_id) + + def is_actor_method(self): + """Wether this function descriptor is an actor method. + + Returns: + True if it's an actor method, False if it's a normal function. + """ + return not self.typed_descriptor.ClassName().empty() diff --git a/python/ray/includes/task.pxd b/python/ray/includes/task.pxd index a46ed2162..d19387279 100644 --- a/python/ray/includes/task.pxd +++ b/python/ray/includes/task.pxd @@ -15,6 +15,9 @@ from ray.includes.unique_ids cimport ( CObjectID, CTaskID, ) +from ray.includes.function_descriptor cimport ( + CFunctionDescriptor, +) cdef extern from "ray/protobuf/common.pb.h" nogil: cdef cppclass RpcTaskSpec "ray::rpc::TaskSpec": @@ -44,7 +47,7 @@ cdef extern from "ray/common/task/task_spec.h" nogil: CJobID JobId() const CTaskID ParentTaskId() const uint64_t ParentCounter() const - c_vector[c_string] FunctionDescriptor() const + CFunctionDescriptor FunctionDescriptor() const c_string FunctionDescriptorString() const uint64_t NumArgs() const uint64_t NumReturns() const diff --git a/python/ray/includes/task.pxi b/python/ray/includes/task.pxi index 0bfaa5012..3512b7716 100644 --- a/python/ray/includes/task.pxi +++ b/python/ray/includes/task.pxi @@ -64,14 +64,10 @@ cdef class TaskSpec: """Return the parent counter of this task.""" return self.task_spec.get().ParentCounter() - def function_descriptor_list(self): + def function_descriptor(self): """Return the function descriptor for this task.""" - cdef c_vector[c_string] function_descriptor = ( + return CFunctionDescriptorToPython( self.task_spec.get().FunctionDescriptor()) - results = [] - for i in range(function_descriptor.size()): - results.append(function_descriptor[i]) - return results def arguments(self): """Return the arguments for the task.""" diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index a8b10e903..36014dd89 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -3,7 +3,8 @@ from functools import wraps from ray import cloudpickle as pickle from ray import ray_constants -from ray.function_manager import FunctionDescriptor +from ray._raylet import PythonFunctionDescriptor +from ray import cross_language, Language import ray.signature # Default parameters for remote functions. @@ -23,6 +24,7 @@ class RemoteFunction: This is a decorated function. It can be used to spawn tasks. Attributes: + _language: The target language. _function: The original function. _function_descriptor: The function descriptor. This is not defined until the remote function is first invoked because that is when the @@ -57,12 +59,15 @@ class RemoteFunction: different workers. """ - def __init__(self, function, num_cpus, num_gpus, memory, - object_store_memory, resources, num_return_vals, max_calls, - max_retries): + def __init__(self, language, function, function_descriptor, num_cpus, + num_gpus, memory, object_store_memory, resources, + num_return_vals, max_calls, max_retries): + self._language = language self._function = function self._function_name = ( self._function.__module__ + "." + self._function.__name__) + self._function_descriptor = function_descriptor + self._is_cross_language = language != Language.PYTHON self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS if num_cpus is None else num_cpus) self._num_gpus = num_gpus @@ -80,11 +85,11 @@ class RemoteFunction: if max_retries is None else max_retries) self._decorator = getattr(function, "__ray_invocation_decorator__", None) - self._function_signature = ray.signature.extract_signature( self._function) self._last_export_session_and_job = None + # Override task.remote's signature and docstring @wraps(function) def _remote_proxy(*args, **kwargs): @@ -152,7 +157,9 @@ class RemoteFunction: # If this function was not exported in this session and job, we need to # export this function again, because the current GCS doesn't have it. - if self._last_export_session_and_job != worker.current_session_and_job: + if not self._is_cross_language and \ + self._last_export_session_and_job != \ + worker.current_session_and_job: # There is an interesting question here. If the remote function is # used by a subsequent driver (in the same script), should the # second driver pickle the function again? If yes, then the remote @@ -164,10 +171,8 @@ class RemoteFunction: # which we do here. self._pickled_function = pickle.dumps(self._function) - self._function_descriptor = FunctionDescriptor.from_function( + self._function_descriptor = PythonFunctionDescriptor.from_function( self._function, self._pickled_function) - self._function_descriptor_list = ( - self._function_descriptor.get_function_descriptor_list()) self._last_export_session_and_job = worker.current_session_and_job worker.function_actor_manager.export(self) @@ -188,20 +193,25 @@ class RemoteFunction: memory, object_store_memory, resources) def invocation(args, kwargs): - if not args and not kwargs and not self._function_signature: + if self._is_cross_language: + list_args = cross_language.format_args(worker, args, kwargs) + elif not args and not kwargs and not self._function_signature: list_args = [] else: list_args = ray.signature.flatten_args( self._function_signature, args, kwargs) if worker.mode == ray.worker.LOCAL_MODE: + assert not self._is_cross_language, \ + "Cross language remote function " \ + "cannot be executed locally." object_ids = worker.local_mode_manager.execute( self._function, self._function_descriptor, args, kwargs, num_return_vals) else: object_ids = worker.core_worker.submit_task( - self._function_descriptor_list, list_args, num_return_vals, - is_direct_call, resources, max_retries) + self._language, self._function_descriptor, list_args, + num_return_vals, is_direct_call, resources, max_retries) if len(object_ids) == 1: return object_ids[0] diff --git a/python/ray/state.py b/python/ray/state.py index e7b9a4acd..31d172861 100644 --- a/python/ray/state.py +++ b/python/ray/state.py @@ -5,7 +5,6 @@ import sys import time import ray -from ray.function_manager import FunctionDescriptor from ray import ( gcs_utils, @@ -393,9 +392,7 @@ class GlobalState: task = ray._raylet.TaskSpec.from_string( task_table_data.task.task_spec.SerializeToString()) - function_descriptor_list = task.function_descriptor_list() - function_descriptor = FunctionDescriptor.from_bytes_list( - function_descriptor_list) + function_descriptor = task.function_descriptor() task_spec_info = { "JobID": task.job_id().hex(), @@ -412,11 +409,7 @@ class GlobalState: "Args": task.arguments(), "ReturnObjectIDs": task.returns(), "RequiredResources": task.required_resources(), - "FunctionID": function_descriptor.function_id.hex(), - "FunctionHash": binary_to_hex(function_descriptor.function_hash), - "ModuleName": function_descriptor.module_name, - "ClassName": function_descriptor.class_name, - "FunctionName": function_descriptor.function_name, + "FunctionDescriptor": function_descriptor.to_dict(), } execution_spec = ray._raylet.TaskExecutionSpec.from_string( diff --git a/python/ray/tests/test_advanced_3.py b/python/ray/tests/test_advanced_3.py index 886478683..b5cdfcdef 100644 --- a/python/ray/tests/test_advanced_3.py +++ b/python/ray/tests/test_advanced_3.py @@ -151,14 +151,13 @@ def test_global_state_api(shutdown_only): assert len(task_table) == 1 assert driver_task_id == list(task_table.keys())[0] task_spec = task_table[driver_task_id]["TaskSpec"] - nil_unique_id_hex = ray.UniqueID.nil().hex() nil_actor_id_hex = ray.ActorID.nil().hex() assert task_spec["TaskID"] == driver_task_id assert task_spec["ActorID"] == nil_actor_id_hex assert task_spec["Args"] == [] assert task_spec["JobID"] == job_id.hex() - assert task_spec["FunctionID"] == nil_unique_id_hex + assert task_spec["FunctionDescriptor"]["type"] == "EmptyFunctionDescriptor" assert task_spec["ReturnObjectIDs"] == [] client_table = ray.nodes() @@ -172,7 +171,7 @@ def test_global_state_api(shutdown_only): def __init__(self): pass - _ = Actor.remote() + _ = Actor.remote() # noqa: F841 # Wait for actor to be created wait_for_num_actors(1) diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 2ce9d5bdb..0b5d027d4 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -9,6 +9,7 @@ import string import sys import threading import time +import pickle import numpy as np import pytest @@ -351,6 +352,23 @@ def test_complex_serialization_with_pickle(shutdown_only): complex_serialization(use_pickle=True) +def test_function_descriptor(): + python_descriptor = ray._raylet.PythonFunctionDescriptor( + "module_name", "function_name", "class_name", "function_hash") + python_descriptor2 = pickle.loads(pickle.dumps(python_descriptor)) + assert python_descriptor == python_descriptor2 + assert hash(python_descriptor) == hash(python_descriptor2) + assert python_descriptor.function_id == python_descriptor2.function_id + java_descriptor = ray._raylet.JavaFunctionDescriptor( + "class_name", "function_name", "signature") + java_descriptor2 = pickle.loads(pickle.dumps(java_descriptor)) + assert java_descriptor == java_descriptor2 + assert python_descriptor != java_descriptor + assert python_descriptor != object() + d = {python_descriptor: 123} + assert d.get(python_descriptor2) == 123 + + def test_nested_functions(ray_start_regular): # Make sure that remote functions can use other values that are defined # after the remote function but before the first function invocation. diff --git a/python/ray/tests/test_cross_language.py b/python/ray/tests/test_cross_language.py new file mode 100644 index 000000000..13323c9bd --- /dev/null +++ b/python/ray/tests/test_cross_language.py @@ -0,0 +1,15 @@ +import pytest + +import ray +import ray.cluster_utils +import ray.test_utils + + +def test_cross_language_raise_kwargs(shutdown_only): + ray.init(load_code_from_local=True, include_java=True) + + with pytest.raises(Exception, match="kwargs"): + ray.java_function("a", "b").remote(x="arg1") + + with pytest.raises(Exception, match="kwargs"): + ray.java_actor_class("a").remote(x="arg1") diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index 56bb05522..648119032 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -189,7 +189,13 @@ def test_raylet_info_endpoint(shutdown_only): try: webui_url = addresses["webui_url"] webui_url = webui_url.replace("localhost", "http://127.0.0.1") - raylet_info = requests.get(webui_url + "/api/raylet_info").json() + response = requests.get(webui_url + "/api/raylet_info") + response.raise_for_status() + try: + raylet_info = response.json() + except Exception as ex: + print("failed response: {}".format(response.text)) + raise ex actor_info = raylet_info["result"]["actors"] try: assert len(actor_info) == 1 diff --git a/python/ray/worker.py b/python/ray/worker.py index 3820fa68f..c9fffa23b 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -36,6 +36,7 @@ from ray import ( ActorID, JobID, ObjectID, + Language, ) from ray import import_thread from ray import profiling @@ -553,6 +554,7 @@ def init(address=None, redis_password=ray_constants.REDIS_DEFAULT_PASSWORD, plasma_directory=None, huge_pages=False, + include_java=False, include_webui=None, webui_host="localhost", job_id=None, @@ -632,6 +634,7 @@ def init(address=None, be created. huge_pages: Boolean flag indicating whether to start the Object Store with hugetlbfs support. Requires plasma_directory. + include_java: Boolean flag indicating whether to enable java worker. include_webui: Boolean flag indicating whether to start the web UI, which displays the status of the Ray cluster. If this argument is None, then the UI will be started if the relevant dependencies @@ -725,6 +728,7 @@ def init(address=None, redis_password=redis_password, plasma_directory=plasma_directory, huge_pages=huge_pages, + include_java=include_java, include_webui=include_webui, webui_host=webui_host, memory=memory, @@ -1684,9 +1688,9 @@ def make_decorator(num_return_vals=None, "allowed for remote functions.") return ray.remote_function.RemoteFunction( - function_or_class, num_cpus, num_gpus, memory, - object_store_memory, resources, num_return_vals, max_calls, - max_retries) + Language.PYTHON, function_or_class, None, num_cpus, num_gpus, + memory, object_store_memory, resources, num_return_vals, + max_calls, max_retries) if inspect.isclass(function_or_class): if num_return_vals is not None: diff --git a/src/ray/common/function_descriptor.cc b/src/ray/common/function_descriptor.cc new file mode 100644 index 000000000..34b9eb5d7 --- /dev/null +++ b/src/ray/common/function_descriptor.cc @@ -0,0 +1,78 @@ +#include "ray/common/function_descriptor.h" + +namespace ray { +FunctionDescriptor FunctionDescriptorBuilder::Empty() { + static ray::FunctionDescriptor empty = + ray::FunctionDescriptor(new EmptyFunctionDescriptor()); + return empty; +} + +FunctionDescriptor FunctionDescriptorBuilder::BuildJava(const std::string &class_name, + const std::string &function_name, + const std::string &signature) { + rpc::FunctionDescriptor descriptor; + auto typed_descriptor = descriptor.mutable_java_function_descriptor(); + typed_descriptor->set_class_name(class_name); + typed_descriptor->set_function_name(function_name); + typed_descriptor->set_signature(signature); + return ray::FunctionDescriptor(new JavaFunctionDescriptor(std::move(descriptor))); +} + +FunctionDescriptor FunctionDescriptorBuilder::BuildPython( + const std::string &module_name, const std::string &class_name, + const std::string &function_name, const std::string &function_hash) { + rpc::FunctionDescriptor descriptor; + auto typed_descriptor = descriptor.mutable_python_function_descriptor(); + typed_descriptor->set_module_name(module_name); + typed_descriptor->set_class_name(class_name); + typed_descriptor->set_function_name(function_name); + typed_descriptor->set_function_hash(function_hash); + return ray::FunctionDescriptor(new PythonFunctionDescriptor(std::move(descriptor))); +} + +FunctionDescriptor FunctionDescriptorBuilder::FromProto(rpc::FunctionDescriptor message) { + switch (message.function_descriptor_case()) { + case ray::FunctionDescriptorType::kJavaFunctionDescriptor: + return ray::FunctionDescriptor(new ray::JavaFunctionDescriptor(std::move(message))); + case ray::FunctionDescriptorType::kPythonFunctionDescriptor: + return ray::FunctionDescriptor(new ray::PythonFunctionDescriptor(std::move(message))); + default: + break; + } + RAY_LOG(DEBUG) << "Unknown function descriptor case: " + << message.function_descriptor_case(); + // When TaskSpecification() constructed without function_descriptor set, + // we should return a valid ray::FunctionDescriptor instance. + return FunctionDescriptorBuilder::Empty(); +} + +FunctionDescriptor FunctionDescriptorBuilder::FromVector( + rpc::Language language, const std::vector &function_descriptor_list) { + if (language == rpc::Language::JAVA) { + RAY_CHECK(function_descriptor_list.size() == 3); + return FunctionDescriptorBuilder::BuildJava( + function_descriptor_list[0], // class name + function_descriptor_list[1], // function name + function_descriptor_list[2] // signature + ); + } else if (language == rpc::Language::PYTHON) { + RAY_CHECK(function_descriptor_list.size() == 4); + return FunctionDescriptorBuilder::BuildPython( + function_descriptor_list[0], // module name + function_descriptor_list[1], // class name + function_descriptor_list[2], // function name + function_descriptor_list[3] // function hash + ); + } else { + RAY_LOG(FATAL) << "Unspported language " << language; + return FunctionDescriptorBuilder::Empty(); + } +} + +FunctionDescriptor FunctionDescriptorBuilder::Deserialize( + const std::string &serialized_binary) { + rpc::FunctionDescriptor descriptor; + descriptor.ParseFromString(serialized_binary); + return FunctionDescriptorBuilder::FromProto(std::move(descriptor)); +} +} // namespace ray diff --git a/src/ray/common/function_descriptor.h b/src/ray/common/function_descriptor.h new file mode 100644 index 000000000..59108afd1 --- /dev/null +++ b/src/ray/common/function_descriptor.h @@ -0,0 +1,189 @@ +#ifndef RAY_CORE_WORKER_FUNCTION_DESCRIPTOR_H +#define RAY_CORE_WORKER_FUNCTION_DESCRIPTOR_H + +#include + +#include "ray/common/grpc_util.h" +#include "ray/protobuf/common.pb.h" + +namespace ray { +/// See `common.proto` for definition of `FunctionDescriptor` oneof type. +using FunctionDescriptorType = rpc::FunctionDescriptor::FunctionDescriptorCase; +/// Wrap a protobuf message. +class FunctionDescriptorInterface : public MessageWrapper { + public: + /// Construct an empty FunctionDescriptor. + FunctionDescriptorInterface() : MessageWrapper() {} + + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. + /// + /// \param message The protobuf message. + FunctionDescriptorInterface(rpc::FunctionDescriptor message) + : MessageWrapper(std::move(message)) {} + + ray::FunctionDescriptorType Type() const { + return message_->function_descriptor_case(); + } + + virtual size_t Hash() const = 0; + + virtual std::string ToString() const = 0; + + template + Subtype *As() { + return reinterpret_cast(this); + } +}; + +class EmptyFunctionDescriptor : public FunctionDescriptorInterface { + public: + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. + /// + /// \param message The protobuf message. + explicit EmptyFunctionDescriptor() : FunctionDescriptorInterface() { + RAY_CHECK(message_->function_descriptor_case() == + ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET); + } + + virtual size_t Hash() const { + return std::hash()(ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET); + } + + virtual std::string ToString() const { return "{type=EmptyFunctionDescriptor}"; } +}; + +class JavaFunctionDescriptor : public FunctionDescriptorInterface { + public: + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. + /// + /// \param message The protobuf message. + explicit JavaFunctionDescriptor(rpc::FunctionDescriptor message) + : FunctionDescriptorInterface(std::move(message)) { + RAY_CHECK(message_->function_descriptor_case() == + ray::FunctionDescriptorType::kJavaFunctionDescriptor); + typed_message_ = &(message_->java_function_descriptor()); + } + + virtual size_t Hash() const { + return std::hash()(ray::FunctionDescriptorType::kJavaFunctionDescriptor) ^ + std::hash()(typed_message_->class_name()) ^ + std::hash()(typed_message_->function_name()) ^ + std::hash()(typed_message_->signature()); + } + + virtual std::string ToString() const { + return "{type=JavaFunctionDescriptor, class_name=" + typed_message_->class_name() + + ", function_name=" + typed_message_->function_name() + + ", signature=" + typed_message_->signature() + "}"; + } + + std::string ClassName() const { return typed_message_->class_name(); } + + std::string FunctionName() const { return typed_message_->function_name(); } + + std::string Signature() const { return typed_message_->signature(); } + + private: + const rpc::JavaFunctionDescriptor *typed_message_; +}; + +class PythonFunctionDescriptor : public FunctionDescriptorInterface { + public: + /// Construct from a protobuf message object. + /// The input message will be **copied** into this object. + /// + /// \param message The protobuf message. + explicit PythonFunctionDescriptor(rpc::FunctionDescriptor message) + : FunctionDescriptorInterface(std::move(message)) { + RAY_CHECK(message_->function_descriptor_case() == + ray::FunctionDescriptorType::kPythonFunctionDescriptor); + typed_message_ = &(message_->python_function_descriptor()); + } + + virtual size_t Hash() const { + return std::hash()(ray::FunctionDescriptorType::kPythonFunctionDescriptor) ^ + std::hash()(typed_message_->module_name()) ^ + std::hash()(typed_message_->class_name()) ^ + std::hash()(typed_message_->function_name()) ^ + std::hash()(typed_message_->function_hash()); + } + + virtual std::string ToString() const { + return "{type=PythonFunctionDescriptor, module_name=" + + typed_message_->module_name() + + ", class_name=" + typed_message_->class_name() + + ", function_name=" + typed_message_->function_name() + + ", function_hash=" + typed_message_->function_hash() + "}"; + } + + std::string ModuleName() const { return typed_message_->module_name(); } + + std::string ClassName() const { return typed_message_->class_name(); } + + std::string FunctionName() const { return typed_message_->function_name(); } + + std::string FunctionHash() const { return typed_message_->function_hash(); } + + private: + const rpc::PythonFunctionDescriptor *typed_message_; +}; + +typedef std::shared_ptr FunctionDescriptor; + +inline bool operator==(const FunctionDescriptor &left, const FunctionDescriptor &right) { + if (left.get() != nullptr && right.get() != nullptr && left->Type() == right->Type() && + left->ToString() == right->ToString()) { + return true; + } + return left.get() == right.get(); +} + +inline bool operator!=(const FunctionDescriptor &left, const FunctionDescriptor &right) { + return !(left == right); +} + +/// Helper class for building a `FunctionDescriptor` object. +class FunctionDescriptorBuilder { + public: + /// Build an EmptyFunctionDescriptor. + /// + /// \return a ray::EmptyFunctionDescriptor + static FunctionDescriptor Empty(); + + /// Build a JavaFunctionDescriptor. + /// + /// \return a ray::JavaFunctionDescriptor + static FunctionDescriptor BuildJava(const std::string &class_name, + const std::string &function_name, + const std::string &signature); + + /// Build a PythonFunctionDescriptor. + /// + /// \return a ray::PythonFunctionDescriptor + static FunctionDescriptor BuildPython(const std::string &module_name, + const std::string &class_name, + const std::string &function_name, + const std::string &function_hash); + + /// Build a ray::FunctionDescriptor according to input message. + /// + /// \return new ray::FunctionDescriptor + static FunctionDescriptor FromProto(rpc::FunctionDescriptor message); + + /// Build a ray::FunctionDescriptor from language and vector. + /// + /// \return new ray::FunctionDescriptor + static FunctionDescriptor FromVector( + rpc::Language language, const std::vector &function_descriptor_list); + + /// Build a ray::FunctionDescriptor from serialized binary. + /// + /// \return new ray::FunctionDescriptor + static FunctionDescriptor Deserialize(const std::string &serialized_binary); +}; +} // namespace ray + +#endif diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index 4aea8b4bf..8ad6f11ad 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -75,8 +75,8 @@ TaskID TaskSpecification::ParentTaskId() const { size_t TaskSpecification::ParentCounter() const { return message_->parent_counter(); } -std::vector TaskSpecification::FunctionDescriptor() const { - return VectorFromProtobuf(message_->function_descriptor()); +ray::FunctionDescriptor TaskSpecification::FunctionDescriptor() const { + return ray::FunctionDescriptorBuilder::FromProto(message_->function_descriptor()); } const SchedulingClass TaskSpecification::GetSchedulingClass() const { @@ -145,8 +145,7 @@ const ResourceSet &TaskSpecification::GetRequiredPlacementResources() const { } bool TaskSpecification::IsDriverTask() const { - // Driver tasks are empty tasks that have no function ID set. - return FunctionDescriptor().empty(); + return message_->type() == TaskType::DRIVER_TASK; } Language TaskSpecification::GetLanguage() const { return message_->language(); } @@ -249,15 +248,7 @@ std::string TaskSpecification::DebugString() const { << ", function_descriptor="; // Print function descriptor. - const auto list = VectorFromProtobuf(message_->function_descriptor()); - // The 4th is the code hash which is binary bits. No need to output it. - const size_t size = std::min(static_cast(3), list.size()); - for (size_t i = 0; i < size; ++i) { - if (i != 0) { - stream << ","; - } - stream << list[i]; - } + stream << FunctionDescriptor()->ToString(); stream << ", task_id=" << TaskId() << ", job_id=" << JobId() << ", num_args=" << NumArgs() << ", num_returns=" << NumReturns(); diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index 9f4f7d1ee..9065255e4 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -7,6 +7,7 @@ #include #include "absl/synchronization/mutex.h" +#include "ray/common/function_descriptor.h" #include "ray/common/grpc_util.h" #include "ray/common/id.h" #include "ray/common/task/scheduling_resources.h" @@ -17,9 +18,7 @@ extern "C" { } namespace ray { - -typedef std::vector FunctionDescriptor; -typedef std::pair SchedulingClassDescriptor; +typedef std::pair SchedulingClassDescriptor; typedef int SchedulingClass; /// Wrapper class of protobuf `TaskSpec`, see `common.proto` for details. @@ -63,7 +62,7 @@ class TaskSpecification : public MessageWrapper { size_t ParentCounter() const; - std::vector FunctionDescriptor() const; + ray::FunctionDescriptor FunctionDescriptor() const; size_t NumArgs() const; @@ -202,9 +201,7 @@ template <> struct hash { size_t operator()(ray::SchedulingClassDescriptor const &k) const { size_t seed = std::hash()(k.first); - for (const auto &str : k.second) { - seed ^= std::hash()(str); - } + seed ^= k.second->Hash(); return seed; } }; diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c383a9ad5..c7ab30952 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -25,16 +25,14 @@ class TaskSpecBuilder { /// \return Reference to the builder object itself. TaskSpecBuilder &SetCommonTaskSpec( const TaskID &task_id, const Language &language, - const std::vector &function_descriptor, const JobID &job_id, + const ray::FunctionDescriptor &function_descriptor, const JobID &job_id, const TaskID &parent_task_id, uint64_t parent_counter, const TaskID &caller_id, const rpc::Address &caller_address, uint64_t num_returns, bool is_direct_call, const std::unordered_map &required_resources, const std::unordered_map &required_placement_resources) { message_->set_type(TaskType::NORMAL_TASK); message_->set_language(language); - for (const auto &fd : function_descriptor) { - message_->add_function_descriptor(fd); - } + *message_->mutable_function_descriptor() = function_descriptor->GetMessage(); message_->set_job_id(job_id.Binary()); message_->set_task_id(task_id.Binary()); message_->set_parent_task_id(parent_task_id.Binary()); @@ -50,6 +48,27 @@ class TaskSpecBuilder { return *this; } + /// Set the driver attributes of the task spec. + /// See `common.proto` for meaning of the arguments. + /// + /// \return Reference to the builder object itself. + TaskSpecBuilder &SetDriverTaskSpec(const TaskID &task_id, const Language &language, + const JobID &job_id, const TaskID &parent_task_id, + const TaskID &caller_id, + const rpc::Address &caller_address) { + message_->set_type(TaskType::DRIVER_TASK); + message_->set_language(language); + message_->set_job_id(job_id.Binary()); + message_->set_task_id(task_id.Binary()); + message_->set_parent_task_id(parent_task_id.Binary()); + message_->set_parent_counter(0); + message_->set_caller_id(caller_id.Binary()); + message_->mutable_caller_address()->CopyFrom(caller_address); + message_->set_num_returns(0); + message_->set_is_direct_call(false); + return *this; + } + /// Add a by-reference argument to the task. /// /// \param arg_id Id of the argument. diff --git a/src/ray/core_worker/actor_handle.cc b/src/ray/core_worker/actor_handle.cc index 461a0ab14..9bcd2daca 100644 --- a/src/ray/core_worker/actor_handle.cc +++ b/src/ray/core_worker/actor_handle.cc @@ -7,14 +7,13 @@ namespace { ray::rpc::ActorHandle CreateInnerActorHandle( const class ActorID &actor_id, const class JobID &job_id, const ObjectID &initial_cursor, const Language actor_language, bool is_direct_call, - const std::vector &actor_creation_task_function_descriptor) { + const ray::FunctionDescriptor &actor_creation_task_function_descriptor) { ray::rpc::ActorHandle inner; inner.set_actor_id(actor_id.Data(), actor_id.Size()); inner.set_creation_job_id(job_id.Data(), job_id.Size()); inner.set_actor_language(actor_language); - *inner.mutable_actor_creation_task_function_descriptor() = { - actor_creation_task_function_descriptor.begin(), - actor_creation_task_function_descriptor.end()}; + *inner.mutable_actor_creation_task_function_descriptor() = + actor_creation_task_function_descriptor->GetMessage(); inner.set_actor_cursor(initial_cursor.Binary()); inner.set_is_direct_call(is_direct_call); return inner; @@ -33,7 +32,7 @@ namespace ray { ActorHandle::ActorHandle( const class ActorID &actor_id, const class JobID &job_id, const ObjectID &initial_cursor, const Language actor_language, bool is_direct_call, - const std::vector &actor_creation_task_function_descriptor) + const ray::FunctionDescriptor &actor_creation_task_function_descriptor) : ActorHandle(CreateInnerActorHandle(actor_id, job_id, initial_cursor, actor_language, is_direct_call, actor_creation_task_function_descriptor)) {} diff --git a/src/ray/core_worker/actor_handle.h b/src/ray/core_worker/actor_handle.h index 68554de0b..e7b225283 100644 --- a/src/ray/core_worker/actor_handle.h +++ b/src/ray/core_worker/actor_handle.h @@ -21,7 +21,7 @@ class ActorHandle { ActorHandle(const ActorID &actor_id, const JobID &job_id, const ObjectID &initial_cursor, const Language actor_language, bool is_direct_call, - const std::vector &actor_creation_task_function_descriptor); + const ray::FunctionDescriptor &actor_creation_task_function_descriptor); /// Constructs an ActorHandle from a serialized string. ActorHandle(const std::string &serialized); @@ -34,8 +34,9 @@ class ActorHandle { Language ActorLanguage() const { return inner_.actor_language(); }; - std::vector ActorCreationTaskFunctionDescriptor() const { - return VectorFromProtobuf(inner_.actor_creation_task_function_descriptor()); + ray::FunctionDescriptor ActorCreationTaskFunctionDescriptor() const { + return ray::FunctionDescriptorBuilder::FromProto( + inner_.actor_creation_task_function_descriptor()); }; bool IsDirectCallActor() const { return inner_.is_direct_call(); } diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index fe86decc3..1e277bfbb 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -23,18 +23,18 @@ std::string LanguageString(Language language); class RayFunction { public: RayFunction() {} - RayFunction(Language language, const std::vector &function_descriptor) + RayFunction(Language language, const ray::FunctionDescriptor &function_descriptor) : language_(language), function_descriptor_(function_descriptor) {} Language GetLanguage() const { return language_; } - const std::vector &GetFunctionDescriptor() const { + const ray::FunctionDescriptor &GetFunctionDescriptor() const { return function_descriptor_; } private: Language language_; - std::vector function_descriptor_; + ray::FunctionDescriptor function_descriptor_; }; /// Argument of a task. diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 0e6226cb8..9124378ea 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -200,13 +200,10 @@ CoreWorker::CoreWorker(const WorkerType worker_type, const Language language, // rerun the driver. if (worker_type_ == WorkerType::DRIVER) { TaskSpecBuilder builder; - std::vector empty_descriptor; - std::unordered_map empty_resources; const TaskID task_id = TaskID::ForDriverTask(worker_context_.GetCurrentJobID()); - builder.SetCommonTaskSpec( - task_id, language_, empty_descriptor, worker_context_.GetCurrentJobID(), - TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), 0, GetCallerId(), - rpc_address_, 0, false, empty_resources, empty_resources); + builder.SetDriverTaskSpec(task_id, language_, worker_context_.GetCurrentJobID(), + TaskID::ComputeDriverTaskId(worker_context_.GetWorkerID()), + GetCallerId(), rpc_address_); std::shared_ptr data = std::make_shared(); data->mutable_task()->mutable_task_spec()->CopyFrom(builder.Build().GetMessage()); @@ -1194,12 +1191,7 @@ void CoreWorker::HandleGetCoreWorkerStats(const rpc::GetCoreWorkerStatsRequest & stats->set_task_queue_length(task_queue_length_); stats->set_num_executed_tasks(num_executed_tasks_); stats->set_num_object_ids_in_scope(reference_counter_->NumObjectIDsInScope()); - if (!current_task_.TaskId().IsNil()) { - stats->set_current_task_desc(current_task_.DebugString()); - for (auto const it : current_task_.FunctionDescriptor()) { - stats->add_current_task_func_desc(it); - } - } + stats->set_current_task_func_desc(current_task_.FunctionDescriptor()->ToString()); stats->set_ip_address(rpc_address_.ip_address()); stats->set_port(rpc_address_.port()); stats->set_actor_id(actor_id_.Binary()); diff --git a/src/ray/core_worker/lib/java/jni_utils.h b/src/ray/core_worker/lib/java/jni_utils.h index 2024c9b3d..e9f59fbe4 100644 --- a/src/ray/core_worker/lib/java/jni_utils.h +++ b/src/ray/core_worker/lib/java/jni_utils.h @@ -3,6 +3,7 @@ #include #include "ray/common/buffer.h" +#include "ray/common/function_descriptor.h" #include "ray/common/id.h" #include "ray/common/ray_object.h" #include "ray/common/status.h" @@ -344,4 +345,26 @@ inline jobject NativeRayObjectToJavaNativeRayObject( return java_obj; } +// TODO(po): Convert C++ ray::FunctionDescriptor to Java FunctionDescriptor +inline jobject NativeRayFunctionDescriptorToJavaStringList( + JNIEnv *env, const ray::FunctionDescriptor &function_descriptor) { + if (function_descriptor->Type() == + ray::FunctionDescriptorType::kJavaFunctionDescriptor) { + auto typed_descriptor = function_descriptor->As(); + std::vector function_descriptor_list = {typed_descriptor->ClassName(), + typed_descriptor->FunctionName(), + typed_descriptor->Signature()}; + return NativeStringVectorToJavaStringList(env, function_descriptor_list); + } else if (function_descriptor->Type() == + ray::FunctionDescriptorType::kPythonFunctionDescriptor) { + auto typed_descriptor = function_descriptor->As(); + std::vector function_descriptor_list = { + typed_descriptor->ModuleName(), typed_descriptor->ClassName(), + typed_descriptor->FunctionName(), typed_descriptor->FunctionHash()}; + return NativeStringVectorToJavaStringList(env, function_descriptor_list); + } + RAY_LOG(FATAL) << "Unknown function descriptor type: " << function_descriptor->Type(); + return NativeStringVectorToJavaStringList(env, std::vector()); +} + #endif // RAY_COMMON_JAVA_JNI_UTILS_H diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc index 9eaa9c7ff..8a98a9cd3 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_RayNativeRuntime.cc @@ -44,8 +44,8 @@ JNIEXPORT jlong JNICALL Java_org_ray_runtime_RayNativeRuntime_nativeInitCoreWork RAY_CHECK(env); RAY_CHECK(local_java_task_executor); // convert RayFunction - jobject ray_function_array_list = - NativeStringVectorToJavaStringList(env, ray_function.GetFunctionDescriptor()); + jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList( + env, ray_function.GetFunctionDescriptor()); // convert args // TODO (kfstorm): Avoid copying binary data from Java to C++ jobject args_array_list = NativeVectorToJavaList>( diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc index 636118d7d..2b6d4e799 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_actor_NativeRayActor.cc @@ -33,7 +33,7 @@ Java_org_ray_runtime_actor_NativeRayActor_nativeGetActorCreationTaskFunctionDesc .GetActorHandle(actor_id, &native_actor_handle); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); auto function_descriptor = native_actor_handle->ActorCreationTaskFunctionDescriptor(); - return NativeStringVectorToJavaStringList(env, function_descriptor); + return NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor); } JNIEXPORT jbyteArray JNICALL Java_org_ray_runtime_actor_NativeRayActor_nativeSerialize( diff --git a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc index a4141306a..87c801c6f 100644 --- a/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/org_ray_runtime_task_NativeTaskSubmitter.cc @@ -10,17 +10,20 @@ inline ray::CoreWorker &GetCoreWorker(jlong nativeCoreWorkerPointer) { } inline ray::RayFunction ToRayFunction(JNIEnv *env, jobject functionDescriptor) { - std::vector function_descriptor; + std::vector function_descriptor_list; jobject list = env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list); RAY_CHECK_JAVA_EXCEPTION(env); - JavaStringListToNativeStringVector(env, list, &function_descriptor); + JavaStringListToNativeStringVector(env, list, &function_descriptor_list); jobject java_language = env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); RAY_CHECK_JAVA_EXCEPTION(env); - int language = env->CallIntMethod(java_language, java_language_get_number); + auto language = static_cast<::Language>( + env->CallIntMethod(java_language, java_language_get_number)); RAY_CHECK_JAVA_EXCEPTION(env); - ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor}; + ray::FunctionDescriptor function_descriptor = + ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); + ray::RayFunction ray_function{language, function_descriptor}; return ray_function; } @@ -134,7 +137,8 @@ JNIEXPORT jobject JNICALL Java_org_ray_runtime_task_NativeTaskSubmitter_nativeSu std::vector return_ids; auto status = GetCoreWorker(nativeCoreWorkerPointer) - .SubmitTask(ray_function, task_args, task_options, &return_ids, /*max_retries=*/1); + .SubmitTask(ray_function, task_args, task_options, &return_ids, + /*max_retries=*/1); THROW_EXCEPTION_AND_RETURN_IF_NOT_OK(env, status, nullptr); diff --git a/src/ray/core_worker/test/core_worker_test.cc b/src/ray/core_worker/test/core_worker_test.cc index 6e366b601..5ebf25a68 100644 --- a/src/ray/core_worker/test/core_worker_test.cc +++ b/src/ray/core_worker/test/core_worker_test.cc @@ -51,14 +51,15 @@ ActorID CreateActorHelper(CoreWorker &worker, uint8_t array[] = {1, 2, 3}; auto buffer = std::make_shared(array, sizeof(array)); - RayFunction func(ray::Language::PYTHON, {"actor creation task"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "actor creation task", "", "", "")); std::vector args; args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); - ActorCreationOptions actor_options{ - max_reconstructions, is_direct_call, - /*max_concurrency*/ 1, resources, resources, {}, - /*is_detached*/ false, /*is_asyncio*/ false}; + ActorCreationOptions actor_options{max_reconstructions, is_direct_call, + /*max_concurrency*/ 1, resources, resources, {}, + /*is_detached*/ false, + /*is_asyncio*/ false}; // Create an actor. ActorID actor_id; @@ -284,7 +285,8 @@ int CoreWorkerTest::GetActorPid(CoreWorker &worker, const ActorID &actor_id, std::vector args; TaskOptions options{1, is_direct_call, resources}; std::vector return_ids; - RayFunction func{Language::PYTHON, {"GetWorkerPid"}}; + RayFunction func{Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "GetWorkerPid", "", "", "")}; RAY_CHECK_OK(worker.SubmitActorTask(actor_id, func, args, options, &return_ids)); @@ -321,7 +323,8 @@ void CoreWorkerTest::TestNormalTask(std::unordered_map &res TaskArg::PassByValue(std::make_shared(buffer1, nullptr))); args.emplace_back(TaskArg::PassByReference(object_id)); - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); TaskOptions options; options.is_direct_call = true; @@ -369,7 +372,8 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso TaskOptions options{1, false, resources}; std::vector return_ids; - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); ASSERT_EQ(return_ids.size(), 1); @@ -412,7 +416,8 @@ void CoreWorkerTest::TestActorTask(std::unordered_map &reso TaskOptions options{1, false, resources}; std::vector return_ids; - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); auto status = driver.SubmitActorTask(actor_id, func, args, options, &return_ids); ASSERT_TRUE(status.ok()); @@ -477,7 +482,8 @@ void CoreWorkerTest::TestActorReconstruction( TaskOptions options{1, false, resources}; std::vector return_ids; - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); ASSERT_EQ(return_ids.size(), 1); @@ -522,7 +528,8 @@ void CoreWorkerTest::TestActorFailure(std::unordered_map &r TaskOptions options{1, false, resources}; std::vector return_ids; - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); @@ -587,7 +594,8 @@ TEST_F(ZeroNodeTest, TestTaskSpecPerf) { // to benchmark performance. uint8_t array[] = {1, 2, 3}; auto buffer = std::make_shared(array, sizeof(array)); - RayFunction function(ray::Language::PYTHON, {}); + RayFunction function(ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("", "", "", "")); std::vector args; args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr))); @@ -670,7 +678,8 @@ TEST_F(SingleNodeTest, TestDirectActorTaskSubmissionPerf) { TaskOptions options{1, false, resources}; std::vector return_ids; - RayFunction func(ray::Language::PYTHON, {"MergeInputArgsAsOutput"}); + RayFunction func(ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "MergeInputArgsAsOutput", "", "", "")); RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); ASSERT_EQ(return_ids.size(), 1); @@ -717,7 +726,7 @@ TEST_F(ZeroNodeTest, TestActorHandle) { JobID job_id = NextJobId(); ActorHandle original(ActorID::Of(job_id, TaskID::ForDriverTask(job_id), 0), job_id, ObjectID::FromRandom(), Language::PYTHON, /*is_direct_call=*/false, - {}); + ray::FunctionDescriptorBuilder::BuildPython("", "", "", "")); std::string output; original.Serialize(&output); ActorHandle deserialized(output); diff --git a/src/ray/core_worker/test/direct_task_transport_test.cc b/src/ray/core_worker/test/direct_task_transport_test.cc index e70116ed3..4f189bb0f 100644 --- a/src/ray/core_worker/test/direct_task_transport_test.cc +++ b/src/ray/core_worker/test/direct_task_transport_test.cc @@ -241,7 +241,7 @@ TEST(LocalDependencyResolverTest, TestInlinePendingDependencies) { } TaskSpecification BuildTaskSpec(const std::unordered_map &resources, - const std::vector &function_descriptor) { + const ray::FunctionDescriptor &function_descriptor) { TaskSpecBuilder builder; rpc::Address empty_address; builder.SetCommonTaskSpec(TaskID::Nil(), Language::PYTHON, function_descriptor, @@ -261,7 +261,8 @@ TEST(DirectTaskTransportTest, TestSubmitOneTask) { task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -291,7 +292,8 @@ TEST(DirectTaskTransportTest, TestHandleTaskFailure) { CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -315,7 +317,8 @@ TEST(DirectTaskTransportTest, TestConcurrentWorkerLeases) { CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor); @@ -360,7 +363,8 @@ TEST(DirectTaskTransportTest, TestReuseWorkerLease) { CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor); @@ -408,7 +412,8 @@ TEST(DirectTaskTransportTest, TestWorkerNotReusedOnError) { CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor); @@ -446,7 +451,8 @@ TEST(DirectTaskTransportTest, TestWorkerNotReturnedOnExit) { CoreWorkerDirectTaskSubmitter submitter(address, raylet_client, factory, nullptr, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task1).ok()); @@ -484,7 +490,8 @@ TEST(DirectTaskTransportTest, TestSpillback) { lease_client_factory, store, task_finisher, ClientID::Nil(), kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -534,7 +541,8 @@ TEST(DirectTaskTransportTest, TestSpillbackRoundTrip) { lease_client_factory, store, task_finisher, local_raylet_id, kLongTimeout); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task = BuildTaskSpec(empty_resources, empty_descriptor); ASSERT_TRUE(submitter.SubmitTask(task).ok()); @@ -620,8 +628,10 @@ TEST(DirectTaskTransportTest, TestSchedulingKeys) { std::unordered_map resources1({{"a", 1.0}}); std::unordered_map resources2({{"b", 2.0}}); - std::vector descriptor1({"a"}); - std::vector descriptor2({"b"}); + ray::FunctionDescriptor descriptor1 = + ray::FunctionDescriptorBuilder::BuildPython("a", "", "", ""); + ray::FunctionDescriptor descriptor2 = + ray::FunctionDescriptorBuilder::BuildPython("b", "", "", ""); // Tasks with different resources should request different worker leases. RAY_LOG(INFO) << "Test different resources"; @@ -682,7 +692,8 @@ TEST(DirectTaskTransportTest, TestWorkerLeaseTimeout) { task_finisher, ClientID::Nil(), /*lease_timeout_ms=*/5); std::unordered_map empty_resources; - std::vector empty_descriptor; + ray::FunctionDescriptor empty_descriptor = + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""); TaskSpecification task1 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task2 = BuildTaskSpec(empty_resources, empty_descriptor); TaskSpecification task3 = BuildTaskSpec(empty_resources, empty_descriptor); diff --git a/src/ray/core_worker/test/mock_worker.cc b/src/ray/core_worker/test/mock_worker.cc index 754d165b0..059705429 100644 --- a/src/ray/core_worker/test/mock_worker.cc +++ b/src/ray/core_worker/test/mock_worker.cc @@ -35,20 +35,23 @@ class MockWorker { const std::vector &return_ids, std::vector> *results) { // Note that this doesn't include dummy object id. - const std::vector &function_descriptor = + const ray::FunctionDescriptor function_descriptor = ray_function.GetFunctionDescriptor(); - RAY_CHECK(return_ids.size() >= 0 && 1 == function_descriptor.size()); + RAY_CHECK(function_descriptor->Type() == + ray::FunctionDescriptorType::kPythonFunctionDescriptor); + auto typed_descriptor = function_descriptor->As(); - if ("actor creation task" == function_descriptor[0]) { + if ("actor creation task" == typed_descriptor->ModuleName()) { return Status::OK(); - } else if ("GetWorkerPid" == function_descriptor[0]) { + } else if ("GetWorkerPid" == typed_descriptor->ModuleName()) { // Get mock worker pid return GetWorkerPid(results); - } else if ("MergeInputArgsAsOutput" == function_descriptor[0]) { + } else if ("MergeInputArgsAsOutput" == typed_descriptor->ModuleName()) { // Merge input args and write the merged content to each of return ids return MergeInputArgsAsOutput(args, return_ids, results); } else { - return Status::TypeError("Unknown function descriptor: " + function_descriptor[0]); + return Status::TypeError("Unknown function descriptor: " + + typed_descriptor->ModuleName()); } } diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index d6ed85bba..c37c805c1 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -25,6 +25,8 @@ enum TaskType { ACTOR_CREATION_TASK = 1; // Actor task. ACTOR_TASK = 2; + // Driver task. + DRIVER_TASK = 3; } // Address of a worker or node manager. @@ -36,6 +38,29 @@ message Address { bytes worker_id = 4; } +/// Function descriptor for Java. +message JavaFunctionDescriptor { + string class_name = 1; + string function_name = 2; + string signature = 3; +} + +/// Function descriptor for Python. +message PythonFunctionDescriptor { + string module_name = 1; + string class_name = 2; + string function_name = 3; + string function_hash = 4; +} + +// A union wrapper for various function descriptor types. +message FunctionDescriptor { + oneof function_descriptor { + JavaFunctionDescriptor java_function_descriptor = 1; + PythonFunctionDescriptor python_function_descriptor = 2; + } +} + /// The task specification encapsulates all immutable information about the /// task. These fields are determined at submission time, converse to the /// `TaskExecutionSpec` may change at execution time. @@ -44,11 +69,8 @@ message TaskSpec { TaskType type = 1; // Language of this task. Language language = 2; - // Function descriptor of this task, which is a list of strings that can - // uniquely describe the function to execute. - // For a Python function, it should be: [module_name, class_name, function_name] - // For a Java function, it should be: [class_name, method_name, type_descriptor] - repeated bytes function_descriptor = 3; + // Function descriptor of this task uniquely describe the function to execute. + FunctionDescriptor function_descriptor = 3; // ID of the job that this task belongs to. bytes job_id = 4; // Task ID of the task. @@ -194,8 +216,8 @@ message CoreWorkerStats { int32 num_pending_tasks = 2; // Number of object ids in local scope. int32 num_object_ids_in_scope = 3; - // Function descriptor of the currently executing task. - repeated bytes current_task_func_desc = 4; + // String representation of the function descriptor of the currently executing task. + string current_task_func_desc = 4; // IP address of the core worker. string ip_address = 6; // Port of the core worker. diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 7c77857ac..c6e9c2c35 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -21,7 +21,7 @@ message ActorHandle { Language actor_language = 4; // Function descriptor of actor creation task. - repeated string actor_creation_task_function_descriptor = 5; + FunctionDescriptor actor_creation_task_function_descriptor = 5; // The unique id of the dummy object returned by the actor creation task. // It's used as a dependency for the first task. diff --git a/src/ray/ray_exported_symbols.lds b/src/ray/ray_exported_symbols.lds index 392a5b12d..446542103 100644 --- a/src/ray/ray_exported_symbols.lds +++ b/src/ray/ray_exported_symbols.lds @@ -7,6 +7,7 @@ *ray*RayObject* *ray*Status* *ray*RayFunction* +*ray*FunctionDescriptorBuilder* *ray*TaskArg* *ray*TaskOptions* *ray*Buffer* diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds index 0be55a0b9..c778015f5 100644 --- a/src/ray/ray_version_script.lds +++ b/src/ray/ray_version_script.lds @@ -9,6 +9,7 @@ VERSION_1.0 { *ray*RayObject*; *ray*Status*; *ray*RayFunction*; + *ray*FunctionDescriptorBuilder*; *ray*TaskArg*; *ray*TaskOptions*; *ray*Buffer*; diff --git a/src/ray/raylet/lineage_cache_test.cc b/src/ray/raylet/lineage_cache_test.cc index bb621b560..f08241b96 100644 --- a/src/ray/raylet/lineage_cache_test.cc +++ b/src/ray/raylet/lineage_cache_test.cc @@ -183,9 +183,10 @@ static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; rpc::Address address; - builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), - RandomTaskId(), 0, RandomTaskId(), address, num_returns, - false, {}, {}); + builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("", "", "", ""), + JobID::Nil(), RandomTaskId(), 0, RandomTaskId(), address, + num_returns, false, {}, {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } diff --git a/src/ray/raylet/scheduling_queue.cc b/src/ray/raylet/scheduling_queue.cc index 85b2efba1..e9ded0f63 100644 --- a/src/ray/raylet/scheduling_queue.cc +++ b/src/ray/raylet/scheduling_queue.cc @@ -447,19 +447,7 @@ std::string SchedulingQueue::DebugString() const { for (const auto &pair : num_running_tasks_) { result << "\n- "; auto desc = TaskSpecification::GetSchedulingClassDescriptor(pair.first); - for (const auto &str : desc.second) { - // Only print the ASCII parts of the function descriptor. - bool ok = str.size() > 0; - for (char c : str) { - if (!isprint(c)) { - ok = false; - } - } - if (ok) { - result << str; - result << "."; - } - } + result << desc.second->ToString(); result << desc.first.ToString(); result << ": " << pair.second; total += pair.second; diff --git a/src/ray/raylet/task_dependency_manager_test.cc b/src/ray/raylet/task_dependency_manager_test.cc index d7cff6c4f..a3fd5f6c0 100644 --- a/src/ray/raylet/task_dependency_manager_test.cc +++ b/src/ray/raylet/task_dependency_manager_test.cc @@ -94,9 +94,10 @@ static inline Task ExampleTask(const std::vector &arguments, uint64_t num_returns) { TaskSpecBuilder builder; rpc::Address address; - builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, {"", "", ""}, JobID::Nil(), - RandomTaskId(), 0, RandomTaskId(), address, num_returns, - false, {}, {}); + builder.SetCommonTaskSpec(RandomTaskId(), Language::PYTHON, + FunctionDescriptorBuilder::BuildPython("", "", "", ""), + JobID::Nil(), RandomTaskId(), 0, RandomTaskId(), address, + num_returns, false, {}, {}); for (const auto &arg : arguments) { builder.AddByRefArg(arg); } diff --git a/streaming/python/includes/transfer.pxi b/streaming/python/includes/transfer.pxi index 7830cf8ab..7dcd91f28 100644 --- a/streaming/python/includes/transfer.pxi +++ b/streaming/python/includes/transfer.pxi @@ -22,7 +22,7 @@ from ray._raylet cimport ( CoreWorker, ActorID, ObjectID, - string_vector_from_list + FunctionDescriptor, ) from ray.includes.libcoreworker cimport CCoreWorker @@ -42,7 +42,6 @@ from ray.streaming.includes.libstreaming cimport ( ) import logging -from ray.function_manager import FunctionDescriptor channel_logger = logging.getLogger(__name__) @@ -54,16 +53,14 @@ cdef class ReaderClient: def __cinit__(self, CoreWorker worker, - async_func: FunctionDescriptor, - sync_func: FunctionDescriptor): + FunctionDescriptor async_func, + FunctionDescriptor sync_func): cdef: CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func - async_native_func = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list())) - sync_native_func = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list())) + async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) + sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) self.client = new CReaderClient(core_worker, async_native_func, sync_native_func) def __dealloc__(self): @@ -95,16 +92,14 @@ cdef class WriterClient: def __cinit__(self, CoreWorker worker, - async_func: FunctionDescriptor, - sync_func: FunctionDescriptor): + FunctionDescriptor async_func, + FunctionDescriptor sync_func): cdef: CCoreWorker *core_worker = worker.core_worker.get() CRayFunction async_native_func CRayFunction sync_native_func - async_native_func = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(async_func.get_function_descriptor_list())) - sync_native_func = CRayFunction( - LANGUAGE_PYTHON, string_vector_from_list(sync_func.get_function_descriptor_list())) + async_native_func = CRayFunction(LANGUAGE_PYTHON, async_func.descriptor) + sync_native_func = CRayFunction(LANGUAGE_PYTHON, sync_func.descriptor) self.client = new CWriterClient(core_worker, async_native_func, sync_native_func) def __dealloc__(self): diff --git a/streaming/python/jobworker.py b/streaming/python/jobworker.py index ed97f6735..07cbd0fb8 100644 --- a/streaming/python/jobworker.py +++ b/streaming/python/jobworker.py @@ -5,7 +5,7 @@ import threading import ray import ray.streaming._streaming as _streaming from ray.streaming.config import Config -from ray.function_manager import FunctionDescriptor +from ray._raylet import PythonFunctionDescriptor from ray.streaming.communication import DataInput, DataOutput logger = logging.getLogger(__name__) @@ -47,18 +47,18 @@ class JobWorker: if env.config.channel_type == Config.NATIVE_CHANNEL: core_worker = ray.worker.global_worker.core_worker - reader_async_func = FunctionDescriptor( + reader_async_func = PythonFunctionDescriptor( __name__, self.on_reader_message.__name__, self.__class__.__name__) - reader_sync_func = FunctionDescriptor( + reader_sync_func = PythonFunctionDescriptor( __name__, self.on_reader_message_sync.__name__, self.__class__.__name__) self.reader_client = _streaming.ReaderClient( core_worker, reader_async_func, reader_sync_func) - writer_async_func = FunctionDescriptor( + writer_async_func = PythonFunctionDescriptor( __name__, self.on_writer_message.__name__, self.__class__.__name__) - writer_sync_func = FunctionDescriptor( + writer_sync_func = PythonFunctionDescriptor( __name__, self.on_writer_message_sync.__name__, self.__class__.__name__) self.writer_client = _streaming.WriterClient( diff --git a/streaming/python/tests/test_direct_transfer.py b/streaming/python/tests/test_direct_transfer.py index 42321769f..57805cbfd 100644 --- a/streaming/python/tests/test_direct_transfer.py +++ b/streaming/python/tests/test_direct_transfer.py @@ -5,7 +5,7 @@ import time import ray import ray.streaming._streaming as _streaming import ray.streaming.runtime.transfer as transfer -from ray.function_manager import FunctionDescriptor +from ray._raylet import PythonFunctionDescriptor from ray.streaming.config import Config @@ -13,16 +13,16 @@ from ray.streaming.config import Config class Worker: def __init__(self): core_worker = ray.worker.global_worker.core_worker - writer_async_func = FunctionDescriptor( + writer_async_func = PythonFunctionDescriptor( __name__, self.on_writer_message.__name__, self.__class__.__name__) - writer_sync_func = FunctionDescriptor( + writer_sync_func = PythonFunctionDescriptor( __name__, self.on_writer_message_sync.__name__, self.__class__.__name__) self.writer_client = _streaming.WriterClient( core_worker, writer_async_func, writer_sync_func) - reader_async_func = FunctionDescriptor( + reader_async_func = PythonFunctionDescriptor( __name__, self.on_reader_message.__name__, self.__class__.__name__) - reader_sync_func = FunctionDescriptor( + reader_sync_func = PythonFunctionDescriptor( __name__, self.on_reader_message_sync.__name__, self.__class__.__name__) self.reader_client = _streaming.ReaderClient( diff --git a/streaming/src/lib/java/streaming_jni_common.cc b/streaming/src/lib/java/streaming_jni_common.cc index 89dd7b75c..f6b0ba213 100644 --- a/streaming/src/lib/java/streaming_jni_common.cc +++ b/streaming/src/lib/java/streaming_jni_common.cc @@ -1,27 +1,25 @@ #include "streaming_jni_common.h" -std::vector -jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) { +std::vector jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) { int stringCount = env->GetArrayLength(jarr); std::vector object_id_vec; for (int i = 0; i < stringCount; i++) { - auto jstr = (jbyteArray) (env->GetObjectArrayElement(jarr, i)); + auto jstr = (jbyteArray)(env->GetObjectArrayElement(jarr, i)); UniqueIdFromJByteArray idFromJByteArray(env, jstr); object_id_vec.push_back(idFromJByteArray.PID); } - return object_id_vec; + return object_id_vec; } -std::vector -jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) { +std::vector jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) { int count = env->GetArrayLength(jarr); std::vector actor_id_vec; for (int i = 0; i < count; i++) { auto bytes = (jbyteArray)(env->GetObjectArrayElement(jarr, i)); std::string id_str(ray::ActorID::Size(), 0); env->GetByteArrayRegion(bytes, 0, ray::ActorID::Size(), - reinterpret_cast(&id_str.front())); - actor_id_vec.push_back(ActorID::FromBinary(id_str)); + reinterpret_cast(&id_str.front())); + actor_id_vec.push_back(ActorID::FromBinary(id_str)); } return actor_id_vec; @@ -38,17 +36,22 @@ jint throwChannelInitException(JNIEnv *env, const char *message, const std::vector &abnormal_queues) { jclass array_list_class = env->FindClass("java/util/ArrayList"); jmethodID array_list_constructor = env->GetMethodID(array_list_class, "", "()V"); - jmethodID array_list_add = env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + jmethodID array_list_add = + env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); jobject array_list = env->NewObject(array_list_class, array_list_constructor); for (auto &q_id : abnormal_queues) { jbyteArray jbyte_array = env->NewByteArray(kUniqueIDSize); - env->SetByteArrayRegion(jbyte_array, 0, kUniqueIDSize, const_cast(reinterpret_cast(q_id.Data()))); + env->SetByteArrayRegion( + jbyte_array, 0, kUniqueIDSize, + const_cast(reinterpret_cast(q_id.Data()))); env->CallBooleanMethod(array_list, array_list_add, jbyte_array); } - jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInitException"); - jmethodID ex_constructor = env->GetMethodID(ex_class, "", "(Ljava/lang/String;Ljava/util/List;)V"); + jclass ex_class = + env->FindClass("org/ray/streaming/runtime/transfer/ChannelInitException"); + jmethodID ex_constructor = + env->GetMethodID(ex_class, "", "(Ljava/lang/String;Ljava/util/List;)V"); jstring message_jstr = env->NewStringUTF(message); jobject ex_obj = env->NewObject(ex_class, ex_constructor, message_jstr, array_list); env->DeleteLocalRef(message_jstr); @@ -56,7 +59,8 @@ jint throwChannelInitException(JNIEnv *env, const char *message, } jint throwChannelInterruptException(JNIEnv *env, const char *message) { - jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInterruptException"); + jclass ex_class = + env->FindClass("org/ray/streaming/runtime/transfer/ChannelInterruptException"); return env->ThrowNew(ex_class, message); } @@ -69,12 +73,13 @@ jclass LoadClass(JNIEnv *env, const char *class_name) { } template -void JavaListToNativeVector( - JNIEnv *env, jobject java_list, std::vector *native_vector, - std::function element_converter) { +void JavaListToNativeVector(JNIEnv *env, jobject java_list, + std::vector *native_vector, + std::function element_converter) { jclass java_list_class = LoadClass(env, "java/util/List"); jmethodID java_list_size = env->GetMethodID(java_list_class, "size", "()I"); - jmethodID java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); + jmethodID java_list_get = + env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); int size = env->CallIntMethod(java_list, java_list_size); native_vector->clear(); for (int i = 0; i < size; i++) { @@ -100,24 +105,29 @@ void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, }); } -ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor) { - jclass java_language_class = LoadClass(env, "org/ray/runtime/generated/Common$Language"); +ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, + jobject functionDescriptor) { + jclass java_language_class = + LoadClass(env, "org/ray/runtime/generated/Common$Language"); jclass java_function_descriptor_class = LoadClass(env, "org/ray/runtime/functionmanager/FunctionDescriptor"); - jmethodID java_language_get_number = env->GetMethodID(java_language_class, "getNumber", "()I"); + jmethodID java_language_get_number = + env->GetMethodID(java_language_class, "getNumber", "()I"); jmethodID java_function_descriptor_get_language = env->GetMethodID(java_function_descriptor_class, "getLanguage", "()Lorg/ray/runtime/generated/Common$Language;"); jobject java_language = env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); - int language = env->CallIntMethod(java_language, java_language_get_number); - std::vector function_descriptor; + auto language = static_cast<::Language>( + env->CallIntMethod(java_language, java_language_get_number)); + std::vector function_descriptor_list; jmethodID java_function_descriptor_to_list = env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;"); JavaStringListToNativeStringVector( env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list), - &function_descriptor); - ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor}; + &function_descriptor_list); + ray::FunctionDescriptor function_descriptor = + ray::FunctionDescriptorBuilder::FromVector(language, function_descriptor_list); + ray::RayFunction ray_function{language, function_descriptor}; return ray_function; } - diff --git a/streaming/src/queue/transport.cc b/streaming/src/queue/transport.cc index 18c1cd395..d3a55f162 100644 --- a/streaming/src/queue/transport.cc +++ b/streaming/src/queue/transport.cc @@ -77,7 +77,7 @@ std::shared_ptr Transport::SendForResultWithRetry( int64_t timeout_ms) { STREAMING_LOG(INFO) << "SendForResultWithRetry retry_cnt: " << retry_cnt << " timeout_ms: " << timeout_ms - << " function: " << function.GetFunctionDescriptor()[0]; + << " function: " << function.GetFunctionDescriptor()->ToString(); std::shared_ptr buffer_shared = std::move(buffer); for (int cnt = 0; cnt < retry_cnt; cnt++) { auto result = SendForResult(buffer_shared, function, timeout_ms); diff --git a/streaming/src/test/mock_actor.cc b/streaming/src/test/mock_actor.cc index 0effacf83..1bfce6276 100644 --- a/streaming/src/test/mock_actor.cc +++ b/streaming/src/test/mock_actor.cc @@ -287,10 +287,18 @@ class StreamingWorker { JobID::FromInt(1), gcs_options, "", "127.0.0.1", node_manager_port, std::bind(&StreamingWorker::ExecuteTask, this, _1, _2, _3, _4, _5, _6, _7)); - RayFunction reader_async_call_func{ray::Language::PYTHON, {"reader_async_call_func"}}; - RayFunction reader_sync_call_func{ray::Language::PYTHON, {"reader_sync_call_func"}}; - RayFunction writer_async_call_func{ray::Language::PYTHON, {"writer_async_call_func"}}; - RayFunction writer_sync_call_func{ray::Language::PYTHON, {"writer_sync_call_func"}}; + RayFunction reader_async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython( + "reader_async_call_func", "", "", "")}; + RayFunction reader_sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("reader_sync_call_func", "", "", "")}; + RayFunction writer_async_call_func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython( + "writer_async_call_func", "", "", "")}; + RayFunction writer_sync_call_func{ + ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("writer_sync_call_func", "", "", "")}; reader_client_ = std::make_shared(worker_.get(), reader_async_call_func, reader_sync_call_func); @@ -314,18 +322,22 @@ class StreamingWorker { // Only one arg param used in streaming. STREAMING_CHECK(args.size() >= 1) << "args.size() = " << args.size(); - std::vector function_descriptor = ray_function.GetFunctionDescriptor(); - STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " << function_descriptor[0]; + ray::FunctionDescriptor function_descriptor = ray_function.GetFunctionDescriptor(); + RAY_CHECK(function_descriptor->Type() == + ray::FunctionDescriptorType::kPythonFunctionDescriptor); + auto typed_descriptor = function_descriptor->As(); + STREAMING_LOG(INFO) << "StreamingWorker::ExecuteTask " + << typed_descriptor->ModuleName(); - std::string func_name = function_descriptor[0]; + std::string func_name = typed_descriptor->ModuleName(); if (func_name == "init") { std::shared_ptr local_buffer = std::make_shared(args[0]->GetData()->Data(), args[0]->GetData()->Size(), true); HandleInitTask(local_buffer); } else if (func_name == "execute_test") { - STREAMING_LOG(INFO) << "Test name: " << function_descriptor[1]; - test_suite_->ExecuteTest(function_descriptor[1]); + STREAMING_LOG(INFO) << "Test name: " << typed_descriptor->ClassName(); + test_suite_->ExecuteTest(typed_descriptor->ClassName()); } else if (func_name == "check_current_test_status") { results->push_back( std::make_shared(test_suite_->CheckCurTestStatus(), nullptr)); diff --git a/streaming/src/test/queue_tests_base.h b/streaming/src/test/queue_tests_base.h index ea4b7bfd0..6e3da6561 100644 --- a/streaming/src/test/queue_tests_base.h +++ b/streaming/src/test/queue_tests_base.h @@ -162,7 +162,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::unordered_map resources; TaskOptions options{0, true, resources}; std::vector return_ids; - RayFunction func{ray::Language::PYTHON, {"init"}}; + RayFunction func{ray::Language::PYTHON, + ray::FunctionDescriptorBuilder::BuildPython("init", "", "", "")}; RAY_CHECK_OK(driver.SubmitActorTask(self_actor_id, func, args, options, &return_ids)); } @@ -176,7 +177,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::unordered_map resources; TaskOptions options{0, true, resources}; std::vector return_ids; - RayFunction func{ray::Language::PYTHON, {"execute_test", test}}; + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "execute_test", test, "", "")}; RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); } @@ -190,7 +192,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { std::unordered_map resources; TaskOptions options{1, true, resources}; std::vector return_ids; - RayFunction func{ray::Language::PYTHON, {"check_current_test_status"}}; + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "check_current_test_status", "", "", "")}; RAY_CHECK_OK(driver.SubmitActorTask(actor_id, func, args, options, &return_ids)); @@ -250,7 +253,8 @@ class StreamingQueueTestBase : public ::testing::TestWithParam { uint8_t array[] = {1, 2, 3}; auto buffer = std::make_shared(array, sizeof(array)); - RayFunction func{ray::Language::PYTHON, {"actor creation task"}}; + RayFunction func{ray::Language::PYTHON, ray::FunctionDescriptorBuilder::BuildPython( + "actor creation task", "", "", "")}; std::vector args; args.emplace_back(TaskArg::PassByValue(std::make_shared(buffer, nullptr)));