mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 22:23:13 +08:00
[core] Support kwargs and positionals in Ray remote calls (#5606)
This commit is contained in:
+10
-14
@@ -193,7 +193,6 @@ class ActorClassMetadata(object):
|
||||
# supported. We don't raise an exception because if the actor
|
||||
# inherits from a class that has a method whose signature we
|
||||
# don't support, there may not be much the user can do about it.
|
||||
signature.check_signature_supported(method, warn=True)
|
||||
self.method_signatures[method_name] = signature.extract_signature(
|
||||
method, ignore_first=not ray.utils.is_class_method(method))
|
||||
# Set the default number of return values for this method.
|
||||
@@ -277,7 +276,6 @@ class ActorClass(object):
|
||||
DerivedActorClass.__module__ = modified_class.__module__
|
||||
DerivedActorClass.__name__ = name
|
||||
DerivedActorClass.__qualname__ = name
|
||||
|
||||
# Construct the base object.
|
||||
self = DerivedActorClass.__new__(DerivedActorClass)
|
||||
|
||||
@@ -397,10 +395,9 @@ class ActorClass(object):
|
||||
if actor_method_cpu == 1:
|
||||
actor_placement_resources = resources.copy()
|
||||
actor_placement_resources["CPU"] += 1
|
||||
|
||||
function_signature = meta.method_signatures[function_name]
|
||||
creation_args = signature.extend_args(function_signature, args,
|
||||
kwargs)
|
||||
creation_args = signature.flatten_args(function_signature, args,
|
||||
kwargs)
|
||||
actor_id = worker.core_worker.create_actor(
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
creation_args, meta.max_reconstructions, resources,
|
||||
@@ -499,25 +496,24 @@ class ActorHandle(object):
|
||||
worker.check_connected()
|
||||
|
||||
function_signature = self._ray_method_signatures[method_name]
|
||||
if args is None:
|
||||
args = []
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
args = signature.extend_args(function_signature, args, kwargs)
|
||||
args = args or []
|
||||
kwargs = kwargs or {}
|
||||
|
||||
list_args = signature.flatten_args(function_signature, args, kwargs)
|
||||
function_descriptor = FunctionDescriptor(
|
||||
self._ray_module_name, method_name, self._ray_class_name)
|
||||
|
||||
with profiling.profile("submit_task"):
|
||||
if worker.mode == ray.LOCAL_MODE:
|
||||
function = getattr(worker.actors[self._actor_id], method_name)
|
||||
object_ids = worker.local_mode_manager.execute(
|
||||
function, function_descriptor, args, num_return_vals)
|
||||
function, function_descriptor, args, kwargs,
|
||||
num_return_vals)
|
||||
else:
|
||||
object_ids = worker.core_worker.submit_actor_task(
|
||||
self._ray_actor_id,
|
||||
function_descriptor.get_function_descriptor_list(), args,
|
||||
num_return_vals, {"CPU": self._ray_actor_method_cpus})
|
||||
function_descriptor.get_function_descriptor_list(),
|
||||
list_args, num_return_vals,
|
||||
{"CPU": self._ray_actor_method_cpus})
|
||||
|
||||
if len(object_ids) == 1:
|
||||
object_ids = object_ids[0]
|
||||
|
||||
@@ -763,7 +763,7 @@ class FunctionActorManager(object):
|
||||
worker's internal state to record the executed method.
|
||||
"""
|
||||
|
||||
def actor_method_executor(dummy_return_id, actor, *args):
|
||||
def actor_method_executor(dummy_return_id, actor, *args, **kwargs):
|
||||
# Update the actor's task counter to reflect the task we're about
|
||||
# to execute.
|
||||
self._worker.actor_task_counter += 1
|
||||
@@ -771,9 +771,9 @@ class FunctionActorManager(object):
|
||||
# Execute the assigned method and save a checkpoint if necessary.
|
||||
try:
|
||||
if is_class_method(method):
|
||||
method_returns = method(*args)
|
||||
method_returns = method(*args, **kwargs)
|
||||
else:
|
||||
method_returns = method(actor, *args)
|
||||
method_returns = method(actor, *args, **kwargs)
|
||||
except Exception as e:
|
||||
# Save the checkpoint before allowing the method exception
|
||||
# to be thrown, but don't save the checkpoint for actor
|
||||
|
||||
@@ -29,7 +29,8 @@ class LocalModeManager(object):
|
||||
def __init__(self):
|
||||
"""Initialize a LocalModeManager."""
|
||||
|
||||
def execute(self, function, function_descriptor, args, num_return_vals):
|
||||
def execute(self, function, function_descriptor, args, kwargs,
|
||||
num_return_vals):
|
||||
"""Synchronously executes a "remote" function or actor method.
|
||||
|
||||
Stores results directly in the generated and returned
|
||||
@@ -42,6 +43,7 @@ class LocalModeManager(object):
|
||||
function_descriptor: Metadata about the function.
|
||||
args: Arguments to the function. These will not be modified by
|
||||
the function execution.
|
||||
kwargs: Keyword arguments to the function.
|
||||
num_return_vals: Number of expected return values specified in the
|
||||
function's decorator.
|
||||
|
||||
@@ -52,7 +54,7 @@ class LocalModeManager(object):
|
||||
LocalModeObjectID.from_random() for _ in range(num_return_vals)
|
||||
]
|
||||
try:
|
||||
results = function(*copy.deepcopy(args))
|
||||
results = function(*copy.deepcopy(args), **copy.deepcopy(kwargs))
|
||||
if num_return_vals == 1:
|
||||
object_ids[0].value = results
|
||||
else:
|
||||
|
||||
@@ -75,9 +75,9 @@ class RemoteFunction(object):
|
||||
self._decorator = getattr(function, "__ray_invocation_decorator__",
|
||||
None)
|
||||
|
||||
ray.signature.check_signature_supported(self._function)
|
||||
self._function_signature = ray.signature.extract_signature(
|
||||
self._function)
|
||||
|
||||
self._last_export_session_and_job = None
|
||||
# Override task.remote's signature and docstring
|
||||
@wraps(function)
|
||||
@@ -140,17 +140,17 @@ class RemoteFunction(object):
|
||||
memory, object_store_memory, resources)
|
||||
|
||||
def invocation(args, kwargs):
|
||||
args = ray.signature.extend_args(self._function_signature, args,
|
||||
kwargs)
|
||||
list_args = ray.signature.flatten_args(self._function_signature,
|
||||
args, kwargs)
|
||||
|
||||
if worker.mode == ray.worker.LOCAL_MODE:
|
||||
object_ids = worker.local_mode_manager.execute(
|
||||
self._function, self._function_descriptor, args,
|
||||
self._function, self._function_descriptor, args, kwargs,
|
||||
num_return_vals)
|
||||
else:
|
||||
object_ids = worker.core_worker.submit_task(
|
||||
self._function_descriptor.get_function_descriptor_list(),
|
||||
args, num_return_vals, resources)
|
||||
list_args, num_return_vals, resources)
|
||||
|
||||
if len(object_ids) == 1:
|
||||
return object_ids[0]
|
||||
|
||||
+134
-131
@@ -14,30 +14,34 @@ from ray.utils import is_cython
|
||||
# entry/init points.
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FunctionSignature = namedtuple("FunctionSignature", [
|
||||
"arg_names", "arg_defaults", "arg_is_positionals", "keyword_names",
|
||||
"function_name"
|
||||
])
|
||||
"""This class is used to represent a function signature.
|
||||
RayParameter = namedtuple(
|
||||
"RayParameter",
|
||||
["name", "kind_int", "default", "annotation", "partial_kwarg"])
|
||||
"""This class is used to represent a function parameter in Ray.
|
||||
|
||||
Note that this is different from the funcsigs.Parameter object because
|
||||
we replace the funcsigs ParameterKind with an int. This is needed because
|
||||
ParameterKind objects are currently non-serializable and the package is not
|
||||
being updated. Replacement is done in `_scrub_parameters` and
|
||||
`_restore_parameters`.
|
||||
|
||||
Attributes:
|
||||
arg_names: A list containing the name of all arguments.
|
||||
arg_defaults: A dictionary mapping from argument name to argument default
|
||||
value. If the argument is not a keyword argument, the default value
|
||||
will be funcsigs._empty.
|
||||
arg_is_positionals: A dictionary mapping from argument name to a bool. The
|
||||
bool will be true if the argument is a *args argument. Otherwise it
|
||||
will be false.
|
||||
keyword_names: A set containing the names of the keyword arguments.
|
||||
Note most arguments in Python can be called as positional or keyword
|
||||
arguments, so this overlaps (sometimes completely) with arg_names.
|
||||
function_name: The name of the function whose signature is being
|
||||
inspected. This is used for printing better error messages.
|
||||
name (str): The name of the parameter as a string.
|
||||
kind (int): Describes how argument values are bound to the parameter. See
|
||||
funcsigs.Parameter and `_convert_to_parameter_kind`.
|
||||
default (object): The default value for the parameter if specified. If the
|
||||
parameter has no default value, this attribute is not set.
|
||||
annotation: The annotation for the parameter if specified. If the
|
||||
parameter has no annotation, this attribute is not set.
|
||||
partial_kwarg (bool): True if the parameter is mapped
|
||||
by 'functools.partial'.
|
||||
"""
|
||||
|
||||
DUMMY_TYPE = "__RAY_DUMMY__"
|
||||
|
||||
def get_signature_params(func):
|
||||
"""Get signature parameters
|
||||
|
||||
def get_signature(func):
|
||||
"""Get signature parameters.
|
||||
|
||||
Support Cython functions by grabbing relevant attributes from the Cython
|
||||
function and attaching to a no-op function. This is somewhat brittle, since
|
||||
@@ -50,6 +54,10 @@ def get_signature_params(func):
|
||||
Args:
|
||||
func: The function whose signature should be checked.
|
||||
|
||||
Returns:
|
||||
A function signature object, which includes the names of the keyword
|
||||
arguments as well as their default values.
|
||||
|
||||
Raises:
|
||||
TypeError: A type error if the signature is not supported
|
||||
"""
|
||||
@@ -72,51 +80,7 @@ def get_signature_params(func):
|
||||
raise TypeError("{!r} is not a Python function we can process"
|
||||
.format(func))
|
||||
|
||||
return list(funcsigs.signature(func).parameters.items())
|
||||
|
||||
|
||||
def check_signature_supported(func, warn=False):
|
||||
"""Check if we support the signature of this function.
|
||||
|
||||
We currently do not allow remote functions to have **kwargs. We also do not
|
||||
support keyword arguments in conjunction with a *args argument.
|
||||
|
||||
Args:
|
||||
func: The function whose signature should be checked.
|
||||
warn: If this is true, a warning will be printed if the signature is
|
||||
not supported. If it is false, an exception will be raised if the
|
||||
signature is not supported.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the signature is not supported.
|
||||
"""
|
||||
function_name = func.__name__
|
||||
sig_params = get_signature_params(func)
|
||||
|
||||
has_kwargs_param = False
|
||||
has_kwonly_param = False
|
||||
for keyword_name, parameter in sig_params:
|
||||
if parameter.kind == Parameter.VAR_KEYWORD:
|
||||
has_kwargs_param = True
|
||||
if parameter.kind == Parameter.KEYWORD_ONLY:
|
||||
has_kwonly_param = True
|
||||
|
||||
if has_kwargs_param:
|
||||
message = ("The function {} has a **kwargs argument, which is "
|
||||
"currently not supported.".format(function_name))
|
||||
if warn:
|
||||
logger.debug(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
|
||||
if has_kwonly_param:
|
||||
message = ("The function {} has a keyword only argument "
|
||||
"(defined after * or *args), which is currently "
|
||||
"not supported.".format(function_name))
|
||||
if warn:
|
||||
logger.debug(message)
|
||||
else:
|
||||
raise Exception(message)
|
||||
return funcsigs.signature(func)
|
||||
|
||||
|
||||
def extract_signature(func, ignore_first=False):
|
||||
@@ -128,95 +92,134 @@ def extract_signature(func, ignore_first=False):
|
||||
be used when func is a method of a class.
|
||||
|
||||
Returns:
|
||||
A function signature object, which includes the names of the keyword
|
||||
arguments as well as their default values.
|
||||
List of RayParameter objects representing the function signature.
|
||||
"""
|
||||
sig_params = get_signature_params(func)
|
||||
signature_parameters = list(get_signature(func).parameters.values())
|
||||
|
||||
if ignore_first:
|
||||
if len(sig_params) == 0:
|
||||
if len(signature_parameters) == 0:
|
||||
raise Exception("Methods must take a 'self' argument, but the "
|
||||
"method '{}' does not have one.".format(
|
||||
func.__name__))
|
||||
sig_params = sig_params[1:]
|
||||
signature_parameters = signature_parameters[1:]
|
||||
|
||||
# Construct the argument default values and other argument information.
|
||||
arg_names = []
|
||||
arg_defaults = []
|
||||
arg_is_positionals = []
|
||||
keyword_names = set()
|
||||
for arg_name, parameter in sig_params:
|
||||
arg_names.append(arg_name)
|
||||
arg_defaults.append(parameter.default)
|
||||
arg_is_positionals.append(parameter.kind == parameter.VAR_POSITIONAL)
|
||||
if parameter.kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
# Note KEYWORD_ONLY arguments currently unsupported.
|
||||
keyword_names.add(arg_name)
|
||||
|
||||
return FunctionSignature(arg_names, arg_defaults, arg_is_positionals,
|
||||
keyword_names, func.__name__)
|
||||
return _scrub_parameters(signature_parameters)
|
||||
|
||||
|
||||
def extend_args(function_signature, args, kwargs):
|
||||
"""Extend the arguments that were passed into a function.
|
||||
def flatten_args(signature_parameters, args, kwargs):
|
||||
"""Validates the arguments against the signature and flattens them.
|
||||
|
||||
This extends the arguments that were passed into a function with the
|
||||
default arguments provided in the function definition.
|
||||
The flat list representation is a serializable format for arguments.
|
||||
Since the flatbuffer representation of function arguments is a list, we
|
||||
combine both keyword arguments and positional arguments. We represent
|
||||
this with two entries per argument value - [DUMMY_TYPE, x] for positional
|
||||
arguments and [KEY, VALUE] for keyword arguments. See the below example.
|
||||
See `recover_args` for logic restoring the flat list back to args/kwargs.
|
||||
|
||||
Args:
|
||||
function_signature: The function signature of the function being
|
||||
called.
|
||||
signature_parameters (list): The list of RayParameter objects
|
||||
representing the function signature, obtained from
|
||||
`extract_signature`.
|
||||
args: The non-keyword arguments passed into the function.
|
||||
kwargs: The keyword arguments passed into the function.
|
||||
|
||||
Returns:
|
||||
An extended list of arguments to pass into the function.
|
||||
List of args and kwargs. Non-keyword arguments are prefixed
|
||||
by internal enum DUMMY_TYPE.
|
||||
|
||||
Raises:
|
||||
Exception: An exception may be raised if the function cannot be called
|
||||
with these arguments.
|
||||
TypeError: Raised if arguments do not fit in the function signature.
|
||||
|
||||
Example:
|
||||
>>> flatten_args([1, 2, 3], {"a": 4})
|
||||
[None, 1, None, 2, None, 3, "a", 4]
|
||||
"""
|
||||
arg_names = function_signature.arg_names
|
||||
arg_defaults = function_signature.arg_defaults
|
||||
arg_is_positionals = function_signature.arg_is_positionals
|
||||
keyword_names = function_signature.keyword_names
|
||||
function_name = function_signature.function_name
|
||||
restored = _restore_parameters(signature_parameters)
|
||||
reconstructed_signature = funcsigs.Signature(parameters=restored)
|
||||
try:
|
||||
reconstructed_signature.bind(*args, **kwargs)
|
||||
except TypeError as exc:
|
||||
raise TypeError(str(exc))
|
||||
list_args = []
|
||||
for arg in args:
|
||||
list_args += [DUMMY_TYPE, arg]
|
||||
|
||||
args = list(args)
|
||||
for keyword, arg in kwargs.items():
|
||||
list_args += [keyword, arg]
|
||||
return list_args
|
||||
|
||||
for keyword_name in kwargs:
|
||||
if keyword_name not in keyword_names:
|
||||
raise Exception("The name '{}' is not a valid keyword argument "
|
||||
"for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
|
||||
# Fill in the remaining arguments.
|
||||
for skipped_name in arg_names[0:len(args)]:
|
||||
if skipped_name in kwargs:
|
||||
raise Exception("Positional and keyword value provided for the "
|
||||
"argument '{}' for the function '{}'".format(
|
||||
keyword_name, function_name))
|
||||
def recover_args(flattened_args):
|
||||
"""Recreates `args` and `kwargs` from the flattened arg list.
|
||||
|
||||
zipped_info = zip(arg_names, arg_defaults, arg_is_positionals)
|
||||
zipped_info = list(zipped_info)[len(args):]
|
||||
for keyword_name, default_value, is_positional in zipped_info:
|
||||
if keyword_name in kwargs:
|
||||
args.append(kwargs[keyword_name])
|
||||
Args:
|
||||
flattened_args: List of args and kwargs. This should be the output of
|
||||
`flatten_args`.
|
||||
|
||||
Returns:
|
||||
args: The non-keyword arguments passed into the function.
|
||||
kwargs: The keyword arguments passed into the function.
|
||||
"""
|
||||
assert len(flattened_args) % 2 == 0, (
|
||||
"Flattened arguments need to be even-numbered. See `flatten_args`.")
|
||||
args = []
|
||||
kwargs = {}
|
||||
for name_index in range(0, len(flattened_args), 2):
|
||||
name, arg = flattened_args[name_index], flattened_args[name_index + 1]
|
||||
if name == DUMMY_TYPE:
|
||||
args.append(arg)
|
||||
else:
|
||||
if default_value != funcsigs._empty:
|
||||
args.append(default_value)
|
||||
else:
|
||||
# This means that there is a missing argument. Unless this is
|
||||
# the last argument and it is a *args argument in which case it
|
||||
# can be omitted.
|
||||
if not is_positional:
|
||||
raise Exception("No value was provided for the argument "
|
||||
"'{}' for the function '{}'.".format(
|
||||
keyword_name, function_name))
|
||||
kwargs[name] = arg
|
||||
|
||||
no_positionals = len(arg_is_positionals) == 0 or not arg_is_positionals[-1]
|
||||
too_many_arguments = len(args) > len(arg_names) and no_positionals
|
||||
if too_many_arguments:
|
||||
raise Exception("Too many arguments were passed to the function '{}'"
|
||||
.format(function_name))
|
||||
return args
|
||||
return args, kwargs
|
||||
|
||||
|
||||
def _scrub_parameters(parameters):
|
||||
"""Returns a scrubbed list of RayParameters."""
|
||||
return [
|
||||
RayParameter(
|
||||
name=param.name,
|
||||
kind_int=_convert_from_parameter_kind(param.kind),
|
||||
default=param.default,
|
||||
annotation=param.annotation,
|
||||
partial_kwarg=param._partial_kwarg) for param in parameters
|
||||
]
|
||||
|
||||
|
||||
def _restore_parameters(ray_parameters):
|
||||
"""Reconstructs the funcsigs.Parameter objects."""
|
||||
return [
|
||||
Parameter(
|
||||
rayparam.name,
|
||||
_convert_to_parameter_kind(rayparam.kind_int),
|
||||
default=rayparam.default,
|
||||
annotation=rayparam.annotation,
|
||||
_partial_kwarg=rayparam.partial_kwarg)
|
||||
for rayparam in ray_parameters
|
||||
]
|
||||
|
||||
|
||||
def _convert_from_parameter_kind(kind):
|
||||
if kind == Parameter.POSITIONAL_ONLY:
|
||||
return 0
|
||||
if kind == Parameter.POSITIONAL_OR_KEYWORD:
|
||||
return 1
|
||||
if kind == Parameter.VAR_POSITIONAL:
|
||||
return 2
|
||||
if kind == Parameter.KEYWORD_ONLY:
|
||||
return 3
|
||||
if kind == Parameter.VAR_KEYWORD:
|
||||
return 4
|
||||
|
||||
|
||||
def _convert_to_parameter_kind(value):
|
||||
if value == 0:
|
||||
return Parameter.POSITIONAL_ONLY
|
||||
if value == 1:
|
||||
return Parameter.POSITIONAL_OR_KEYWORD
|
||||
if value == 2:
|
||||
return Parameter.VAR_POSITIONAL
|
||||
if value == 3:
|
||||
return Parameter.KEYWORD_ONLY
|
||||
if value == 4:
|
||||
return Parameter.VAR_KEYWORD
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
# coding: utf-8
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
import ray.tests.cluster_utils
|
||||
import ray.tests.utils
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": True
|
||||
}, {
|
||||
"local_mode": False
|
||||
}],
|
||||
indirect=True)
|
||||
def test_args_force_positional(ray_start_regular):
|
||||
def force_positional(*, a="hello", b="helxo", **kwargs):
|
||||
return a, b, kwargs
|
||||
|
||||
class TestActor():
|
||||
def force_positional(self, a="hello", b="heo", *args, **kwargs):
|
||||
return a, b, args, kwargs
|
||||
|
||||
def test_function(fn, remote_fn):
|
||||
assert fn(a=1, b=3, c=5) == ray.get(remote_fn.remote(a=1, b=3, c=5))
|
||||
assert fn(a=1) == ray.get(remote_fn.remote(a=1))
|
||||
assert fn(a=1) == ray.get(remote_fn.remote(a=1))
|
||||
|
||||
remote_test_function = ray.remote(test_function)
|
||||
|
||||
remote_force_positional = ray.remote(force_positional)
|
||||
test_function(force_positional, remote_force_positional)
|
||||
ray.get(
|
||||
remote_test_function.remote(force_positional, remote_force_positional))
|
||||
|
||||
remote_actor_class = ray.remote(TestActor)
|
||||
remote_actor = remote_actor_class.remote()
|
||||
actor_method = remote_actor.force_positional
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.force_positional
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": False
|
||||
}, {
|
||||
"local_mode": True
|
||||
}],
|
||||
indirect=True)
|
||||
def test_args_intertwined(ray_start_regular):
|
||||
def args_intertwined(a, *args, x="hello", **kwargs):
|
||||
return a, args, x, kwargs
|
||||
|
||||
class TestActor():
|
||||
def args_intertwined(self, a, *args, x="hello", **kwargs):
|
||||
return a, args, x, kwargs
|
||||
|
||||
@classmethod
|
||||
def cls_args_intertwined(cls, a, *args, x="hello", **kwargs):
|
||||
return a, args, x, kwargs
|
||||
|
||||
def test_function(fn, remote_fn):
|
||||
assert fn(
|
||||
1, 2, 3, x="hi", y="hello") == ray.get(
|
||||
remote_fn.remote(1, 2, 3, x="hi", y="hello"))
|
||||
assert fn(
|
||||
1, 2, 3, y="1hello") == ray.get(
|
||||
remote_fn.remote(1, 2, 3, y="1hello"))
|
||||
assert fn(1, y="1hello") == ray.get(remote_fn.remote(1, y="1hello"))
|
||||
|
||||
remote_test_function = ray.remote(test_function)
|
||||
|
||||
remote_args_intertwined = ray.remote(args_intertwined)
|
||||
test_function(args_intertwined, remote_args_intertwined)
|
||||
ray.get(
|
||||
remote_test_function.remote(args_intertwined, remote_args_intertwined))
|
||||
|
||||
remote_actor_class = ray.remote(TestActor)
|
||||
remote_actor = remote_actor_class.remote()
|
||||
actor_method = remote_actor.args_intertwined
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.args_intertwined
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
actor_method = remote_actor.cls_args_intertwined
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.cls_args_intertwined
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
+120
-13
@@ -29,6 +29,7 @@ import pickle
|
||||
import pytest
|
||||
|
||||
import ray
|
||||
from ray import signature
|
||||
import ray.ray_constants as ray_constants
|
||||
import ray.tests.cluster_utils
|
||||
import ray.tests.utils
|
||||
@@ -781,6 +782,121 @@ def test_keyword_args(ray_start_regular):
|
||||
assert ray.get(f3.remote(4)) == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": True
|
||||
}, {
|
||||
"local_mode": False
|
||||
}],
|
||||
indirect=True)
|
||||
def test_args_starkwargs(ray_start_regular):
|
||||
def starkwargs(a, b, **kwargs):
|
||||
return a, b, kwargs
|
||||
|
||||
class TestActor(object):
|
||||
def starkwargs(self, a, b, **kwargs):
|
||||
return a, b, kwargs
|
||||
|
||||
def test_function(fn, remote_fn):
|
||||
assert fn(1, 2, x=3) == ray.get(remote_fn.remote(1, 2, x=3))
|
||||
with pytest.raises(TypeError):
|
||||
remote_fn.remote(3)
|
||||
|
||||
remote_test_function = ray.remote(test_function)
|
||||
|
||||
remote_starkwargs = ray.remote(starkwargs)
|
||||
test_function(starkwargs, remote_starkwargs)
|
||||
ray.get(remote_test_function.remote(starkwargs, remote_starkwargs))
|
||||
|
||||
remote_actor_class = ray.remote(TestActor)
|
||||
remote_actor = remote_actor_class.remote()
|
||||
actor_method = remote_actor.starkwargs
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.starkwargs
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": True
|
||||
}, {
|
||||
"local_mode": False
|
||||
}],
|
||||
indirect=True)
|
||||
def test_args_named_and_star(ray_start_regular):
|
||||
def hello(a, x="hello", **kwargs):
|
||||
return a, x, kwargs
|
||||
|
||||
class TestActor(object):
|
||||
def hello(self, a, x="hello", **kwargs):
|
||||
return a, x, kwargs
|
||||
|
||||
def test_function(fn, remote_fn):
|
||||
assert fn(1, x=2, y=3) == ray.get(remote_fn.remote(1, x=2, y=3))
|
||||
assert fn(1, 2, y=3) == ray.get(remote_fn.remote(1, 2, y=3))
|
||||
assert fn(1, y=3) == ray.get(remote_fn.remote(1, y=3))
|
||||
|
||||
assert fn(1, ) == ray.get(remote_fn.remote(1, ))
|
||||
assert fn(1) == ray.get(remote_fn.remote(1))
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
remote_fn.remote(1, 2, x=3)
|
||||
|
||||
remote_test_function = ray.remote(test_function)
|
||||
|
||||
remote_hello = ray.remote(hello)
|
||||
test_function(hello, remote_hello)
|
||||
ray.get(remote_test_function.remote(hello, remote_hello))
|
||||
|
||||
remote_actor_class = ray.remote(TestActor)
|
||||
remote_actor = remote_actor_class.remote()
|
||||
actor_method = remote_actor.hello
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.hello
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_regular", [{
|
||||
"local_mode": True
|
||||
}, {
|
||||
"local_mode": False
|
||||
}],
|
||||
indirect=True)
|
||||
def test_args_stars_after(ray_start_regular):
|
||||
def star_args_after(a="hello", b="heo", *args, **kwargs):
|
||||
return a, b, args, kwargs
|
||||
|
||||
class TestActor(object):
|
||||
def star_args_after(self, a="hello", b="heo", *args, **kwargs):
|
||||
return a, b, args, kwargs
|
||||
|
||||
def test_function(fn, remote_fn):
|
||||
assert fn("hi", "hello", 2) == ray.get(
|
||||
remote_fn.remote("hi", "hello", 2))
|
||||
assert fn(
|
||||
"hi", "hello", 2, hi="hi") == ray.get(
|
||||
remote_fn.remote("hi", "hello", 2, hi="hi"))
|
||||
assert fn(hi="hi") == ray.get(remote_fn.remote(hi="hi"))
|
||||
|
||||
remote_test_function = ray.remote(test_function)
|
||||
|
||||
remote_star_args_after = ray.remote(star_args_after)
|
||||
test_function(star_args_after, remote_star_args_after)
|
||||
ray.get(
|
||||
remote_test_function.remote(star_args_after, remote_star_args_after))
|
||||
|
||||
remote_actor_class = ray.remote(TestActor)
|
||||
remote_actor = remote_actor_class.remote()
|
||||
actor_method = remote_actor.star_args_after
|
||||
local_actor = TestActor()
|
||||
local_method = local_actor.star_args_after
|
||||
test_function(local_method, actor_method)
|
||||
ray.get(remote_test_function.remote(local_method, actor_method))
|
||||
|
||||
|
||||
def test_variable_number_of_args(shutdown_only):
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
@@ -790,16 +906,6 @@ def test_variable_number_of_args(shutdown_only):
|
||||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
|
||||
try:
|
||||
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
|
||||
kwargs_exception_thrown = False
|
||||
except Exception:
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
x = varargs_fct1.remote(0, 1, 2)
|
||||
@@ -807,8 +913,6 @@ def test_variable_number_of_args(shutdown_only):
|
||||
x = varargs_fct2.remote(0, 1, 2)
|
||||
assert ray.get(x) == "1 2"
|
||||
|
||||
assert kwargs_exception_thrown
|
||||
|
||||
@ray.remote
|
||||
def f1(*args):
|
||||
return args
|
||||
@@ -2688,7 +2792,10 @@ def test_global_state_api(shutdown_only):
|
||||
|
||||
task_spec = task_table[task_id]["TaskSpec"]
|
||||
assert task_spec["ActorID"] == nil_actor_id_hex
|
||||
assert task_spec["Args"] == [1, "hi", x_id]
|
||||
assert task_spec["Args"] == [
|
||||
signature.DUMMY_TYPE, 1, signature.DUMMY_TYPE, "hi",
|
||||
signature.DUMMY_TYPE, x_id
|
||||
]
|
||||
assert task_spec["JobID"] == job_id.hex()
|
||||
assert task_spec["ReturnObjectIDs"] == [result_id]
|
||||
|
||||
|
||||
+13
-10
@@ -677,7 +677,7 @@ class Worker(object):
|
||||
else:
|
||||
arguments[object_indices[i]] = value
|
||||
|
||||
return arguments
|
||||
return ray.signature.recover_args(arguments)
|
||||
|
||||
def _store_outputs_in_object_store(self, object_ids, outputs):
|
||||
"""Store the outputs of a remote function in the local object store.
|
||||
@@ -744,7 +744,7 @@ class Worker(object):
|
||||
|
||||
function_descriptor = FunctionDescriptor.from_bytes_list(
|
||||
task.function_descriptor_list())
|
||||
args = task.arguments()
|
||||
serialized_args = task.arguments()
|
||||
return_object_ids = task.returns()
|
||||
if task.is_actor_task() or task.is_actor_creation_task():
|
||||
dummy_return_id = return_object_ids.pop()
|
||||
@@ -757,8 +757,9 @@ class Worker(object):
|
||||
self.reraise_actor_init_error()
|
||||
self.memory_monitor.raise_if_low_memory()
|
||||
with profiling.profile("task:deserialize_arguments"):
|
||||
arguments = self._get_arguments_for_execution(
|
||||
function_name, args)
|
||||
function_args, function_kwargs = (
|
||||
self._get_arguments_for_execution(function_name,
|
||||
serialized_args))
|
||||
except Exception as e:
|
||||
self._handle_process_task_failure(
|
||||
function_descriptor, return_object_ids, e,
|
||||
@@ -770,7 +771,8 @@ class Worker(object):
|
||||
self._current_task = task
|
||||
with profiling.profile("task:execute"):
|
||||
if task.is_normal_task():
|
||||
outputs = function_executor(*arguments)
|
||||
outputs = function_executor(*function_args,
|
||||
**function_kwargs)
|
||||
else:
|
||||
if task.is_actor_task():
|
||||
key = task.actor_id()
|
||||
@@ -790,8 +792,9 @@ class Worker(object):
|
||||
ray_constants.from_memory_units(
|
||||
task.required_resources()[
|
||||
"object_store_memory"])))
|
||||
outputs = function_executor(dummy_return_id,
|
||||
self.actors[key], *arguments)
|
||||
outputs = function_executor(
|
||||
dummy_return_id, self.actors[key], *function_args,
|
||||
**function_kwargs)
|
||||
except Exception as e:
|
||||
# Determine whether the exception occured during a task, not an
|
||||
# actor method.
|
||||
@@ -1115,14 +1118,14 @@ def _initialize_serialization(job_id, worker=global_worker):
|
||||
local=True,
|
||||
job_id=job_id,
|
||||
class_id="type")
|
||||
# Tell Ray to serialize FunctionSignatures as dictionaries. This is
|
||||
# Tell Ray to serialize RayParameters as dictionaries. This is
|
||||
# used when passing around actor handles.
|
||||
_register_custom_serializer(
|
||||
ray.signature.FunctionSignature,
|
||||
ray.signature.RayParameter,
|
||||
use_dict=True,
|
||||
local=True,
|
||||
job_id=job_id,
|
||||
class_id="ray.signature.FunctionSignature")
|
||||
class_id="ray.signature.RayParameter")
|
||||
# Tell Ray to serialize StringIO with pickle. We do this because
|
||||
# Ray's default __dict__ serialization is incorrect for this type
|
||||
# (the object's __dict__ is empty and therefore doesn't
|
||||
|
||||
Reference in New Issue
Block a user