From fb7ccef49309eeacb6aae65f5e603ea11b1a31ff Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Tue, 30 Aug 2016 15:14:02 -0700 Subject: [PATCH] Allow remote decorator to be used with no parentheses. --- README.md | 2 +- doc/aliasing.md | 6 +- doc/remote-functions.md | 2 +- doc/tutorial.md | 12 +- examples/alexnet/README.md | 2 +- examples/alexnet/alexnet.py | 10 +- examples/hyperopt/README.md | 2 +- examples/hyperopt/hyperopt.py | 2 +- examples/lbfgs/README.md | 4 +- examples/lbfgs/driver.py | 4 +- examples/trpo/README.md | 2 +- lib/python/ray/array/distributed/core.py | 28 ++--- lib/python/ray/array/distributed/linalg.py | 6 +- lib/python/ray/array/distributed/random.py | 2 +- lib/python/ray/array/remote/core.py | 36 +++--- lib/python/ray/array/remote/linalg.py | 26 ++-- lib/python/ray/array/remote/random.py | 2 +- lib/python/ray/serialization.py | 2 +- lib/python/ray/worker.py | 136 ++++++++++++--------- test/failure_test.py | 4 +- test/runtest.py | 42 +++---- test/test_functions.py | 36 +++--- 22 files changed, 193 insertions(+), 175 deletions(-) diff --git a/README.md b/README.md index 81d5f4766..3606914c0 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ import numpy as np ray.init(start_ray_local=True, num_workers=10) # Define a remote function for estimating pi. -@ray.remote() +@ray.remote def estimate_pi(n): x = np.random.uniform(size=n) y = np.random.uniform(size=n) diff --git a/doc/aliasing.md b/doc/aliasing.md index 5ff153766..d2763d02b 100644 --- a/doc/aliasing.md +++ b/doc/aliasing.md @@ -10,15 +10,15 @@ However, to provide a more flexible API, we allow tasks to not only return values, but to also return object ids to values. As an examples, consider the following code. ```python -@ray.remote() +@ray.remote def f() return np.zeros(5) -@ray.remote() +@ray.remote def g() return f() -@ray.remote() +@ray.remote def h() return g() ``` diff --git a/doc/remote-functions.md b/doc/remote-functions.md index 8ea56552f..29590ad20 100644 --- a/doc/remote-functions.md +++ b/doc/remote-functions.md @@ -5,7 +5,7 @@ functions. Remote functions are written like regular Python functions, but with the `@ray.remote` decorator on top. ```python -@ray.remote() +@ray.remote def increment(n): return n + 1 ``` diff --git a/doc/tutorial.md b/doc/tutorial.md index aea656175..f031f0e50 100644 --- a/doc/tutorial.md +++ b/doc/tutorial.md @@ -107,7 +107,7 @@ def add(a, b): ``` A remote function in Ray looks like this. ```python -@ray.remote() +@ray.remote def add(a, b): return a + b ``` @@ -194,7 +194,7 @@ around `time.sleep`. ```python import time -@ray.remote() +@ray.remote def sleep(n): time.sleep(n) return 0 @@ -245,11 +245,11 @@ Computation graphs encode dependencies. For example, suppose we define ```python import numpy as np -@ray.remote() +@ray.remote def zeros(shape): return np.zeros(shape) -@ray.remote() +@ray.remote def dot(a, b): return np.dot(a, b) ``` @@ -282,12 +282,12 @@ processes can also call remote functions. To illustrate this, consider the following example. ```python -@ray.remote() +@ray.remote def sub_experiment(i, j): # Run the jth sub-experiment for the ith experiment. return i + j -@ray.remote() +@ray.remote def run_experiment(i): sub_results = [] # Launch tasks to perform 10 sub-experiments in parallel. diff --git a/examples/alexnet/README.md b/examples/alexnet/README.md index dc6be5c83..23e7da2b5 100644 --- a/examples/alexnet/README.md +++ b/examples/alexnet/README.md @@ -85,7 +85,7 @@ The other parallel component of this application is the training procedure. This is built on top of the remote function `compute_grad`. ```python -@ray.remote() +@ray.remote def compute_grad(X, Y, mean, weights): # Load the weights into the network. # Subtract the mean and crop the images. diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index 8fa7a08bb..c2211930e 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -230,7 +230,7 @@ def net_initialization(): def net_reinitialization(net_vars): return net_vars -@ray.remote() +@ray.remote def num_images(batches): """Counts number of images in batches. @@ -243,7 +243,7 @@ def num_images(batches): shape_ids = [ra.shape.remote(batch) for batch in batches] return sum([ray.get(shape_id)[0] for shape_id in shape_ids]) -@ray.remote() +@ray.remote def compute_mean_image(batches): """Computes the mean image given a list of batches of images. @@ -305,7 +305,7 @@ def shuffle_pair(first_batch, second_batch): images1, labels1, images2, labels2 = shuffle_arrays.remote(first_batch[0], first_batch[1], second_batch[0], second_batch[1]) return (images1, labels1), (images2, labels2) -@ray.remote() +@ray.remote def filenames_to_labels(filenames, filename_label_dict): """Converts filename strings to integer labels. @@ -380,7 +380,7 @@ def shuffle(batches): new_batches.append(permuted_batches[-1]) return new_batches -@ray.remote() +@ray.remote def compute_grad(X, Y, mean, weights): """Computes the gradient of the network. @@ -405,7 +405,7 @@ def compute_grad(X, Y, mean, weights): # Compute the gradients. return sess.run([g for (g, v) in comp_grads], feed_dict={images: subset_X, y_true: subset_Y, dropout: 0.5}) -@ray.remote() +@ray.remote def compute_accuracy(X, Y, weights): """Returns the accuracy of the network diff --git a/examples/hyperopt/README.md b/examples/hyperopt/README.md index 82258d307..cf3ea399d 100644 --- a/examples/hyperopt/README.md +++ b/examples/hyperopt/README.md @@ -82,7 +82,7 @@ complicated version of this remote function is defined in [hyperopt.py](hyperopt.py). ```python -@ray.remote() +@ray.remote def train_cnn_and_compute_accuracy(hyperparameters, train_images, train_labels, validation_images, validation_labels): # Actual work omitted. return validation_accuracy diff --git a/examples/hyperopt/hyperopt.py b/examples/hyperopt/hyperopt.py index f8b4b07bf..9d3f8e51b 100644 --- a/examples/hyperopt/hyperopt.py +++ b/examples/hyperopt/hyperopt.py @@ -51,7 +51,7 @@ def cnn_setup(x, y, keep_prob, lr, stddev): # Define a remote function that takes a set of hyperparameters as well as the # data, consructs and trains a network, and returns the validation accuracy. -@ray.remote() +@ray.remote def train_cnn_and_compute_accuracy(params, steps, train_images, train_labels, validation_images, validation_labels): # Extract the hyperparameters from the params dictionary. learning_rate = params["learning_rate"] diff --git a/examples/lbfgs/README.md b/examples/lbfgs/README.md index 4af9f267c..addfc5dbc 100644 --- a/examples/lbfgs/README.md +++ b/examples/lbfgs/README.md @@ -91,12 +91,12 @@ use remote functions to distribute the loading of the data. Now, lets turn `loss` and `grad` into remote functions. ```python -@ray.remote() +@ray.remote def loss(theta, xs, ys): # compute the loss return loss -@ray.remote() +@ray.remote def grad(theta, xs, ys): # compute the gradient return grad diff --git a/examples/lbfgs/driver.py b/examples/lbfgs/driver.py index 9a5de5d84..5d7fff6fd 100644 --- a/examples/lbfgs/driver.py +++ b/examples/lbfgs/driver.py @@ -74,14 +74,14 @@ if __name__ == "__main__": sess.run([update_w, update_b], feed_dict={w_new: theta[:w_size].reshape(w_shape), b_new: theta[w_size:]}) # Compute the loss on a batch of data. - @ray.remote() + @ray.remote def loss(theta, xs, ys): sess, _, _, cross_entropy, _, x, y_, _, _ = ray.reusables.net_vars load_weights(theta) return float(sess.run(cross_entropy, feed_dict={x: xs, y_: ys})) # Compute the gradient of the loss on a batch of data. - @ray.remote() + @ray.remote def grad(theta, xs, ys): sess, _, _, _, cross_entropy_grads, x, y_, _, _ = ray.reusables.net_vars load_weights(theta) diff --git a/examples/trpo/README.md b/examples/trpo/README.md index a1fd8aa1d..053c02e3e 100644 --- a/examples/trpo/README.md +++ b/examples/trpo/README.md @@ -79,7 +79,7 @@ we use reusable variables to store the gym environment and the neural network po then used in the remote `do_rollout` function to do a remote rollout: ```python -@ray.remote() +@ray.remote def do_rollout(policy, timestep_limit, seed): # Retrieve the game environment. env = ray.reusables.env diff --git a/lib/python/ray/array/distributed/core.py b/lib/python/ray/array/distributed/core.py index 3938160ec..4e6b41ef6 100644 --- a/lib/python/ray/array/distributed/core.py +++ b/lib/python/ray/array/distributed/core.py @@ -65,12 +65,12 @@ class DistArray(object): a = self.assemble() return a[sliced] -@ray.remote() +@ray.remote def assemble(a): return a.assemble() # TODO(rkn): what should we call this method -@ray.remote() +@ray.remote def numpy_to_dist(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): @@ -79,28 +79,28 @@ def numpy_to_dist(a): result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]]) return result -@ray.remote() +@ray.remote def zeros(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result -@ray.remote() +@ray.remote def ones(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result -@ray.remote() +@ray.remote def copy(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable. return result -@ray.remote() +@ray.remote def eye(dim1, dim2=-1, dtype_name="float"): dim2 = dim1 if dim2 == -1 else dim2 shape = [dim1, dim2] @@ -113,7 +113,7 @@ def eye(dim1, dim2=-1, dtype_name="float"): result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name) return result -@ray.remote() +@ray.remote def triu(a): if a.ndim != 2: raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) @@ -127,7 +127,7 @@ def triu(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result -@ray.remote() +@ray.remote def tril(a): if a.ndim != 2: raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) @@ -141,7 +141,7 @@ def tril(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result -@ray.remote() +@ray.remote def blockwise_dot(*matrices): n = len(matrices) if n % 2 != 0: @@ -152,7 +152,7 @@ def blockwise_dot(*matrices): result += np.dot(matrices[i], matrices[n / 2 + i]) return result -@ray.remote() +@ray.remote def dot(a, b): if a.ndim != 2: raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim)) @@ -167,7 +167,7 @@ def dot(a, b): result.objectids[i, j] = blockwise_dot.remote(*args) return result -@ray.remote() +@ray.remote def subblocks(a, *ranges): """ This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example, @@ -197,7 +197,7 @@ def subblocks(a, *ranges): result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])] return result -@ray.remote() +@ray.remote def transpose(a): if a.ndim != 2: raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape)) @@ -208,7 +208,7 @@ def transpose(a): return result # TODO(rkn): support broadcasting? -@ray.remote() +@ray.remote def add(x1, x2): if x1.shape != x2.shape: raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) @@ -218,7 +218,7 @@ def add(x1, x2): return result # TODO(rkn): support broadcasting? -@ray.remote() +@ray.remote def subtract(x1, x2): if x1.shape != x2.shape: raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) diff --git a/lib/python/ray/array/distributed/linalg.py b/lib/python/ray/array/distributed/linalg.py index 5e883da67..270441999 100644 --- a/lib/python/ray/array/distributed/linalg.py +++ b/lib/python/ray/array/distributed/linalg.py @@ -112,7 +112,7 @@ def tsqr_hr_helper1(u, s, y_top_block, b): t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) return t, y_top -@ray.remote() +@ray.remote def tsqr_hr_helper2(s, r_temp): s_full = np.diag(s) return np.dot(s_full, r_temp) @@ -127,11 +127,11 @@ def tsqr_hr(a): r = tsqr_hr_helper2.remote(s, r_temp) return y, t, y_top, r -@ray.remote() +@ray.remote def qr_helper1(a_rc, y_ri, t, W_c): return a_rc - np.dot(y_ri, np.dot(t.T, W_c)) -@ray.remote() +@ray.remote def qr_helper2(y_ri, a_rc): return np.dot(y_ri.T, a_rc) diff --git a/lib/python/ray/array/distributed/random.py b/lib/python/ray/array/distributed/random.py index 4f905671e..c7b44a713 100644 --- a/lib/python/ray/array/distributed/random.py +++ b/lib/python/ray/array/distributed/random.py @@ -4,7 +4,7 @@ import ray from core import * -@ray.remote() +@ray.remote def normal(shape): num_blocks = DistArray.compute_num_blocks(shape) objectids = np.empty(num_blocks, dtype=object) diff --git a/lib/python/ray/array/remote/core.py b/lib/python/ray/array/remote/core.py index b96606780..c8cd3ecf4 100644 --- a/lib/python/ray/array/remote/core.py +++ b/lib/python/ray/array/remote/core.py @@ -3,80 +3,80 @@ import ray __all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"] -@ray.remote() +@ray.remote def zeros(shape, dtype_name="float", order="C"): return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) -@ray.remote() +@ray.remote def zeros_like(a, dtype_name="None", order="K", subok=True): dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) -@ray.remote() +@ray.remote def ones(shape, dtype_name="float", order="C"): return np.ones(shape, dtype=np.dtype(dtype_name), order=order) -@ray.remote() +@ray.remote def eye(N, M=-1, k=0, dtype_name="float"): M = N if M == -1 else M return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) -@ray.remote() +@ray.remote def dot(a, b): return np.dot(a, b) -@ray.remote() +@ray.remote def vstack(*xs): return np.vstack(xs) -@ray.remote() +@ray.remote def hstack(*xs): return np.hstack(xs) # TODO(rkn): instead of this, consider implementing slicing -@ray.remote() +@ray.remote def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices" return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] -@ray.remote() +@ray.remote def copy(a, order="K"): return np.copy(a, order=order) -@ray.remote() +@ray.remote def tril(m, k=0): return np.tril(m, k=k) -@ray.remote() +@ray.remote def triu(m, k=0): return np.triu(m, k=k) -@ray.remote() +@ray.remote def diag(v, k=0): return np.diag(v, k=k) -@ray.remote() +@ray.remote def transpose(a, axes=[]): axes = None if axes == [] else axes return np.transpose(a, axes=axes) -@ray.remote() +@ray.remote def add(x1, x2): return np.add(x1, x2) -@ray.remote() +@ray.remote def subtract(x1, x2): return np.subtract(x1, x2) -@ray.remote() +@ray.remote def sum(x, axis=-1): return np.sum(x, axis=axis if axis != -1 else None) -@ray.remote() +@ray.remote def shape(a): return np.shape(a) # We use Any to allow different numerical types as well as numpy arrays. # TODO(rkn):this isn't in the numpy API, so be careful about exposing this. -@ray.remote() +@ray.remote def sum_list(*xs): return np.sum(xs, axis=0) diff --git a/lib/python/ray/array/remote/linalg.py b/lib/python/ray/array/remote/linalg.py index 8cccf74da..ab01788c3 100644 --- a/lib/python/ray/array/remote/linalg.py +++ b/lib/python/ray/array/remote/linalg.py @@ -6,11 +6,11 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank", "LinAlgError", "multi_dot"] -@ray.remote() +@ray.remote def matrix_power(M, n): return np.linalg.matrix_power(M, n) -@ray.remote() +@ray.remote def solve(a, b): return np.linalg.solve(a, b) @@ -22,31 +22,31 @@ def tensorsolve(a): def tensorinv(a): raise NotImplementedError -@ray.remote() +@ray.remote def inv(a): return np.linalg.inv(a) -@ray.remote() +@ray.remote def cholesky(a): return np.linalg.cholesky(a) -@ray.remote() +@ray.remote def eigvals(a): return np.linalg.eigvals(a) -@ray.remote() +@ray.remote def eigvalsh(a): raise NotImplementedError -@ray.remote() +@ray.remote def pinv(a): return np.linalg.pinv(a) -@ray.remote() +@ray.remote def slogdet(a): raise NotImplementedError -@ray.remote() +@ray.remote def det(a): return np.linalg.det(a) @@ -66,7 +66,7 @@ def eigh(a): def lstsq(a, b): return np.linalg.lstsq(a) -@ray.remote() +@ray.remote def norm(x): return np.linalg.norm(x) @@ -74,14 +74,14 @@ def norm(x): def qr(a): return np.linalg.qr(a) -@ray.remote() +@ray.remote def cond(x): return np.linalg.cond(x) -@ray.remote() +@ray.remote def matrix_rank(M): return np.linalg.matrix_rank(M) -@ray.remote() +@ray.remote def multi_dot(*a): raise NotImplementedError diff --git a/lib/python/ray/array/remote/random.py b/lib/python/ray/array/remote/random.py index e86d321d1..7fb3def0e 100644 --- a/lib/python/ray/array/remote/random.py +++ b/lib/python/ray/array/remote/random.py @@ -1,6 +1,6 @@ import numpy as np import ray -@ray.remote() +@ray.remote def normal(shape): return np.random.normal(size=shape) diff --git a/lib/python/ray/serialization.py b/lib/python/ray/serialization.py index 8cdbfa558..153fb89f4 100644 --- a/lib/python/ray/serialization.py +++ b/lib/python/ray/serialization.py @@ -28,7 +28,7 @@ class Tuple(tuple): class Str(str): pass - + class Unicode(unicode): pass diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 2ad41bcf5..259755e18 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -964,7 +964,7 @@ def main_loop(worker=global_worker): # TODO(rkn): Why is the below line necessary? function.__module__ = module assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in." - worker.functions[function_name] = remote(num_return_vals, worker)(function) + worker.functions[function_name] = remote(num_return_vals=num_return_vals)(function) _logger().info("Successfully imported remote function {}.".format(function_name)) # Noify the scheduler that the remote function imported successfully. # We pass an empty error message string because the import succeeded. @@ -1069,71 +1069,89 @@ def _export_reusable_variable(name, reusable, worker=global_worker): raise Exception("_export_reusable_variable can only be called on a driver.") raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer)) -def remote(num_return_vals=1, worker=global_worker): +def remote(*args, **kwargs): """This decorator is used to create remote functions. Args: num_return_vals (int): The number of object IDs that a call to this function should return. """ - def remote_decorator(func): - def func_call(*args, **kwargs): - """This gets run immediately when a worker calls a remote function.""" - check_connected() - args = list(args) - args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments - if _mode() == raylib.PYTHON_MODE: - # In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the - # arguments to prevent the function call from mutating them and to match - # the usual behavior of immutable remote objects. - return func(*copy.deepcopy(args)) - objectids = _submit_task(func_name, args) - if len(objectids) == 1: - return objectids[0] - elif len(objectids) > 1: - return objectids - def func_executor(arguments): - """This gets run when the remote function is executed.""" - _logger().info("Calling function {}".format(func.__name__)) - start_time = time.time() - result = func(*arguments) - end_time = time.time() - _logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time)) - return result - def func_invoker(*args, **kwargs): - """This is returned by the decorator and used to invoke the function.""" - raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name)) - func_invoker.remote = func_call - func_invoker.executor = func_executor - func_invoker.is_remote = True - func_name = "{}.{}".format(func.__module__, func.__name__) - func_invoker.func_name = func_name - func_invoker.func_doc = func.func_doc - sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()] - keyword_defaults = [(k, v.default) for k, v in sig_params] - has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params]) - func_invoker.has_vararg_param = has_vararg_param - has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) - check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) + worker = global_worker + def make_remote_decorator(num_return_vals): + def remote_decorator(func): + def func_call(*args, **kwargs): + """This gets run immediately when a worker calls a remote function.""" + check_connected() + args = list(args) + args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments + if any([arg is funcsigs._empty for arg in args]): + raise Exception("Not enough arguments were provided to {}.".format(func_name)) + if _mode() == raylib.PYTHON_MODE: + # In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the + # arguments to prevent the function call from mutating them and to match + # the usual behavior of immutable remote objects. + return func(*copy.deepcopy(args)) + objectids = _submit_task(func_name, args) + if len(objectids) == 1: + return objectids[0] + elif len(objectids) > 1: + return objectids + def func_executor(arguments): + """This gets run when the remote function is executed.""" + _logger().info("Calling function {}".format(func.__name__)) + start_time = time.time() + result = func(*arguments) + end_time = time.time() + _logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time)) + return result + def func_invoker(*args, **kwargs): + """This is returned by the decorator and used to invoke the function.""" + raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name)) + func_invoker.remote = func_call + func_invoker.executor = func_executor + func_invoker.is_remote = True + func_name = "{}.{}".format(func.__module__, func.__name__) + func_invoker.func_name = func_name + func_invoker.func_doc = func.func_doc - # Everything ready - export the function - if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]: - func_name_global_valid = func.__name__ in func.__globals__ - func_name_global_value = func.__globals__.get(func.__name__) - # Set the function globally to make it refer to itself - func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable - try: - to_export = pickling.dumps((func, num_return_vals, func.__module__)) - finally: - # Undo our changes - if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value - else: del func.__globals__[func.__name__] - if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: - raylib.export_remote_function(worker.handle, func_name, to_export) - elif worker.mode is None: - worker.cached_remote_functions.append((func_name, to_export)) - return func_invoker - return remote_decorator + sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()] + keyword_defaults = [(k, v.default) for k, v in sig_params] + has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params]) + func_invoker.has_vararg_param = has_vararg_param + has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params]) + check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name) + + # Everything ready - export the function + if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + func_name_global_valid = func.__name__ in func.__globals__ + func_name_global_value = func.__globals__.get(func.__name__) + # Set the function globally to make it refer to itself + func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable + try: + to_export = pickling.dumps((func, num_return_vals, func.__module__)) + finally: + # Undo our changes + if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value + else: del func.__globals__[func.__name__] + if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]: + raylib.export_remote_function(worker.handle, func_name, to_export) + elif worker.mode is None: + worker.cached_remote_functions.append((func_name, to_export)) + return func_invoker + + return remote_decorator + + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # This is the case where the decorator is just @ray.remote. + num_return_vals = 1 + func = args[0] + return make_remote_decorator(num_return_vals)(func) + else: + # This is the case where the decorator is something like + # @ray.remote(num_return_vals=2). + assert len(args) == 0 and "num_return_vals" in kwargs.keys(), "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'." + num_return_vals = kwargs["num_return_vals"] + return make_remote_decorator(num_return_vals) def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name): """Check if we support the signature of this function. diff --git a/test/failure_test.py b/test/failure_test.py index adc889b76..87c9d33a3 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -76,7 +76,7 @@ class TaskStatusTest(unittest.TestCase): return reducer, () def __call__(self): return - ray.remote()(Foo()) + ray.remote(Foo()) for _ in range(100): # Retry if we need to wait longer. if len(ray.task_info()["failed_remote_function_imports"]) >= 1: break @@ -112,7 +112,7 @@ class TaskStatusTest(unittest.TestCase): def reinitializer(foo): raise Exception("The reinitializer failed.") ray.reusables.foo = ray.Reusable(initializer, reinitializer) - @ray.remote() + @ray.remote def use_foo(): ray.reusables.foo use_foo.remote() diff --git a/test/runtest.py b/test/runtest.py index 9c85b58ee..99f2c74d7 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -262,42 +262,42 @@ class APITest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=2) # Test that we can define a remote function in the shell. - @ray.remote() + @ray.remote def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) # Test that we can redefine the remote function. - @ray.remote() + @ray.remote def f(x): return x + 10 self.assertEqual(ray.get(f.remote(0)), 10) # Test that we can close over plain old data. data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}] - @ray.remote() + @ray.remote def g(): return data ray.get(g.remote()) # Test that we can close over modules. - @ray.remote() + @ray.remote def h(): return np.zeros([3, 5]) assert_equal(ray.get(h.remote()), np.zeros([3, 5])) - @ray.remote() + @ray.remote def j(): return time.time() ray.get(j.remote()) # Test that we can define remote functions that call other remote functions. - @ray.remote() + @ray.remote def k(x): return x + 1 - @ray.remote() + @ray.remote def l(x): return k.remote(x) - @ray.remote() + @ray.remote def m(x): return ray.get(l.remote(x)) self.assertEqual(ray.get(k.remote(1)), 2) @@ -309,7 +309,7 @@ class APITest(unittest.TestCase): def testSelect(self): ray.init(start_ray_local=True, num_workers=4) - @ray.remote() + @ray.remote def f(delay): time.sleep(delay) return 1 @@ -345,10 +345,10 @@ class APITest(unittest.TestCase): ray.reusables.foo = ray.Reusable(foo_initializer) ray.reusables.bar = ray.Reusable(bar_initializer, bar_reinitializer) - @ray.remote() + @ray.remote def use_foo(): return ray.reusables.foo - @ray.remote() + @ray.remote def use_bar(): ray.reusables.bar.append(1) return ray.reusables.bar @@ -368,7 +368,7 @@ class APITest(unittest.TestCase): def f(): sys.path.append("fake_directory") ray.worker.global_worker.run_function_on_all_workers(f) - @ray.remote() + @ray.remote def get_path(): return sys.path self.assertEqual("fake_directory", ray.get(get_path.remote())[-1]) @@ -509,7 +509,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote() + @ray.remote def f(): return sys.getrefcount(None) first_count = ray.get(f.remote()) @@ -522,7 +522,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote() + @ray.remote def f(): return sys.getrefcount(True) first_count = ray.get(f.remote()) @@ -535,7 +535,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote() + @ray.remote def f(): return sys.getrefcount(False) first_count = ray.get(f.remote()) @@ -559,7 +559,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer) self.assertEqual(ray.reusables.foo, 1) - @ray.remote() + @ray.remote def use_foo(): return ray.reusables.foo self.assertEqual(ray.get(use_foo.remote()), 1) @@ -573,7 +573,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.bar = ray.Reusable(bar_initializer) - @ray.remote() + @ray.remote def use_bar(): ray.reusables.bar.append(4) return ray.reusables.bar @@ -592,7 +592,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer) - @ray.remote() + @ray.remote def use_baz(i): baz = ray.reusables.baz baz[i] = 1 @@ -613,7 +613,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer) - @ray.remote() + @ray.remote def use_qux(): return ray.reusables.qux self.assertEqual(ray.get(use_qux.remote()), 0) @@ -634,7 +634,7 @@ class ClusterAttachingTest(unittest.TestCase): ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) - @ray.remote() + @ray.remote def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) @@ -653,7 +653,7 @@ class ClusterAttachingTest(unittest.TestCase): ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) - @ray.remote() + @ray.remote def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) diff --git a/test/test_functions.py b/test/test_functions.py index d6be66ec6..d12585730 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -10,54 +10,54 @@ def handle_int(a, b): # Test aliasing -@ray.remote() +@ray.remote def test_alias_f(): return np.ones([3, 4, 5]) -@ray.remote() +@ray.remote def test_alias_g(): return test_alias_f.remote() -@ray.remote() +@ray.remote def test_alias_h(): return test_alias_g.remote() # Test timing -@ray.remote() +@ray.remote def empty_function(): pass -@ray.remote() +@ray.remote def trivial_function(): return 1 # Test keyword arguments -@ray.remote() +@ray.remote def keyword_fct1(a, b="hello"): return "{} {}".format(a, b) -@ray.remote() +@ray.remote def keyword_fct2(a="hello", b="world"): return "{} {}".format(a, b) -@ray.remote() +@ray.remote def keyword_fct3(a, b, c="hello", d="world"): return "{} {} {} {}".format(a, b, c, d) # Test variable numbers of arguments -@ray.remote() +@ray.remote def varargs_fct1(*a): return " ".join(map(str, a)) -@ray.remote() +@ray.remote def varargs_fct2(a, *b): return " ".join(map(str, b)) try: - @ray.remote() + @ray.remote def kwargs_throw_exception(**c): return () kwargs_exception_thrown = False @@ -65,7 +65,7 @@ except: kwargs_exception_thrown = True try: - @ray.remote() + @ray.remote def varargs_and_kwargs_throw_exception(a, b="hi", *c): return "{} {} {}".format(a, b, c) varargs_and_kwargs_exception_thrown = False @@ -74,11 +74,11 @@ except: # test throwing an exception -@ray.remote() +@ray.remote def throw_exception_fct1(): raise Exception("Test function 1 intentionally failed.") -@ray.remote() +@ray.remote def throw_exception_fct2(): raise Exception("Test function 2 intentionally failed.") @@ -88,18 +88,18 @@ def throw_exception_fct3(x): # test Python mode -@ray.remote() +@ray.remote def python_mode_f(): return np.array([0, 0]) -@ray.remote() +@ray.remote def python_mode_g(x): x[0] = 1 return x # test no return values -@ray.remote() +@ray.remote def no_op(): pass @@ -107,6 +107,6 @@ class TestClass(object): def __init__(self): self.a = 5 -@ray.remote() +@ray.remote def test_unknown_type(): return TestClass()