From d7f313a02657bc68dbc62e321ab7870ae84217b8 Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Mon, 29 Aug 2016 22:05:59 -0700 Subject: [PATCH] Remove type information from remote decorator. --- README.md | 2 +- doc/aliasing.md | 6 +- doc/remote-functions.md | 4 +- doc/tutorial.md | 14 +- examples/alexnet/README.md | 4 +- examples/alexnet/alexnet.py | 14 +- examples/hyperopt/README.md | 2 +- examples/hyperopt/hyperopt.py | 2 +- examples/lbfgs/README.md | 4 +- examples/lbfgs/driver.py | 4 +- examples/rl_pong/README.md | 2 +- examples/rl_pong/driver.py | 2 +- examples/trpo/README.md | 2 +- lib/python/ray/array/distributed/core.py | 28 ++-- lib/python/ray/array/distributed/linalg.py | 16 +- lib/python/ray/array/distributed/random.py | 2 +- lib/python/ray/array/remote/core.py | 36 ++--- lib/python/ray/array/remote/linalg.py | 40 ++--- lib/python/ray/array/remote/random.py | 2 +- lib/python/ray/worker.py | 176 ++------------------- test/failure_test.py | 2 +- test/runtest.py | 42 ++--- test/test_functions.py | 46 +++--- 23 files changed, 147 insertions(+), 305 deletions(-) diff --git a/README.md b/README.md index e3d50729e..81d5f4766 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ import numpy as np ray.init(start_ray_local=True, num_workers=10) # Define a remote function for estimating pi. -@ray.remote([int], [float]) +@ray.remote() def estimate_pi(n): x = np.random.uniform(size=n) y = np.random.uniform(size=n) diff --git a/doc/aliasing.md b/doc/aliasing.md index 1f227541e..5ff153766 100644 --- a/doc/aliasing.md +++ b/doc/aliasing.md @@ -10,15 +10,15 @@ However, to provide a more flexible API, we allow tasks to not only return values, but to also return object ids to values. As an examples, consider the following code. ```python -@ray.remote([], [np.ndarray]) +@ray.remote() def f() return np.zeros(5) -@ray.remote([], [np.ndarray]) +@ray.remote() def g() return f() -@ray.remote([], [np.ndarray]) +@ray.remote() def h() return g() ``` diff --git a/doc/remote-functions.md b/doc/remote-functions.md index e8bdc3a02..8ea56552f 100644 --- a/doc/remote-functions.md +++ b/doc/remote-functions.md @@ -5,7 +5,7 @@ functions. Remote functions are written like regular Python functions, but with the `@ray.remote` decorator on top. ```python -@ray.remote([int], [int]) +@ray.remote() def increment(n): return n + 1 ``` @@ -68,7 +68,7 @@ class ExampleClass(object): # This example assumes that field1 and field2 are serializable types. self.field1 = field1 self.field2 = field2 - + @staticmethod def deserialize(primitives): (field1, field2) = primitives diff --git a/doc/tutorial.md b/doc/tutorial.md index 08b8bdba4..aea656175 100644 --- a/doc/tutorial.md +++ b/doc/tutorial.md @@ -107,7 +107,7 @@ def add(a, b): ``` A remote function in Ray looks like this. ```python -@ray.remote([int, int], [int]) +@ray.remote() def add(a, b): return a + b ``` @@ -156,7 +156,7 @@ passed into the actual execution of the remote function. Note that a remote function can return multiple object IDs. ```python -@ray.remote([], [int, float, str]) +@ray.remote(num_return_vals=3) def return_multiple(): return 0, 0.0, "zero" @@ -194,7 +194,7 @@ around `time.sleep`. ```python import time -@ray.remote([int], [int]) +@ray.remote() def sleep(n): time.sleep(n) return 0 @@ -245,11 +245,11 @@ Computation graphs encode dependencies. For example, suppose we define ```python import numpy as np -@ray.remote([list], [np.ndarray]) +@ray.remote() def zeros(shape): return np.zeros(shape) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def dot(a, b): return np.dot(a, b) ``` @@ -282,12 +282,12 @@ processes can also call remote functions. To illustrate this, consider the following example. ```python -@ray.remote([int, int], [int]) +@ray.remote() def sub_experiment(i, j): # Run the jth sub-experiment for the ith experiment. return i + j -@ray.remote([int], [int]) +@ray.remote() def run_experiment(i): sub_results = [] # Launch tasks to perform 10 sub-experiments in parallel. diff --git a/examples/alexnet/README.md b/examples/alexnet/README.md index 23cde1440..dc6be5c83 100644 --- a/examples/alexnet/README.md +++ b/examples/alexnet/README.md @@ -59,7 +59,7 @@ workers. At the core of our loading code is the remote function retrieves the appropriate object. ```python -@ray.remote([str, str, List], [np.ndarray, List]) +@ray.remote(num_return_vals=2) def load_tarfile_from_s3(bucket, s3_key, size=[]): # Pull the object with the given key and bucket from S3, untar the contents, # and return it. @@ -85,7 +85,7 @@ The other parallel component of this application is the training procedure. This is built on top of the remote function `compute_grad`. ```python -@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List]) +@ray.remote() def compute_grad(X, Y, mean, weights): # Load the weights into the network. # Subtract the mean and crop the images. diff --git a/examples/alexnet/alexnet.py b/examples/alexnet/alexnet.py index cb07aa211..f3de56044 100644 --- a/examples/alexnet/alexnet.py +++ b/examples/alexnet/alexnet.py @@ -39,7 +39,7 @@ def load_chunk(tarfile, size=None): filenames.append(filename) return np.concatenate(result), filenames -@ray.remote([str, str, List], [np.ndarray, List]) +@ray.remote(num_return_vals=2) def load_tarfile_from_s3(bucket, s3_key, size=[]): """Load an imagenet .tar file. @@ -231,7 +231,7 @@ def net_initialization(): def net_reinitialization(net_vars): return net_vars -@ray.remote([List], [int]) +@ray.remote() def num_images(batches): """Counts number of images in batches. @@ -244,7 +244,7 @@ def num_images(batches): shape_ids = [ra.shape.remote(batch) for batch in batches] return sum([ray.get(shape_id)[0] for shape_id in shape_ids]) -@ray.remote([List], [np.ndarray]) +@ray.remote() def compute_mean_image(batches): """Computes the mean image given a list of batches of images. @@ -261,7 +261,7 @@ def compute_mean_image(batches): n_images = num_images.remote(batches) return np.sum(sum_images, axis=0).astype("float64") / ray.get(n_images) -@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=4) def shuffle_arrays(first_images, first_labels, second_images, second_labels): """Shuffles the images and labels from two batches. @@ -306,7 +306,7 @@ def shuffle_pair(first_batch, second_batch): images1, labels1, images2, labels2 = shuffle_arrays.remote(first_batch[0], first_batch[1], second_batch[0], second_batch[1]) return (images1, labels1), (images2, labels2) -@ray.remote([list, dict], [np.ndarray]) +@ray.remote() def filenames_to_labels(filenames, filename_label_dict): """Converts filename strings to integer labels. @@ -381,7 +381,7 @@ def shuffle(batches): new_batches.append(permuted_batches[-1]) return new_batches -@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List]) +@ray.remote() def compute_grad(X, Y, mean, weights): """Computes the gradient of the network. @@ -406,7 +406,7 @@ def compute_grad(X, Y, mean, weights): # Compute the gradients. return sess.run([g for (g, v) in comp_grads], feed_dict={images: subset_X, y_true: subset_Y, dropout: 0.5}) -@ray.remote([np.ndarray, np.ndarray, List], [np.float32]) +@ray.remote() def compute_accuracy(X, Y, weights): """Returns the accuracy of the network diff --git a/examples/hyperopt/README.md b/examples/hyperopt/README.md index 5ee4eaaa1..82258d307 100644 --- a/examples/hyperopt/README.md +++ b/examples/hyperopt/README.md @@ -82,7 +82,7 @@ complicated version of this remote function is defined in [hyperopt.py](hyperopt.py). ```python -@ray.remote([dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float]) +@ray.remote() def train_cnn_and_compute_accuracy(hyperparameters, train_images, train_labels, validation_images, validation_labels): # Actual work omitted. return validation_accuracy diff --git a/examples/hyperopt/hyperopt.py b/examples/hyperopt/hyperopt.py index 9d7b12a14..f8b4b07bf 100644 --- a/examples/hyperopt/hyperopt.py +++ b/examples/hyperopt/hyperopt.py @@ -51,7 +51,7 @@ def cnn_setup(x, y, keep_prob, lr, stddev): # Define a remote function that takes a set of hyperparameters as well as the # data, consructs and trains a network, and returns the validation accuracy. -@ray.remote([dict, int, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float]) +@ray.remote() def train_cnn_and_compute_accuracy(params, steps, train_images, train_labels, validation_images, validation_labels): # Extract the hyperparameters from the params dictionary. learning_rate = params["learning_rate"] diff --git a/examples/lbfgs/README.md b/examples/lbfgs/README.md index 8e0cf8c04..4af9f267c 100644 --- a/examples/lbfgs/README.md +++ b/examples/lbfgs/README.md @@ -91,12 +91,12 @@ use remote functions to distribute the loading of the data. Now, lets turn `loss` and `grad` into remote functions. ```python -@ray.remote([np.ndarray, np.ndarray, np.ndarray], [float]) +@ray.remote() def loss(theta, xs, ys): # compute the loss return loss -@ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def grad(theta, xs, ys): # compute the gradient return grad diff --git a/examples/lbfgs/driver.py b/examples/lbfgs/driver.py index d3556456f..9a5de5d84 100644 --- a/examples/lbfgs/driver.py +++ b/examples/lbfgs/driver.py @@ -74,14 +74,14 @@ if __name__ == "__main__": sess.run([update_w, update_b], feed_dict={w_new: theta[:w_size].reshape(w_shape), b_new: theta[w_size:]}) # Compute the loss on a batch of data. - @ray.remote([np.ndarray, np.ndarray, np.ndarray], [float]) + @ray.remote() def loss(theta, xs, ys): sess, _, _, cross_entropy, _, x, y_, _, _ = ray.reusables.net_vars load_weights(theta) return float(sess.run(cross_entropy, feed_dict={x: xs, y_: ys})) # Compute the gradient of the loss on a batch of data. - @ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray]) + @ray.remote() def grad(theta, xs, ys): sess, _, _, _, cross_entropy_grads, x, y_, _, _ = ray.reusables.net_vars load_weights(theta) diff --git a/examples/rl_pong/README.md b/examples/rl_pong/README.md index 907624fe0..daafaecd0 100644 --- a/examples/rl_pong/README.md +++ b/examples/rl_pong/README.md @@ -32,7 +32,7 @@ estimate of the gradient. Below is a simplified pseudocode version of this function. ```python -@ray.remote([dict], [dict, float]) +@ray.remote(num_return_vals=2) def compute_gradient(model): # Retrieve the game environment. env = ray.reusables.env diff --git a/examples/rl_pong/driver.py b/examples/rl_pong/driver.py index cbec69df3..6fe4bd58f 100644 --- a/examples/rl_pong/driver.py +++ b/examples/rl_pong/driver.py @@ -72,7 +72,7 @@ def policy_backward(eph, epx, epdlogp, model): dW1 = np.dot(dh.T, epx) return {"W1": dW1, "W2": dW2} -@ray.remote([dict], [dict, float]) +@ray.remote(num_return_vals=2) def compute_gradient(model): env = ray.reusables.env observation = env.reset() diff --git a/examples/trpo/README.md b/examples/trpo/README.md index 52f510d68..a1fd8aa1d 100644 --- a/examples/trpo/README.md +++ b/examples/trpo/README.md @@ -79,7 +79,7 @@ we use reusable variables to store the gym environment and the neural network po then used in the remote `do_rollout` function to do a remote rollout: ```python -@ray.remote([np.ndarray, int, int], [dict]) +@ray.remote() def do_rollout(policy, timestep_limit, seed): # Retrieve the game environment. env = ray.reusables.env diff --git a/lib/python/ray/array/distributed/core.py b/lib/python/ray/array/distributed/core.py index d9fa158dc..e3347fad0 100644 --- a/lib/python/ray/array/distributed/core.py +++ b/lib/python/ray/array/distributed/core.py @@ -66,12 +66,12 @@ class DistArray(object): a = self.assemble() return a[sliced] -@ray.remote([DistArray], [np.ndarray]) +@ray.remote() def assemble(a): return a.assemble() # TODO(rkn): what should we call this method -@ray.remote([np.ndarray], [DistArray]) +@ray.remote() def numpy_to_dist(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): @@ -80,28 +80,28 @@ def numpy_to_dist(a): result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]]) return result -@ray.remote([List, str], [DistArray]) +@ray.remote() def zeros(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result -@ray.remote([List, str], [DistArray]) +@ray.remote() def ones(shape, dtype_name="float"): result = DistArray(shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name) return result -@ray.remote([DistArray], [DistArray]) +@ray.remote() def copy(a): result = DistArray(a.shape) for index in np.ndindex(*result.num_blocks): result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable. return result -@ray.remote([int, int, str], [DistArray]) +@ray.remote() def eye(dim1, dim2=-1, dtype_name="float"): dim2 = dim1 if dim2 == -1 else dim2 shape = [dim1, dim2] @@ -114,7 +114,7 @@ def eye(dim1, dim2=-1, dtype_name="float"): result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name) return result -@ray.remote([DistArray], [DistArray]) +@ray.remote() def triu(a): if a.ndim != 2: raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) @@ -128,7 +128,7 @@ def triu(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result -@ray.remote([DistArray], [DistArray]) +@ray.remote() def tril(a): if a.ndim != 2: raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim)) @@ -142,7 +142,7 @@ def tril(a): result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j]) return result -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def blockwise_dot(*matrices): n = len(matrices) if n % 2 != 0: @@ -153,7 +153,7 @@ def blockwise_dot(*matrices): result += np.dot(matrices[i], matrices[n / 2 + i]) return result -@ray.remote([DistArray, DistArray], [DistArray]) +@ray.remote() def dot(a, b): if a.ndim != 2: raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim)) @@ -168,7 +168,7 @@ def dot(a, b): result.objectids[i, j] = blockwise_dot.remote(*args) return result -@ray.remote([DistArray, List], [DistArray]) +@ray.remote() def subblocks(a, *ranges): """ This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example, @@ -198,7 +198,7 @@ def subblocks(a, *ranges): result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])] return result -@ray.remote([DistArray], [DistArray]) +@ray.remote() def transpose(a): if a.ndim != 2: raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape)) @@ -209,7 +209,7 @@ def transpose(a): return result # TODO(rkn): support broadcasting? -@ray.remote([DistArray, DistArray], [DistArray]) +@ray.remote() def add(x1, x2): if x1.shape != x2.shape: raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) @@ -219,7 +219,7 @@ def add(x1, x2): return result # TODO(rkn): support broadcasting? -@ray.remote([DistArray, DistArray], [DistArray]) +@ray.remote() def subtract(x1, x2): if x1.shape != x2.shape: raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape)) diff --git a/lib/python/ray/array/distributed/linalg.py b/lib/python/ray/array/distributed/linalg.py index 11c8702ef..5e883da67 100644 --- a/lib/python/ray/array/distributed/linalg.py +++ b/lib/python/ray/array/distributed/linalg.py @@ -6,7 +6,7 @@ from core import * __all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"] -@ray.remote([DistArray], [DistArray, np.ndarray]) +@ray.remote(num_return_vals=2) def tsqr(a): """ arguments: @@ -75,7 +75,7 @@ def tsqr(a): return q_result, r # TODO(rkn): This is unoptimized, we really want a block version of this. -@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=3) def modified_lu(q): """ Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf @@ -105,19 +105,19 @@ def modified_lu(q): U = np.triu(q_work)[:b, :] return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put -@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def tsqr_hr_helper1(u, s, y_top_block, b): y_top = y_top_block[:b, :b] s_full = np.diag(s) t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T)) return t, y_top -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def tsqr_hr_helper2(s, r_temp): s_full = np.diag(s) return np.dot(s_full, r_temp) -@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=4) def tsqr_hr(a): """Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" q, r_temp = tsqr.remote(a) @@ -127,15 +127,15 @@ def tsqr_hr(a): r = tsqr_hr_helper2.remote(s, r_temp) return y, t, y_top, r -@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def qr_helper1(a_rc, y_ri, t, W_c): return a_rc - np.dot(y_ri, np.dot(t.T, W_c)) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def qr_helper2(y_ri, a_rc): return np.dot(y_ri.T, a_rc) -@ray.remote([DistArray], [DistArray, DistArray]) +@ray.remote(num_return_vals=2) def qr(a): """Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf""" m, n = a.shape[0], a.shape[1] diff --git a/lib/python/ray/array/distributed/random.py b/lib/python/ray/array/distributed/random.py index 951ade7df..4c42fe433 100644 --- a/lib/python/ray/array/distributed/random.py +++ b/lib/python/ray/array/distributed/random.py @@ -6,7 +6,7 @@ import ray from core import * -@ray.remote([List], [DistArray]) +@ray.remote() def normal(shape): num_blocks = DistArray.compute_num_blocks(shape) objectids = np.empty(num_blocks, dtype=object) diff --git a/lib/python/ray/array/remote/core.py b/lib/python/ray/array/remote/core.py index 1b9bf5318..bab484696 100644 --- a/lib/python/ray/array/remote/core.py +++ b/lib/python/ray/array/remote/core.py @@ -4,80 +4,80 @@ import ray __all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"] -@ray.remote([List, str, str], [np.ndarray]) +@ray.remote() def zeros(shape, dtype_name="float", order="C"): return np.zeros(shape, dtype=np.dtype(dtype_name), order=order) -@ray.remote([np.ndarray, str, str, bool], [np.ndarray]) +@ray.remote() def zeros_like(a, dtype_name="None", order="K", subok=True): dtype_val = None if dtype_name == "None" else np.dtype(dtype_name) return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok) -@ray.remote([List, str, str], [np.ndarray]) +@ray.remote() def ones(shape, dtype_name="float", order="C"): return np.ones(shape, dtype=np.dtype(dtype_name), order=order) -@ray.remote([int, int, int, str], [np.ndarray]) +@ray.remote() def eye(N, M=-1, k=0, dtype_name="float"): M = N if M == -1 else M return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name)) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def dot(a, b): return np.dot(a, b) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def vstack(*xs): return np.vstack(xs) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def hstack(*xs): return np.hstack(xs) # TODO(rkn): instead of this, consider implementing slicing -@ray.remote([np.ndarray, List, List], [np.ndarray]) +@ray.remote() def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices" return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]] -@ray.remote([np.ndarray, str], [np.ndarray]) +@ray.remote() def copy(a, order="K"): return np.copy(a, order=order) -@ray.remote([np.ndarray, int], [np.ndarray]) +@ray.remote() def tril(m, k=0): return np.tril(m, k=k) -@ray.remote([np.ndarray, int], [np.ndarray]) +@ray.remote() def triu(m, k=0): return np.triu(m, k=k) -@ray.remote([np.ndarray, int], [np.ndarray]) +@ray.remote() def diag(v, k=0): return np.diag(v, k=k) -@ray.remote([np.ndarray, List], [np.ndarray]) +@ray.remote() def transpose(a, axes=[]): axes = None if axes == [] else axes return np.transpose(a, axes=axes) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def add(x1, x2): return np.add(x1, x2) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def subtract(x1, x2): return np.subtract(x1, x2) -@ray.remote([np.ndarray, int], [np.ndarray]) +@ray.remote() def sum(x, axis=-1): return np.sum(x, axis=axis if axis != -1 else None) -@ray.remote([np.ndarray], [tuple]) +@ray.remote() def shape(a): return np.shape(a) # We use Any to allow different numerical types as well as numpy arrays. # TODO(rkn):this isn't in the numpy API, so be careful about exposing this. -@ray.remote([Any], [Any]) +@ray.remote() def sum_list(*xs): return np.sum(xs, axis=0) diff --git a/lib/python/ray/array/remote/linalg.py b/lib/python/ray/array/remote/linalg.py index f79e94684..92aea45f0 100644 --- a/lib/python/ray/array/remote/linalg.py +++ b/lib/python/ray/array/remote/linalg.py @@ -7,82 +7,82 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv", "svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank", "LinAlgError", "multi_dot"] -@ray.remote([np.ndarray, int], [np.ndarray]) +@ray.remote() def matrix_power(M, n): return np.linalg.matrix_power(M, n) -@ray.remote([np.ndarray, np.ndarray], [np.ndarray]) +@ray.remote() def solve(a, b): return np.linalg.solve(a, b) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def tensorsolve(a): raise NotImplementedError -@ray.remote([np.ndarray], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def tensorinv(a): raise NotImplementedError -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def inv(a): return np.linalg.inv(a) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def cholesky(a): return np.linalg.cholesky(a) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def eigvals(a): return np.linalg.eigvals(a) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def eigvalsh(a): raise NotImplementedError -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def pinv(a): return np.linalg.pinv(a) -@ray.remote([np.ndarray], [int]) +@ray.remote() def slogdet(a): raise NotImplementedError -@ray.remote([np.ndarray], [float]) +@ray.remote() def det(a): return np.linalg.det(a) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=3) def svd(a): return np.linalg.svd(a) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def eig(a): return np.linalg.eig(a) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def eigh(a): return np.linalg.eigh(a) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray, int, np.ndarray]) +@ray.remote(num_return_vals=4) def lstsq(a, b): return np.linalg.lstsq(a) -@ray.remote([np.ndarray], [float]) +@ray.remote() def norm(x): return np.linalg.norm(x) -@ray.remote([np.ndarray], [np.ndarray, np.ndarray]) +@ray.remote(num_return_vals=2) def qr(a): return np.linalg.qr(a) -@ray.remote([np.ndarray], [float]) +@ray.remote() def cond(x): return np.linalg.cond(x) -@ray.remote([np.ndarray], [int]) +@ray.remote() def matrix_rank(M): return np.linalg.matrix_rank(M) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def multi_dot(*a): raise NotImplementedError diff --git a/lib/python/ray/array/remote/random.py b/lib/python/ray/array/remote/random.py index d76bb5dc9..7aa5ddbd8 100644 --- a/lib/python/ray/array/remote/random.py +++ b/lib/python/ray/array/remote/random.py @@ -2,6 +2,6 @@ from typing import List import numpy as np import ray -@ray.remote([List], [np.ndarray]) +@ray.remote() def normal(shape): return np.random.normal(size=shape) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index e580021bf..ba27ee3b6 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -4,7 +4,6 @@ import time import traceback import copy import logging -from types import ModuleType import typing import funcsigs import numpy as np @@ -45,7 +44,7 @@ class RayTaskError(Exception): def __init__(self, function_name, exception, traceback_str): """Initialize a RayTaskError.""" self.function_name = function_name - if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError): + if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError): self.exception = exception else: self.exception = None @@ -59,8 +58,6 @@ class RayTaskError(Exception): exception = RayGetError.deserialize(exception[1]) elif exception[0] == "RayGetArgumentError": exception = RayGetArgumentError.deserialize(exception[1]) - elif exception[0] == "RayGetArgumentTypeError": - exception = RayGetArgumentTypeError.deserialize(exception[1]) elif exception[0] == "None": exception = None else: @@ -73,8 +70,6 @@ class RayTaskError(Exception): serialized_exception = ("RayGetError", self.exception.serialize()) elif isinstance(self.exception, RayGetArgumentError): serialized_exception = ("RayGetArgumentError", self.exception.serialize()) - elif isinstance(self.exception, RayGetArgumentTypeError): - serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize()) elif self.exception is None: serialized_exception = ("None",) else: @@ -151,41 +146,6 @@ class RayGetArgumentError(Exception): """Format a RayGetArgumentError as a string.""" return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error) -class RayGetArgumentTypeError(Exception): - """An exception used when a task's argument doesn't type check. - - Attributes: - function_name (str): The name of the function for the current task. - argument_index (int): The index (zero indexed) of the argument in the - present task's remote function call. - received_type: The type of the argument that was passed in. - expected_type: The type that was expected. This is determined by the remote - decorator. - """ - - def __init__(self, function_name, argument_index, received_type, expected_type): - """Initialize a RayGetArgumentTypeError object.""" - self.function_name = function_name - self.argument_index = argument_index - # TODO(rkn): when we support the serialization of types, then we should - # remove the string conversions below. - self.received_type = str(received_type) - self.expected_type = str(expected_type) - - @staticmethod - def deserialize(primitives): - """Create a RayGetArgumentTypeError from a primitive object.""" - function_name, argument_index, received_type, expected_type = primitives - return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type) - - def serialize(self): - """Turn a RayGetArgumentTypeError into a primitive object.""" - return (self.function_name, self.argument_index, self.received_type, self.expected_type) - - def __str__(self): - """Format a RayGetArgumentTypeError as a string.""" - return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type) - class RayDealloc(object): """An object used internally to properly implement reference counting. @@ -993,7 +953,7 @@ def main_loop(worker=global_worker): def process_remote_function(function_name, serialized_function): """Import a remote function.""" try: - (function, arg_types, return_types, module) = pickling.loads(serialized_function) + function, num_return_vals, module = pickling.loads(serialized_function) except: # If an exception was thrown when the remote function was imported, we # record the traceback and notify the scheduler of the failure. @@ -1005,11 +965,11 @@ def main_loop(worker=global_worker): # TODO(rkn): Why is the below line necessary? function.__module__ = module assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in." - worker.functions[function_name] = remote(arg_types, return_types, worker)(function) + worker.functions[function_name] = remote(num_return_vals, worker)(function) _logger().info("Successfully imported remote function {}.".format(function_name)) # Noify the scheduler that the remote function imported successfully. # We pass an empty error message string because the import succeeded. - raylib.register_remote_function(worker.handle, function_name, len(return_types)) + raylib.register_remote_function(worker.handle, function_name, num_return_vals) def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str): """Import a reusable variable.""" @@ -1110,12 +1070,12 @@ def _export_reusable_variable(name, reusable, worker=global_worker): raise Exception("_export_reusable_variable can only be called on a driver.") raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer)) -def remote(arg_types, return_types, worker=global_worker): +def remote(num_return_vals=1, worker=global_worker): """This decorator is used to create remote functions. Args: - arg_types (List[type]): List of Python types of the function arguments. - return_types (List[type]): List of Python types of the return values. + num_return_vals (int): The number of object IDs that a call to this function + should return. """ def remote_decorator(func): def func_call(*args, **kwargs): @@ -1128,7 +1088,6 @@ def remote(arg_types, return_types, worker=global_worker): # arguments to prevent the function call from mutating them and to match # the usual behavior of immutable remote objects. return func(*copy.deepcopy(args)) - check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid objectids = _submit_task(func_name, args) if len(objectids) == 1: return objectids[0] @@ -1140,7 +1099,6 @@ def remote(arg_types, return_types, worker=global_worker): start_time = time.time() result = func(*arguments) end_time = time.time() - check_return_values(func_invoker, result) # throws an exception if result is invalid _logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time)) return result def func_invoker(*args, **kwargs): @@ -1148,8 +1106,6 @@ def remote(arg_types, return_types, worker=global_worker): raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name)) func_invoker.remote = func_call func_invoker.executor = func_executor - func_invoker.arg_types = arg_types - func_invoker.return_types = return_types func_invoker.is_remote = True func_name = "{}.{}".format(func.__module__, func.__name__) func_invoker.func_name = func_name @@ -1168,7 +1124,7 @@ def remote(arg_types, return_types, worker=global_worker): # Set the function globally to make it refer to itself func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable try: - to_export = pickling.dumps((func, arg_types, return_types, func.__module__)) + to_export = pickling.dumps((func, num_return_vals, func.__module__)) finally: # Undo our changes if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value @@ -1205,109 +1161,12 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]): raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name) - -def check_return_values(function, result): - """Check the types and number of return values. - - Args: - function (Callable): The remote function whose outputs are being checked. - result: The value returned by an invocation of the remote function. The - expected types and number are defined in the remote decorator. - - Raises: - Exception: An exception is raised if the return values have incorrect types - or the function returned the wrong number of return values. - """ - # If the @remote decorator declares that the function has no return values, - # then all we do is check that there were in fact no return values. - if len(function.return_types) == 0: - if result is not None: - raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__)) - return - # If a function has multiple return values, Python returns a tuple of the - # values. If there is a single return value, then Python does not return a - # tuple, it simply returns the value. That is why we place result with - # (result,) when there is only one return value, so we can treat these two - # cases similarly. - if len(function.return_types) == 1: - result = (result,) - # Below we check that the number of values returned by the function match the - # number of return values declared in the @remote decorator. - if len(result) != len(function.return_types): - raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result))) - # Here we do some limited type checking to make sure the return values have - # the right types. - for i in range(len(result)): - if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)): - raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i])) - -def typecheck_arg(arg, expected_type, i, name): - """Check that an argument has the expected type. - - Args: - arg: An argument to function. - expected_type (type): The expected type of arg. - i (int): The position of the argument to the function. - name (str): The name of the function. - - Raises: - RayGetArgumentTypeError: An exception is raised if arg does not have the - expected type. - """ - if issubclass(type(arg), expected_type): - # Passed the type-checck - # TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True - pass - elif isinstance(arg, long) and issubclass(int, expected_type): - # TODO(mehrdadn): Should long really be convertible to int? - pass - else: - raise RayGetArgumentTypeError(name, i, type(arg), expected_type) - -def check_arguments(arg_types, has_vararg_param, name, args): - """Check that the arguments to the remote function have the right types. - - This is called by the worker that calls the remote function (not the worker - that executes the remote function). - - Args: - arg_types (List[type]): A list of the types of the arguments to the function - being checked. - has_vararg_param (bool): True if the function being checked has a *args - argument. - name (str): The name of the function. - args (List): The arguments to the function. - - Raises: - Exception: An exception is raised the args do not all have the right types. - """ - # check the number of args - if len(args) != len(arg_types) and not has_vararg_param: - raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args))) - elif len(args) < len(arg_types) - 1 and has_vararg_param: - raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args))) - - for (i, arg) in enumerate(args): - if i <= len(arg_types) - 1: - expected_type = arg_types[i] - elif has_vararg_param: - expected_type = arg_types[-1] - else: - assert False, "This code should be unreachable." - - if isinstance(arg, raylib.ObjectID): - # TODO(rkn): When we have type information in the ObjectID, do type checking here. - pass - else: - typecheck_arg(arg, expected_type, i, name) - def get_arguments_for_execution(function, args, worker=global_worker): """Retrieve the arguments for the remote function. This retrieves the values for the arguments to the remote function that were passed in as object IDs. Argumens that were passed by value are not changed. - This also does some type checking. This is called by the worker that is - executing the remote function. + This is called by the worker that is executing the remote function. Args: function (Callable): The remote function whose arguments are being @@ -1321,25 +1180,9 @@ def get_arguments_for_execution(function, args, worker=global_worker): Raises: RayGetArgumentError: This exception is raised if a task that created one of the arguments failed. - RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one - of the arguments does not have the expected type. """ - # TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function. arguments = [] - # # check the number of args - # if len(args) != len(function.arg_types) and function.arg_types[-1] is not None: - # raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args))) - # elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None: - # raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args))) - for (i, arg) in enumerate(args): - if i <= len(function.arg_types) - 1: - expected_type = function.arg_types[i] - elif function.has_vararg_param and len(function.arg_types) >= 1: - expected_type = function.arg_types[-1] - else: - assert False, "This code should be unreachable." - if isinstance(arg, raylib.ObjectID): # get the object from the local object store _logger().info("Getting argument {} for function {}.".format(i, function.__name__)) @@ -1353,7 +1196,6 @@ def get_arguments_for_execution(function, args, worker=global_worker): # pass the argument by value argument = arg - typecheck_arg(argument, expected_type, i, function.__name__) arguments.append(argument) return arguments diff --git a/test/failure_test.py b/test/failure_test.py index c30546cfb..c2d068954 100644 --- a/test/failure_test.py +++ b/test/failure_test.py @@ -140,7 +140,7 @@ class TaskStatusTest(unittest.TestCase): def reinitializer(foo): raise Exception("The reinitializer failed.") ray.reusables.foo = ray.Reusable(initializer, reinitializer) - @ray.remote([], []) + @ray.remote() def use_foo(): ray.reusables.foo use_foo.remote() diff --git a/test/runtest.py b/test/runtest.py index 0c00923cd..9c85b58ee 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -262,42 +262,42 @@ class APITest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=2) # Test that we can define a remote function in the shell. - @ray.remote([int], [int]) + @ray.remote() def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) # Test that we can redefine the remote function. - @ray.remote([int], [int]) + @ray.remote() def f(x): return x + 10 self.assertEqual(ray.get(f.remote(0)), 10) # Test that we can close over plain old data. data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}] - @ray.remote([], [list]) + @ray.remote() def g(): return data ray.get(g.remote()) # Test that we can close over modules. - @ray.remote([], [np.ndarray]) + @ray.remote() def h(): return np.zeros([3, 5]) assert_equal(ray.get(h.remote()), np.zeros([3, 5])) - @ray.remote([], [float]) + @ray.remote() def j(): return time.time() ray.get(j.remote()) # Test that we can define remote functions that call other remote functions. - @ray.remote([int], [int]) + @ray.remote() def k(x): return x + 1 - @ray.remote([int], [int]) + @ray.remote() def l(x): return k.remote(x) - @ray.remote([int], [int]) + @ray.remote() def m(x): return ray.get(l.remote(x)) self.assertEqual(ray.get(k.remote(1)), 2) @@ -309,7 +309,7 @@ class APITest(unittest.TestCase): def testSelect(self): ray.init(start_ray_local=True, num_workers=4) - @ray.remote([float], [int]) + @ray.remote() def f(delay): time.sleep(delay) return 1 @@ -345,10 +345,10 @@ class APITest(unittest.TestCase): ray.reusables.foo = ray.Reusable(foo_initializer) ray.reusables.bar = ray.Reusable(bar_initializer, bar_reinitializer) - @ray.remote([], [int]) + @ray.remote() def use_foo(): return ray.reusables.foo - @ray.remote([], [list]) + @ray.remote() def use_bar(): ray.reusables.bar.append(1) return ray.reusables.bar @@ -368,7 +368,7 @@ class APITest(unittest.TestCase): def f(): sys.path.append("fake_directory") ray.worker.global_worker.run_function_on_all_workers(f) - @ray.remote([], [list]) + @ray.remote() def get_path(): return sys.path self.assertEqual("fake_directory", ray.get(get_path.remote())[-1]) @@ -509,7 +509,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote([], [int]) + @ray.remote() def f(): return sys.getrefcount(None) first_count = ray.get(f.remote()) @@ -522,7 +522,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote([], [int]) + @ray.remote() def f(): return sys.getrefcount(True) first_count = ray.get(f.remote()) @@ -535,7 +535,7 @@ class PythonCExtensionTest(unittest.TestCase): ray.init(start_ray_local=True, num_workers=1) # Make sure that we aren't accidentally messing up Python's reference counts. - @ray.remote([], [int]) + @ray.remote() def f(): return sys.getrefcount(False) first_count = ray.get(f.remote()) @@ -559,7 +559,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer) self.assertEqual(ray.reusables.foo, 1) - @ray.remote([], [int]) + @ray.remote() def use_foo(): return ray.reusables.foo self.assertEqual(ray.get(use_foo.remote()), 1) @@ -573,7 +573,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.bar = ray.Reusable(bar_initializer) - @ray.remote([], [list]) + @ray.remote() def use_bar(): ray.reusables.bar.append(4) return ray.reusables.bar @@ -592,7 +592,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer) - @ray.remote([int], [np.ndarray]) + @ray.remote() def use_baz(i): baz = ray.reusables.baz baz[i] = 1 @@ -613,7 +613,7 @@ class ReusablesTest(unittest.TestCase): ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer) - @ray.remote([], [int]) + @ray.remote() def use_qux(): return ray.reusables.qux self.assertEqual(ray.get(use_qux.remote()), 0) @@ -634,7 +634,7 @@ class ClusterAttachingTest(unittest.TestCase): ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) - @ray.remote([int], [int]) + @ray.remote() def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) @@ -653,7 +653,7 @@ class ClusterAttachingTest(unittest.TestCase): ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address) - @ray.remote([int], [int]) + @ray.remote() def f(x): return x + 1 self.assertEqual(ray.get(f.remote(0)), 1) diff --git a/test/test_functions.py b/test/test_functions.py index 9fb1aca21..cc0a75c7a 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -4,60 +4,60 @@ import numpy as np # Test simple functionality -@ray.remote([int, int], [int, int]) +@ray.remote(num_return_vals=2) def handle_int(a, b): return a + 1, b + 1 # Test aliasing -@ray.remote([], [np.ndarray]) +@ray.remote() def test_alias_f(): return np.ones([3, 4, 5]) -@ray.remote([], [np.ndarray]) +@ray.remote() def test_alias_g(): return test_alias_f.remote() -@ray.remote([], [np.ndarray]) +@ray.remote() def test_alias_h(): return test_alias_g.remote() # Test timing -@ray.remote([], []) +@ray.remote() def empty_function(): pass -@ray.remote([], [int]) +@ray.remote() def trivial_function(): return 1 # Test keyword arguments -@ray.remote([int, str], [str]) +@ray.remote() def keyword_fct1(a, b="hello"): return "{} {}".format(a, b) -@ray.remote([str, str], [str]) +@ray.remote() def keyword_fct2(a="hello", b="world"): return "{} {}".format(a, b) -@ray.remote([int, int, str, str], [str]) +@ray.remote() def keyword_fct3(a, b, c="hello", d="world"): return "{} {} {} {}".format(a, b, c, d) # Test variable numbers of arguments -@ray.remote([int], [str]) +@ray.remote() def varargs_fct1(*a): return " ".join(map(str, a)) -@ray.remote([int, int], [str]) +@ray.remote() def varargs_fct2(a, *b): return " ".join(map(str, b)) try: - @ray.remote([int], []) + @ray.remote() def kwargs_throw_exception(**c): return () kwargs_exception_thrown = False @@ -65,7 +65,7 @@ except: kwargs_exception_thrown = True try: - @ray.remote([int, str, int], [str]) + @ray.remote() def varargs_and_kwargs_throw_exception(a, b="hi", *c): return "{} {} {}".format(a, b, c) varargs_and_kwargs_exception_thrown = False @@ -74,46 +74,46 @@ except: # test throwing an exception -@ray.remote([], []) +@ray.remote() def throw_exception_fct1(): raise Exception("Test function 1 intentionally failed.") -@ray.remote([], [int]) +@ray.remote() def throw_exception_fct2(): raise Exception("Test function 2 intentionally failed.") -@ray.remote([float], [int, str, np.ndarray]) +@ray.remote(num_return_vals=3) def throw_exception_fct3(x): raise Exception("Test function 3 intentionally failed.") # test Python mode -@ray.remote([], [np.ndarray]) +@ray.remote() def python_mode_f(): return np.array([0, 0]) -@ray.remote([np.ndarray], [np.ndarray]) +@ray.remote() def python_mode_g(x): x[0] = 1 return x # test no return values -@ray.remote([], []) +@ray.remote() def no_op(): pass -@ray.remote([], []) +@ray.remote() def no_op_fail(): return 0 # test wrong return types -@ray.remote([], [int]) +@ray.remote() def test_return1(): return 0.0 -@ray.remote([], [int, float]) +@ray.remote(num_return_vals=2) def test_return2(): return 2.0, 3.0 @@ -121,6 +121,6 @@ class TestClass(object): def __init__(self): self.a = 5 -@ray.remote([], [TestClass]) +@ray.remote() def test_unknown_type(): return TestClass()