mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 10:01:11 +08:00
[xlang] Cross language Python support (#6709)
This commit is contained in:
@@ -3,7 +3,8 @@ from functools import wraps
|
||||
|
||||
from ray import cloudpickle as pickle
|
||||
from ray import ray_constants
|
||||
from ray.function_manager import FunctionDescriptor
|
||||
from ray._raylet import PythonFunctionDescriptor
|
||||
from ray import cross_language, Language
|
||||
import ray.signature
|
||||
|
||||
# Default parameters for remote functions.
|
||||
@@ -23,6 +24,7 @@ class RemoteFunction:
|
||||
This is a decorated function. It can be used to spawn tasks.
|
||||
|
||||
Attributes:
|
||||
_language: The target language.
|
||||
_function: The original function.
|
||||
_function_descriptor: The function descriptor. This is not defined
|
||||
until the remote function is first invoked because that is when the
|
||||
@@ -57,12 +59,15 @@ class RemoteFunction:
|
||||
different workers.
|
||||
"""
|
||||
|
||||
def __init__(self, function, num_cpus, num_gpus, memory,
|
||||
object_store_memory, resources, num_return_vals, max_calls,
|
||||
max_retries):
|
||||
def __init__(self, language, function, function_descriptor, num_cpus,
|
||||
num_gpus, memory, object_store_memory, resources,
|
||||
num_return_vals, max_calls, max_retries):
|
||||
self._language = language
|
||||
self._function = function
|
||||
self._function_name = (
|
||||
self._function.__module__ + "." + self._function.__name__)
|
||||
self._function_descriptor = function_descriptor
|
||||
self._is_cross_language = language != Language.PYTHON
|
||||
self._num_cpus = (DEFAULT_REMOTE_FUNCTION_CPUS
|
||||
if num_cpus is None else num_cpus)
|
||||
self._num_gpus = num_gpus
|
||||
@@ -80,11 +85,11 @@ class RemoteFunction:
|
||||
if max_retries is None else max_retries)
|
||||
self._decorator = getattr(function, "__ray_invocation_decorator__",
|
||||
None)
|
||||
|
||||
self._function_signature = ray.signature.extract_signature(
|
||||
self._function)
|
||||
|
||||
self._last_export_session_and_job = None
|
||||
|
||||
# Override task.remote's signature and docstring
|
||||
@wraps(function)
|
||||
def _remote_proxy(*args, **kwargs):
|
||||
@@ -152,7 +157,9 @@ class RemoteFunction:
|
||||
|
||||
# If this function was not exported in this session and job, we need to
|
||||
# export this function again, because the current GCS doesn't have it.
|
||||
if self._last_export_session_and_job != worker.current_session_and_job:
|
||||
if not self._is_cross_language and \
|
||||
self._last_export_session_and_job != \
|
||||
worker.current_session_and_job:
|
||||
# There is an interesting question here. If the remote function is
|
||||
# used by a subsequent driver (in the same script), should the
|
||||
# second driver pickle the function again? If yes, then the remote
|
||||
@@ -164,10 +171,8 @@ class RemoteFunction:
|
||||
# which we do here.
|
||||
self._pickled_function = pickle.dumps(self._function)
|
||||
|
||||
self._function_descriptor = FunctionDescriptor.from_function(
|
||||
self._function_descriptor = PythonFunctionDescriptor.from_function(
|
||||
self._function, self._pickled_function)
|
||||
self._function_descriptor_list = (
|
||||
self._function_descriptor.get_function_descriptor_list())
|
||||
|
||||
self._last_export_session_and_job = worker.current_session_and_job
|
||||
worker.function_actor_manager.export(self)
|
||||
@@ -188,20 +193,25 @@ class RemoteFunction:
|
||||
memory, object_store_memory, resources)
|
||||
|
||||
def invocation(args, kwargs):
|
||||
if not args and not kwargs and not self._function_signature:
|
||||
if self._is_cross_language:
|
||||
list_args = cross_language.format_args(worker, args, kwargs)
|
||||
elif not args and not kwargs and not self._function_signature:
|
||||
list_args = []
|
||||
else:
|
||||
list_args = ray.signature.flatten_args(
|
||||
self._function_signature, args, kwargs)
|
||||
|
||||
if worker.mode == ray.worker.LOCAL_MODE:
|
||||
assert not self._is_cross_language, \
|
||||
"Cross language remote function " \
|
||||
"cannot be executed locally."
|
||||
object_ids = worker.local_mode_manager.execute(
|
||||
self._function, self._function_descriptor, args, kwargs,
|
||||
num_return_vals)
|
||||
else:
|
||||
object_ids = worker.core_worker.submit_task(
|
||||
self._function_descriptor_list, list_args, num_return_vals,
|
||||
is_direct_call, resources, max_retries)
|
||||
self._language, self._function_descriptor, list_args,
|
||||
num_return_vals, is_direct_call, resources, max_retries)
|
||||
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
|
||||
Reference in New Issue
Block a user