Remove type information from remote decorator. (#394)

* Remove type information from remote decorator.

* Remove typing module.

* Fix failure_test.py.

* Allow remote decorator to be used with no parentheses.

* Fix documentation.
This commit is contained in:
Philipp Moritz
2016-08-30 18:34:25 -07:00
committed by GitHub
32 changed files with 232 additions and 442 deletions
+1 -1
View File
@@ -21,7 +21,7 @@ import numpy as np
ray.init(start_ray_local=True, num_workers=10)
# Define a remote function for estimating pi.
@ray.remote([int], [float])
@ray.remote
def estimate_pi(n):
x = np.random.uniform(size=n)
y = np.random.uniform(size=n)
+3 -3
View File
@@ -10,15 +10,15 @@ However, to provide a more flexible API, we allow tasks to not only return
values, but to also return object ids to values. As an examples, consider
the following code.
```python
@ray.remote([], [np.ndarray])
@ray.remote
def f()
return np.zeros(5)
@ray.remote([], [np.ndarray])
@ray.remote
def g()
return f()
@ray.remote([], [np.ndarray])
@ray.remote
def h()
return g()
```
+1 -1
View File
@@ -18,7 +18,7 @@ import shlex
# These 4 lines added to enable ReadTheDocs to work.
import mock
MOCK_MODULES = ["libraylib", "IPython", "numpy", "typing", "funcsigs", "subprocess32", "protobuf", "colorama", "graphviz", "cloudpickle", "ray.internal.graph_pb2"]
MOCK_MODULES = ["libraylib", "IPython", "numpy", "funcsigs", "subprocess32", "protobuf", "colorama", "graphviz", "cloudpickle", "ray.internal.graph_pb2"]
for mod_name in MOCK_MODULES:
sys.modules[mod_name] = mock.Mock()
+1 -1
View File
@@ -19,7 +19,7 @@ brew update
brew install git cmake automake autoconf libtool boost graphviz
sudo easy_install pip
sudo pip install ipython --user
sudo pip install numpy typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
```
## Build
+1 -1
View File
@@ -15,7 +15,7 @@ First install the dependencies. We currently do not support Python 3.
```
sudo apt-get update
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
```
## Build
+2 -2
View File
@@ -5,7 +5,7 @@ functions. Remote functions are written like regular Python functions, but with
the `@ray.remote` decorator on top.
```python
@ray.remote([int], [int])
@ray.remote
def increment(n):
return n + 1
```
@@ -68,7 +68,7 @@ class ExampleClass(object):
# This example assumes that field1 and field2 are serializable types.
self.field1 = field1
self.field2 = field2
@staticmethod
def deserialize(primitives):
(field1, field2) = primitives
+7 -14
View File
@@ -107,18 +107,11 @@ def add(a, b):
```
A remote function in Ray looks like this.
```python
@ray.remote([int, int], [int])
@ray.remote
def add(a, b):
return a + b
```
The information passed to the `@ray.remote` decorator includes type information
for the arguments and for the return values of the function. Because of the
distinction that we make between *submitting a task* and *executing the task*,
we require type information so that we can catch type errors when the remote
function is called instead of catching them when the task is actually executed
(which could be much later and could be on a different machine).
### Remote functions
Whereas in regular Python, calling `add(1, 2)` would return `3`, in Ray, calling
@@ -156,7 +149,7 @@ passed into the actual execution of the remote function.
Note that a remote function can return multiple object IDs.
```python
@ray.remote([], [int, float, str])
@ray.remote(num_return_vals=3)
def return_multiple():
return 0, 0.0, "zero"
@@ -194,7 +187,7 @@ around `time.sleep`.
```python
import time
@ray.remote([int], [int])
@ray.remote
def sleep(n):
time.sleep(n)
return 0
@@ -245,11 +238,11 @@ Computation graphs encode dependencies. For example, suppose we define
```python
import numpy as np
@ray.remote([list], [np.ndarray])
@ray.remote
def zeros(shape):
return np.zeros(shape)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def dot(a, b):
return np.dot(a, b)
```
@@ -282,12 +275,12 @@ processes can also call remote functions. To illustrate this, consider the
following example.
```python
@ray.remote([int, int], [int])
@ray.remote
def sub_experiment(i, j):
# Run the jth sub-experiment for the ith experiment.
return i + j
@ray.remote([int], [int])
@ray.remote
def run_experiment(i):
sub_results = []
# Launch tasks to perform 10 sub-experiments in parallel.
+1 -1
View File
@@ -6,7 +6,7 @@ RUN apt-get update
RUN apt-get -y install apt-utils
RUN apt-get -y install sudo
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user
RUN adduser ray-user sudo
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
+1 -1
View File
@@ -6,7 +6,7 @@ RUN apt-get update
RUN apt-get -y install apt-utils
RUN apt-get -y install sudo
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user --uid 500
RUN adduser ray-user sudo
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
+1 -1
View File
@@ -7,7 +7,7 @@ RUN apt-get update
RUN apt-get -y install apt-utils
RUN apt-get -y install sudo
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user
RUN adduser ray-user sudo
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
+2 -2
View File
@@ -59,7 +59,7 @@ workers. At the core of our loading code is the remote function
retrieves the appropriate object.
```python
@ray.remote([str, str, List], [np.ndarray, List])
@ray.remote(num_return_vals=2)
def load_tarfile_from_s3(bucket, s3_key, size=[]):
# Pull the object with the given key and bucket from S3, untar the contents,
# and return it.
@@ -85,7 +85,7 @@ The other parallel component of this application is the training procedure. This
is built on top of the remote function `compute_grad`.
```python
@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List])
@ray.remote
def compute_grad(X, Y, mean, weights):
# Load the weights into the network.
# Subtract the mean and crop the images.
+7 -8
View File
@@ -7,7 +7,6 @@ import tarfile, io
import boto3
import PIL.Image as Image
import tensorflow as tf
from typing import List, Tuple
import ray.array.remote as ra
@@ -39,7 +38,7 @@ def load_chunk(tarfile, size=None):
filenames.append(filename)
return np.concatenate(result), filenames
@ray.remote([str, str, List], [np.ndarray, List])
@ray.remote(num_return_vals=2)
def load_tarfile_from_s3(bucket, s3_key, size=[]):
"""Load an imagenet .tar file.
@@ -231,7 +230,7 @@ def net_initialization():
def net_reinitialization(net_vars):
return net_vars
@ray.remote([List], [int])
@ray.remote
def num_images(batches):
"""Counts number of images in batches.
@@ -244,7 +243,7 @@ def num_images(batches):
shape_ids = [ra.shape.remote(batch) for batch in batches]
return sum([ray.get(shape_id)[0] for shape_id in shape_ids])
@ray.remote([List], [np.ndarray])
@ray.remote
def compute_mean_image(batches):
"""Computes the mean image given a list of batches of images.
@@ -261,7 +260,7 @@ def compute_mean_image(batches):
n_images = num_images.remote(batches)
return np.sum(sum_images, axis=0).astype("float64") / ray.get(n_images)
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=4)
def shuffle_arrays(first_images, first_labels, second_images, second_labels):
"""Shuffles the images and labels from two batches.
@@ -306,7 +305,7 @@ def shuffle_pair(first_batch, second_batch):
images1, labels1, images2, labels2 = shuffle_arrays.remote(first_batch[0], first_batch[1], second_batch[0], second_batch[1])
return (images1, labels1), (images2, labels2)
@ray.remote([list, dict], [np.ndarray])
@ray.remote
def filenames_to_labels(filenames, filename_label_dict):
"""Converts filename strings to integer labels.
@@ -381,7 +380,7 @@ def shuffle(batches):
new_batches.append(permuted_batches[-1])
return new_batches
@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List])
@ray.remote
def compute_grad(X, Y, mean, weights):
"""Computes the gradient of the network.
@@ -406,7 +405,7 @@ def compute_grad(X, Y, mean, weights):
# Compute the gradients.
return sess.run([g for (g, v) in comp_grads], feed_dict={images: subset_X, y_true: subset_Y, dropout: 0.5})
@ray.remote([np.ndarray, np.ndarray, List], [np.float32])
@ray.remote
def compute_accuracy(X, Y, weights):
"""Returns the accuracy of the network
+2 -4
View File
@@ -82,15 +82,13 @@ complicated version of this remote function is defined in
[hyperopt.py](hyperopt.py).
```python
@ray.remote([dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float])
@ray.remote
def train_cnn_and_compute_accuracy(hyperparameters, train_images, train_labels, validation_images, validation_labels):
# Actual work omitted.
return validation_accuracy
```
The only difference is that we added the `@ray.remote` decorator specifying a
little bit of type information (the input is a dictionary along with some numpy
arrays, and the return value is a float).
The only difference is that we added the `@ray.remote` decorator.
Now a call to `train_cnn_and_compute_accuracy` does not execute the function. It
submits the task to the scheduler and returns an object ID for the output
+1 -1
View File
@@ -51,7 +51,7 @@ def cnn_setup(x, y, keep_prob, lr, stddev):
# Define a remote function that takes a set of hyperparameters as well as the
# data, consructs and trains a network, and returns the validation accuracy.
@ray.remote([dict, int, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float])
@ray.remote
def train_cnn_and_compute_accuracy(params, steps, train_images, train_labels, validation_images, validation_labels):
# Extract the hyperparameters from the params dictionary.
learning_rate = params["learning_rate"]
+3 -5
View File
@@ -91,20 +91,18 @@ use remote functions to distribute the loading of the data.
Now, lets turn `loss` and `grad` into remote functions.
```python
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [float])
@ray.remote
def loss(theta, xs, ys):
# compute the loss
return loss
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def grad(theta, xs, ys):
# compute the gradient
return grad
```
The only difference is that we added the `@ray.remote` decorator specifying a
little bit of type information (the inputs consist of numpy arrays, `loss`
returns a float, and `grad` returns a numpy array).
The only difference is that we added the `@ray.remote` decorator.
Now, it is easy to speed up the computation of the full loss and the full
gradient.
+2 -2
View File
@@ -74,14 +74,14 @@ if __name__ == "__main__":
sess.run([update_w, update_b], feed_dict={w_new: theta[:w_size].reshape(w_shape), b_new: theta[w_size:]})
# Compute the loss on a batch of data.
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [float])
@ray.remote
def loss(theta, xs, ys):
sess, _, _, cross_entropy, _, x, y_, _, _ = ray.reusables.net_vars
load_weights(theta)
return float(sess.run(cross_entropy, feed_dict={x: xs, y_: ys}))
# Compute the gradient of the loss on a batch of data.
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def grad(theta, xs, ys):
sess, _, _, _, cross_entropy_grads, x, y_, _, _ = ray.reusables.net_vars
load_weights(theta)
+1 -1
View File
@@ -32,7 +32,7 @@ estimate of the gradient. Below is a simplified pseudocode version of this
function.
```python
@ray.remote([dict], [dict, float])
@ray.remote(num_return_vals=2)
def compute_gradient(model):
# Retrieve the game environment.
env = ray.reusables.env
+1 -1
View File
@@ -72,7 +72,7 @@ def policy_backward(eph, epx, epdlogp, model):
dW1 = np.dot(dh.T, epx)
return {"W1": dW1, "W2": dW2}
@ray.remote([dict], [dict, float])
@ray.remote(num_return_vals=2)
def compute_gradient(model):
env = ray.reusables.env
observation = env.reset()
+1 -1
View File
@@ -79,7 +79,7 @@ we use reusable variables to store the gym environment and the neural network po
then used in the remote `do_rollout` function to do a remote rollout:
```python
@ray.remote([np.ndarray, int, int], [dict])
@ray.remote
def do_rollout(policy, timestep_limit, seed):
# Retrieve the game environment.
env = ray.reusables.env
+2 -2
View File
@@ -31,11 +31,11 @@ if [[ $platform == "linux" ]]; then
# These commands must be kept in sync with the installation instructions.
sudo apt-get update
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
sudo pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
elif [[ $platform == "macosx" ]]; then
# These commands must be kept in sync with the installation instructions.
brew install git cmake automake autoconf libtool boost graphviz
sudo easy_install pip
sudo pip install ipython --user
sudo pip install numpy typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
fi
+14 -15
View File
@@ -1,4 +1,3 @@
from typing import List
import numpy as np
import ray.array.remote as ra
import ray
@@ -66,12 +65,12 @@ class DistArray(object):
a = self.assemble()
return a[sliced]
@ray.remote([DistArray], [np.ndarray])
@ray.remote
def assemble(a):
return a.assemble()
# TODO(rkn): what should we call this method
@ray.remote([np.ndarray], [DistArray])
@ray.remote
def numpy_to_dist(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
@@ -80,28 +79,28 @@ def numpy_to_dist(a):
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
return result
@ray.remote([List, str], [DistArray])
@ray.remote
def zeros(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([List, str], [DistArray])
@ray.remote
def ones(shape, dtype_name="float"):
result = DistArray(shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ray.remote
def copy(a):
result = DistArray(a.shape)
for index in np.ndindex(*result.num_blocks):
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
return result
@ray.remote([int, int, str], [DistArray])
@ray.remote
def eye(dim1, dim2=-1, dtype_name="float"):
dim2 = dim1 if dim2 == -1 else dim2
shape = [dim1, dim2]
@@ -114,7 +113,7 @@ def eye(dim1, dim2=-1, dtype_name="float"):
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
return result
@ray.remote([DistArray], [DistArray])
@ray.remote
def triu(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
@@ -128,7 +127,7 @@ def triu(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote([DistArray], [DistArray])
@ray.remote
def tril(a):
if a.ndim != 2:
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
@@ -142,7 +141,7 @@ def tril(a):
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
return result
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def blockwise_dot(*matrices):
n = len(matrices)
if n % 2 != 0:
@@ -153,7 +152,7 @@ def blockwise_dot(*matrices):
result += np.dot(matrices[i], matrices[n / 2 + i])
return result
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote
def dot(a, b):
if a.ndim != 2:
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
@@ -168,7 +167,7 @@ def dot(a, b):
result.objectids[i, j] = blockwise_dot.remote(*args)
return result
@ray.remote([DistArray, List], [DistArray])
@ray.remote
def subblocks(a, *ranges):
"""
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
@@ -198,7 +197,7 @@ def subblocks(a, *ranges):
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
return result
@ray.remote([DistArray], [DistArray])
@ray.remote
def transpose(a):
if a.ndim != 2:
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
@@ -209,7 +208,7 @@ def transpose(a):
return result
# TODO(rkn): support broadcasting?
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote
def add(x1, x2):
if x1.shape != x2.shape:
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
@@ -219,7 +218,7 @@ def add(x1, x2):
return result
# TODO(rkn): support broadcasting?
@ray.remote([DistArray, DistArray], [DistArray])
@ray.remote
def subtract(x1, x2):
if x1.shape != x2.shape:
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
+8 -8
View File
@@ -6,7 +6,7 @@ from core import *
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
@ray.remote([DistArray], [DistArray, np.ndarray])
@ray.remote(num_return_vals=2)
def tsqr(a):
"""
arguments:
@@ -75,7 +75,7 @@ def tsqr(a):
return q_result, r
# TODO(rkn): This is unoptimized, we really want a block version of this.
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=3)
def modified_lu(q):
"""
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
@@ -105,19 +105,19 @@ def modified_lu(q):
U = np.triu(q_work)[:b, :]
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tsqr_hr_helper1(u, s, y_top_block, b):
y_top = y_top_block[:b, :b]
s_full = np.diag(s)
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
return t, y_top
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def tsqr_hr_helper2(s, r_temp):
s_full = np.diag(s)
return np.dot(s_full, r_temp)
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=4)
def tsqr_hr(a):
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
q, r_temp = tsqr.remote(a)
@@ -127,15 +127,15 @@ def tsqr_hr(a):
r = tsqr_hr_helper2.remote(s, r_temp)
return y, t, y_top, r
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def qr_helper1(a_rc, y_ri, t, W_c):
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def qr_helper2(y_ri, a_rc):
return np.dot(y_ri.T, a_rc)
@ray.remote([DistArray], [DistArray, DistArray])
@ray.remote(num_return_vals=2)
def qr(a):
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
m, n = a.shape[0], a.shape[1]
+1 -3
View File
@@ -1,12 +1,10 @@
from typing import List
import numpy as np
import ray.array.remote as ra
import ray
from core import *
@ray.remote([List], [DistArray])
@ray.remote
def normal(shape):
num_blocks = DistArray.compute_num_blocks(shape)
objectids = np.empty(num_blocks, dtype=object)
+18 -19
View File
@@ -1,83 +1,82 @@
from typing import List, Any
import numpy as np
import ray
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
@ray.remote([List, str, str], [np.ndarray])
@ray.remote
def zeros(shape, dtype_name="float", order="C"):
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote([np.ndarray, str, str, bool], [np.ndarray])
@ray.remote
def zeros_like(a, dtype_name="None", order="K", subok=True):
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
@ray.remote([List, str, str], [np.ndarray])
@ray.remote
def ones(shape, dtype_name="float", order="C"):
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
@ray.remote([int, int, int, str], [np.ndarray])
@ray.remote
def eye(N, M=-1, k=0, dtype_name="float"):
M = N if M == -1 else M
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def dot(a, b):
return np.dot(a, b)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def vstack(*xs):
return np.vstack(xs)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def hstack(*xs):
return np.hstack(xs)
# TODO(rkn): instead of this, consider implementing slicing
@ray.remote([np.ndarray, List, List], [np.ndarray])
@ray.remote
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
@ray.remote([np.ndarray, str], [np.ndarray])
@ray.remote
def copy(a, order="K"):
return np.copy(a, order=order)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote
def tril(m, k=0):
return np.tril(m, k=k)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote
def triu(m, k=0):
return np.triu(m, k=k)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote
def diag(v, k=0):
return np.diag(v, k=k)
@ray.remote([np.ndarray, List], [np.ndarray])
@ray.remote
def transpose(a, axes=[]):
axes = None if axes == [] else axes
return np.transpose(a, axes=axes)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def add(x1, x2):
return np.add(x1, x2)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def subtract(x1, x2):
return np.subtract(x1, x2)
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote
def sum(x, axis=-1):
return np.sum(x, axis=axis if axis != -1 else None)
@ray.remote([np.ndarray], [tuple])
@ray.remote
def shape(a):
return np.shape(a)
# We use Any to allow different numerical types as well as numpy arrays.
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
@ray.remote([Any], [Any])
@ray.remote
def sum_list(*xs):
return np.sum(xs, axis=0)
+20 -21
View File
@@ -1,4 +1,3 @@
from typing import List
import numpy as np
import ray
@@ -7,82 +6,82 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
"LinAlgError", "multi_dot"]
@ray.remote([np.ndarray, int], [np.ndarray])
@ray.remote
def matrix_power(M, n):
return np.linalg.matrix_power(M, n)
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
@ray.remote
def solve(a, b):
return np.linalg.solve(a, b)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tensorsolve(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def tensorinv(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def inv(a):
return np.linalg.inv(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def cholesky(a):
return np.linalg.cholesky(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def eigvals(a):
return np.linalg.eigvals(a)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def eigvalsh(a):
raise NotImplementedError
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def pinv(a):
return np.linalg.pinv(a)
@ray.remote([np.ndarray], [int])
@ray.remote
def slogdet(a):
raise NotImplementedError
@ray.remote([np.ndarray], [float])
@ray.remote
def det(a):
return np.linalg.det(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, np.ndarray])
@ray.remote(num_return_vals=3)
def svd(a):
return np.linalg.svd(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def eig(a):
return np.linalg.eig(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def eigh(a):
return np.linalg.eigh(a)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, int, np.ndarray])
@ray.remote(num_return_vals=4)
def lstsq(a, b):
return np.linalg.lstsq(a)
@ray.remote([np.ndarray], [float])
@ray.remote
def norm(x):
return np.linalg.norm(x)
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
@ray.remote(num_return_vals=2)
def qr(a):
return np.linalg.qr(a)
@ray.remote([np.ndarray], [float])
@ray.remote
def cond(x):
return np.linalg.cond(x)
@ray.remote([np.ndarray], [int])
@ray.remote
def matrix_rank(M):
return np.linalg.matrix_rank(M)
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def multi_dot(*a):
raise NotImplementedError
+1 -2
View File
@@ -1,7 +1,6 @@
from typing import List
import numpy as np
import ray
@ray.remote([List], [np.ndarray])
@ray.remote
def normal(shape):
return np.random.normal(size=shape)
-9
View File
@@ -1,7 +1,6 @@
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
import sys
import typing
from ctypes import c_void_p
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
@@ -38,9 +37,6 @@ def _fill_function(func, globals, defaults, closure, dict):
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
return result
def _create_type(type_repr):
return eval(type_repr.replace("~", ""), None, (lambda d: d.setdefault("typing", typing) and None or d)(dict(typing.__dict__)))
class BetterPickler(CloudPickler):
def save_function_tuple(self, func):
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
@@ -63,10 +59,5 @@ class BetterPickler(CloudPickler):
self.save(cloudpickle._make_cell)
self.save((obj.cell_contents,))
self.write(pickle.REDUCE)
def save_type(self, obj):
self.save(_create_type)
self.save((repr(obj),))
self.write(pickle.REDUCE)
dispatch = CloudPickler.dispatch.copy()
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
# dispatch[typing.GenericMeta] = save_type
+1 -1
View File
@@ -28,7 +28,7 @@ class Tuple(tuple):
class Str(str):
pass
class Unicode(unicode):
pass
+85 -226
View File
@@ -4,8 +4,6 @@ import time
import traceback
import copy
import logging
from types import ModuleType
import typing
import funcsigs
import numpy as np
import colorama
@@ -45,7 +43,7 @@ class RayTaskError(Exception):
def __init__(self, function_name, exception, traceback_str):
"""Initialize a RayTaskError."""
self.function_name = function_name
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
self.exception = exception
else:
self.exception = None
@@ -59,8 +57,6 @@ class RayTaskError(Exception):
exception = RayGetError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentError":
exception = RayGetArgumentError.deserialize(exception[1])
elif exception[0] == "RayGetArgumentTypeError":
exception = RayGetArgumentTypeError.deserialize(exception[1])
elif exception[0] == "None":
exception = None
else:
@@ -73,8 +69,6 @@ class RayTaskError(Exception):
serialized_exception = ("RayGetError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentError):
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
elif isinstance(self.exception, RayGetArgumentTypeError):
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
elif self.exception is None:
serialized_exception = ("None",)
else:
@@ -151,41 +145,6 @@ class RayGetArgumentError(Exception):
"""Format a RayGetArgumentError as a string."""
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
class RayGetArgumentTypeError(Exception):
"""An exception used when a task's argument doesn't type check.
Attributes:
function_name (str): The name of the function for the current task.
argument_index (int): The index (zero indexed) of the argument in the
present task's remote function call.
received_type: The type of the argument that was passed in.
expected_type: The type that was expected. This is determined by the remote
decorator.
"""
def __init__(self, function_name, argument_index, received_type, expected_type):
"""Initialize a RayGetArgumentTypeError object."""
self.function_name = function_name
self.argument_index = argument_index
# TODO(rkn): when we support the serialization of types, then we should
# remove the string conversions below.
self.received_type = str(received_type)
self.expected_type = str(expected_type)
@staticmethod
def deserialize(primitives):
"""Create a RayGetArgumentTypeError from a primitive object."""
function_name, argument_index, received_type, expected_type = primitives
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
def serialize(self):
"""Turn a RayGetArgumentTypeError into a primitive object."""
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
def __str__(self):
"""Format a RayGetArgumentTypeError as a string."""
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
class RayDealloc(object):
"""An object used internally to properly implement reference counting.
@@ -232,13 +191,13 @@ class Reusable(object):
def __init__(self, initializer, reinitializer=None):
"""Initialize a Reusable object."""
if not isinstance(initializer, typing.Callable):
if not callable(initializer):
raise Exception("When creating a RayReusable, initializer must be a function.")
self.initializer = initializer
if reinitializer is None:
# If no reinitializer is passed in, use a wrapped version of the initializer.
reinitializer = lambda value: initializer()
if not isinstance(reinitializer, typing.Callable):
if not callable(reinitializer):
raise Exception("When creating a RayReusable, reinitializer must be a function.")
self.reinitializer = reinitializer
@@ -993,7 +952,7 @@ def main_loop(worker=global_worker):
def process_remote_function(function_name, serialized_function):
"""Import a remote function."""
try:
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
function, num_return_vals, module = pickling.loads(serialized_function)
except:
# If an exception was thrown when the remote function was imported, we
# record the traceback and notify the scheduler of the failure.
@@ -1005,11 +964,11 @@ def main_loop(worker=global_worker):
# TODO(rkn): Why is the below line necessary?
function.__module__ = module
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
worker.functions[function_name] = remote(num_return_vals=num_return_vals)(function)
_logger().info("Successfully imported remote function {}.".format(function_name))
# Noify the scheduler that the remote function imported successfully.
# We pass an empty error message string because the import succeeded.
raylib.register_remote_function(worker.handle, function_name, len(return_types))
raylib.register_remote_function(worker.handle, function_name, num_return_vals)
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
"""Import a reusable variable."""
@@ -1110,75 +1069,89 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
raise Exception("_export_reusable_variable can only be called on a driver.")
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
def remote(arg_types, return_types, worker=global_worker):
def remote(*args, **kwargs):
"""This decorator is used to create remote functions.
Args:
arg_types (List[type]): List of Python types of the function arguments.
return_types (List[type]): List of Python types of the return values.
num_return_vals (int): The number of object IDs that a call to this function
should return.
"""
def remote_decorator(func):
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
check_connected()
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if _mode() == raylib.PYTHON_MODE:
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
objectids = _submit_task(func_name, args)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
return objectids
def func_executor(arguments):
"""This gets run when the remote function is executed."""
_logger().info("Calling function {}".format(func.__name__))
start_time = time.time()
result = func(*arguments)
end_time = time.time()
check_return_values(func_invoker, result) # throws an exception if result is invalid
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
def func_invoker(*args, **kwargs):
"""This is returned by the decorator and used to invoke the function."""
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
func_invoker.remote = func_call
func_invoker.executor = func_executor
func_invoker.arg_types = arg_types
func_invoker.return_types = return_types
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
func_invoker.func_doc = func.func_doc
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
keyword_defaults = [(k, v.default) for k, v in sig_params]
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
func_invoker.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
worker = global_worker
def make_remote_decorator(num_return_vals):
def remote_decorator(func):
def func_call(*args, **kwargs):
"""This gets run immediately when a worker calls a remote function."""
check_connected()
args = list(args)
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
if any([arg is funcsigs._empty for arg in args]):
raise Exception("Not enough arguments were provided to {}.".format(func_name))
if _mode() == raylib.PYTHON_MODE:
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
# arguments to prevent the function call from mutating them and to match
# the usual behavior of immutable remote objects.
return func(*copy.deepcopy(args))
objectids = _submit_task(func_name, args)
if len(objectids) == 1:
return objectids[0]
elif len(objectids) > 1:
return objectids
def func_executor(arguments):
"""This gets run when the remote function is executed."""
_logger().info("Calling function {}".format(func.__name__))
start_time = time.time()
result = func(*arguments)
end_time = time.time()
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
return result
def func_invoker(*args, **kwargs):
"""This is returned by the decorator and used to invoke the function."""
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
func_invoker.remote = func_call
func_invoker.executor = func_executor
func_invoker.is_remote = True
func_name = "{}.{}".format(func.__module__, func.__name__)
func_invoker.func_name = func_name
func_invoker.func_doc = func.func_doc
# Everything ready - export the function
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raylib.export_remote_function(worker.handle, func_name, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append((func_name, to_export))
return func_invoker
return remote_decorator
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
keyword_defaults = [(k, v.default) for k, v in sig_params]
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
func_invoker.has_vararg_param = has_vararg_param
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
# Everything ready - export the function
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
func_name_global_valid = func.__name__ in func.__globals__
func_name_global_value = func.__globals__.get(func.__name__)
# Set the function globally to make it refer to itself
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
try:
to_export = pickling.dumps((func, num_return_vals, func.__module__))
finally:
# Undo our changes
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
else: del func.__globals__[func.__name__]
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
raylib.export_remote_function(worker.handle, func_name, to_export)
elif worker.mode is None:
worker.cached_remote_functions.append((func_name, to_export))
return func_invoker
return remote_decorator
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
num_return_vals = 1
func = args[0]
return make_remote_decorator(num_return_vals)(func)
else:
# This is the case where the decorator is something like
# @ray.remote(num_return_vals=2).
assert len(args) == 0 and "num_return_vals" in kwargs.keys(), "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
num_return_vals = kwargs["num_return_vals"]
return make_remote_decorator(num_return_vals)
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
"""Check if we support the signature of this function.
@@ -1205,109 +1178,12 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
def check_return_values(function, result):
"""Check the types and number of return values.
Args:
function (Callable): The remote function whose outputs are being checked.
result: The value returned by an invocation of the remote function. The
expected types and number are defined in the remote decorator.
Raises:
Exception: An exception is raised if the return values have incorrect types
or the function returned the wrong number of return values.
"""
# If the @remote decorator declares that the function has no return values,
# then all we do is check that there were in fact no return values.
if len(function.return_types) == 0:
if result is not None:
raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__))
return
# If a function has multiple return values, Python returns a tuple of the
# values. If there is a single return value, then Python does not return a
# tuple, it simply returns the value. That is why we place result with
# (result,) when there is only one return value, so we can treat these two
# cases similarly.
if len(function.return_types) == 1:
result = (result,)
# Below we check that the number of values returned by the function match the
# number of return values declared in the @remote decorator.
if len(result) != len(function.return_types):
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
# Here we do some limited type checking to make sure the return values have
# the right types.
for i in range(len(result)):
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
def typecheck_arg(arg, expected_type, i, name):
"""Check that an argument has the expected type.
Args:
arg: An argument to function.
expected_type (type): The expected type of arg.
i (int): The position of the argument to the function.
name (str): The name of the function.
Raises:
RayGetArgumentTypeError: An exception is raised if arg does not have the
expected type.
"""
if issubclass(type(arg), expected_type):
# Passed the type-checck
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
pass
elif isinstance(arg, long) and issubclass(int, expected_type):
# TODO(mehrdadn): Should long really be convertible to int?
pass
else:
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
def check_arguments(arg_types, has_vararg_param, name, args):
"""Check that the arguments to the remote function have the right types.
This is called by the worker that calls the remote function (not the worker
that executes the remote function).
Args:
arg_types (List[type]): A list of the types of the arguments to the function
being checked.
has_vararg_param (bool): True if the function being checked has a *args
argument.
name (str): The name of the function.
args (List): The arguments to the function.
Raises:
Exception: An exception is raised the args do not all have the right types.
"""
# check the number of args
if len(args) != len(arg_types) and not has_vararg_param:
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
elif len(args) < len(arg_types) - 1 and has_vararg_param:
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i <= len(arg_types) - 1:
expected_type = arg_types[i]
elif has_vararg_param:
expected_type = arg_types[-1]
else:
assert False, "This code should be unreachable."
if isinstance(arg, raylib.ObjectID):
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
pass
else:
typecheck_arg(arg, expected_type, i, name)
def get_arguments_for_execution(function, args, worker=global_worker):
"""Retrieve the arguments for the remote function.
This retrieves the values for the arguments to the remote function that were
passed in as object IDs. Argumens that were passed by value are not changed.
This also does some type checking. This is called by the worker that is
executing the remote function.
This is called by the worker that is executing the remote function.
Args:
function (Callable): The remote function whose arguments are being
@@ -1321,25 +1197,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
Raises:
RayGetArgumentError: This exception is raised if a task that created one of
the arguments failed.
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
of the arguments does not have the expected type.
"""
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
arguments = []
# # check the number of args
# if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
# raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
# elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
# raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
for (i, arg) in enumerate(args):
if i <= len(function.arg_types) - 1:
expected_type = function.arg_types[i]
elif function.has_vararg_param and len(function.arg_types) >= 1:
expected_type = function.arg_types[-1]
else:
assert False, "This code should be unreachable."
if isinstance(arg, raylib.ObjectID):
# get the object from the local object store
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
@@ -1353,7 +1213,6 @@ def get_arguments_for_execution(function, args, worker=global_worker):
# pass the argument by value
argument = arg
typecheck_arg(argument, expected_type, i, function.__name__)
arguments.append(argument)
return arguments
+2 -30
View File
@@ -6,34 +6,6 @@ import test_functions
class FailureTest(unittest.TestCase):
def testNoArgs(self):
reload(test_functions)
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
test_functions.no_op_fail.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 1)
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
ray.worker.cleanup()
def testTypeChecking(self):
reload(test_functions)
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
# Make sure that these functions throw exceptions because there return
# values do not type check.
test_functions.test_return1.remote()
test_functions.test_return2.remote()
time.sleep(0.2)
task_info = ray.task_info()
self.assertEqual(len(task_info["failed_tasks"]), 2)
self.assertEqual(len(task_info["running_tasks"]), 0)
ray.worker.cleanup()
def testUnknownSerialization(self):
reload(test_functions)
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
@@ -104,7 +76,7 @@ class TaskStatusTest(unittest.TestCase):
return reducer, ()
def __call__(self):
return
ray.remote([], [])(Foo())
ray.remote(Foo())
for _ in range(100): # Retry if we need to wait longer.
if len(ray.task_info()["failed_remote_function_imports"]) >= 1:
break
@@ -140,7 +112,7 @@ class TaskStatusTest(unittest.TestCase):
def reinitializer(foo):
raise Exception("The reinitializer failed.")
ray.reusables.foo = ray.Reusable(initializer, reinitializer)
@ray.remote([], [])
@ray.remote
def use_foo():
ray.reusables.foo
use_foo.remote()
+21 -21
View File
@@ -262,42 +262,42 @@ class APITest(unittest.TestCase):
ray.init(start_ray_local=True, num_workers=2)
# Test that we can define a remote function in the shell.
@ray.remote([int], [int])
@ray.remote
def f(x):
return x + 1
self.assertEqual(ray.get(f.remote(0)), 1)
# Test that we can redefine the remote function.
@ray.remote([int], [int])
@ray.remote
def f(x):
return x + 10
self.assertEqual(ray.get(f.remote(0)), 10)
# Test that we can close over plain old data.
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
@ray.remote([], [list])
@ray.remote
def g():
return data
ray.get(g.remote())
# Test that we can close over modules.
@ray.remote([], [np.ndarray])
@ray.remote
def h():
return np.zeros([3, 5])
assert_equal(ray.get(h.remote()), np.zeros([3, 5]))
@ray.remote([], [float])
@ray.remote
def j():
return time.time()
ray.get(j.remote())
# Test that we can define remote functions that call other remote functions.
@ray.remote([int], [int])
@ray.remote
def k(x):
return x + 1
@ray.remote([int], [int])
@ray.remote
def l(x):
return k.remote(x)
@ray.remote([int], [int])
@ray.remote
def m(x):
return ray.get(l.remote(x))
self.assertEqual(ray.get(k.remote(1)), 2)
@@ -309,7 +309,7 @@ class APITest(unittest.TestCase):
def testSelect(self):
ray.init(start_ray_local=True, num_workers=4)
@ray.remote([float], [int])
@ray.remote
def f(delay):
time.sleep(delay)
return 1
@@ -345,10 +345,10 @@ class APITest(unittest.TestCase):
ray.reusables.foo = ray.Reusable(foo_initializer)
ray.reusables.bar = ray.Reusable(bar_initializer, bar_reinitializer)
@ray.remote([], [int])
@ray.remote
def use_foo():
return ray.reusables.foo
@ray.remote([], [list])
@ray.remote
def use_bar():
ray.reusables.bar.append(1)
return ray.reusables.bar
@@ -368,7 +368,7 @@ class APITest(unittest.TestCase):
def f():
sys.path.append("fake_directory")
ray.worker.global_worker.run_function_on_all_workers(f)
@ray.remote([], [list])
@ray.remote
def get_path():
return sys.path
self.assertEqual("fake_directory", ray.get(get_path.remote())[-1])
@@ -509,7 +509,7 @@ class PythonCExtensionTest(unittest.TestCase):
ray.init(start_ray_local=True, num_workers=1)
# Make sure that we aren't accidentally messing up Python's reference counts.
@ray.remote([], [int])
@ray.remote
def f():
return sys.getrefcount(None)
first_count = ray.get(f.remote())
@@ -522,7 +522,7 @@ class PythonCExtensionTest(unittest.TestCase):
ray.init(start_ray_local=True, num_workers=1)
# Make sure that we aren't accidentally messing up Python's reference counts.
@ray.remote([], [int])
@ray.remote
def f():
return sys.getrefcount(True)
first_count = ray.get(f.remote())
@@ -535,7 +535,7 @@ class PythonCExtensionTest(unittest.TestCase):
ray.init(start_ray_local=True, num_workers=1)
# Make sure that we aren't accidentally messing up Python's reference counts.
@ray.remote([], [int])
@ray.remote
def f():
return sys.getrefcount(False)
first_count = ray.get(f.remote())
@@ -559,7 +559,7 @@ class ReusablesTest(unittest.TestCase):
ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer)
self.assertEqual(ray.reusables.foo, 1)
@ray.remote([], [int])
@ray.remote
def use_foo():
return ray.reusables.foo
self.assertEqual(ray.get(use_foo.remote()), 1)
@@ -573,7 +573,7 @@ class ReusablesTest(unittest.TestCase):
ray.reusables.bar = ray.Reusable(bar_initializer)
@ray.remote([], [list])
@ray.remote
def use_bar():
ray.reusables.bar.append(4)
return ray.reusables.bar
@@ -592,7 +592,7 @@ class ReusablesTest(unittest.TestCase):
ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer)
@ray.remote([int], [np.ndarray])
@ray.remote
def use_baz(i):
baz = ray.reusables.baz
baz[i] = 1
@@ -613,7 +613,7 @@ class ReusablesTest(unittest.TestCase):
ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer)
@ray.remote([], [int])
@ray.remote
def use_qux():
return ray.reusables.qux
self.assertEqual(ray.get(use_qux.remote()), 0)
@@ -634,7 +634,7 @@ class ClusterAttachingTest(unittest.TestCase):
ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address)
@ray.remote([int], [int])
@ray.remote
def f(x):
return x + 1
self.assertEqual(ray.get(f.remote(0)), 1)
@@ -653,7 +653,7 @@ class ClusterAttachingTest(unittest.TestCase):
ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address)
@ray.remote([int], [int])
@ray.remote
def f(x):
return x + 1
self.assertEqual(ray.get(f.remote(0)), 1)
+20 -34
View File
@@ -4,60 +4,60 @@ import numpy as np
# Test simple functionality
@ray.remote([int, int], [int, int])
@ray.remote(num_return_vals=2)
def handle_int(a, b):
return a + 1, b + 1
# Test aliasing
@ray.remote([], [np.ndarray])
@ray.remote
def test_alias_f():
return np.ones([3, 4, 5])
@ray.remote([], [np.ndarray])
@ray.remote
def test_alias_g():
return test_alias_f.remote()
@ray.remote([], [np.ndarray])
@ray.remote
def test_alias_h():
return test_alias_g.remote()
# Test timing
@ray.remote([], [])
@ray.remote
def empty_function():
pass
@ray.remote([], [int])
@ray.remote
def trivial_function():
return 1
# Test keyword arguments
@ray.remote([int, str], [str])
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@ray.remote([str, str], [str])
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@ray.remote([int, int, str, str], [str])
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
# Test variable numbers of arguments
@ray.remote([int], [str])
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote([int, int], [str])
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
try:
@ray.remote([int], [])
@ray.remote
def kwargs_throw_exception(**c):
return ()
kwargs_exception_thrown = False
@@ -65,7 +65,7 @@ except:
kwargs_exception_thrown = True
try:
@ray.remote([int, str, int], [str])
@ray.remote
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
return "{} {} {}".format(a, b, c)
varargs_and_kwargs_exception_thrown = False
@@ -74,53 +74,39 @@ except:
# test throwing an exception
@ray.remote([], [])
@ray.remote
def throw_exception_fct1():
raise Exception("Test function 1 intentionally failed.")
@ray.remote([], [int])
@ray.remote
def throw_exception_fct2():
raise Exception("Test function 2 intentionally failed.")
@ray.remote([float], [int, str, np.ndarray])
@ray.remote(num_return_vals=3)
def throw_exception_fct3(x):
raise Exception("Test function 3 intentionally failed.")
# test Python mode
@ray.remote([], [np.ndarray])
@ray.remote
def python_mode_f():
return np.array([0, 0])
@ray.remote([np.ndarray], [np.ndarray])
@ray.remote
def python_mode_g(x):
x[0] = 1
return x
# test no return values
@ray.remote([], [])
@ray.remote
def no_op():
pass
@ray.remote([], [])
def no_op_fail():
return 0
# test wrong return types
@ray.remote([], [int])
def test_return1():
return 0.0
@ray.remote([], [int, float])
def test_return2():
return 2.0, 3.0
class TestClass(object):
def __init__(self):
self.a = 5
@ray.remote([], [TestClass])
@ray.remote
def test_unknown_type():
return TestClass()