mirror of
https://github.com/wassname/ray.git
synced 2026-07-06 01:07:38 +08:00
Remove type information from remote decorator. (#394)
* Remove type information from remote decorator. * Remove typing module. * Fix failure_test.py. * Allow remote decorator to be used with no parentheses. * Fix documentation.
This commit is contained in:
@@ -21,7 +21,7 @@ import numpy as np
|
||||
ray.init(start_ray_local=True, num_workers=10)
|
||||
|
||||
# Define a remote function for estimating pi.
|
||||
@ray.remote([int], [float])
|
||||
@ray.remote
|
||||
def estimate_pi(n):
|
||||
x = np.random.uniform(size=n)
|
||||
y = np.random.uniform(size=n)
|
||||
|
||||
+3
-3
@@ -10,15 +10,15 @@ However, to provide a more flexible API, we allow tasks to not only return
|
||||
values, but to also return object ids to values. As an examples, consider
|
||||
the following code.
|
||||
```python
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def f()
|
||||
return np.zeros(5)
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def g()
|
||||
return f()
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def h()
|
||||
return g()
|
||||
```
|
||||
|
||||
+1
-1
@@ -18,7 +18,7 @@ import shlex
|
||||
|
||||
# These 4 lines added to enable ReadTheDocs to work.
|
||||
import mock
|
||||
MOCK_MODULES = ["libraylib", "IPython", "numpy", "typing", "funcsigs", "subprocess32", "protobuf", "colorama", "graphviz", "cloudpickle", "ray.internal.graph_pb2"]
|
||||
MOCK_MODULES = ["libraylib", "IPython", "numpy", "funcsigs", "subprocess32", "protobuf", "colorama", "graphviz", "cloudpickle", "ray.internal.graph_pb2"]
|
||||
for mod_name in MOCK_MODULES:
|
||||
sys.modules[mod_name] = mock.Mock()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ brew update
|
||||
brew install git cmake automake autoconf libtool boost graphviz
|
||||
sudo easy_install pip
|
||||
sudo pip install ipython --user
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
|
||||
sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
@@ -15,7 +15,7 @@ First install the dependencies. We currently do not support Python 3.
|
||||
```
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
```
|
||||
|
||||
## Build
|
||||
|
||||
@@ -5,7 +5,7 @@ functions. Remote functions are written like regular Python functions, but with
|
||||
the `@ray.remote` decorator on top.
|
||||
|
||||
```python
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def increment(n):
|
||||
return n + 1
|
||||
```
|
||||
@@ -68,7 +68,7 @@ class ExampleClass(object):
|
||||
# This example assumes that field1 and field2 are serializable types.
|
||||
self.field1 = field1
|
||||
self.field2 = field2
|
||||
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
(field1, field2) = primitives
|
||||
|
||||
+7
-14
@@ -107,18 +107,11 @@ def add(a, b):
|
||||
```
|
||||
A remote function in Ray looks like this.
|
||||
```python
|
||||
@ray.remote([int, int], [int])
|
||||
@ray.remote
|
||||
def add(a, b):
|
||||
return a + b
|
||||
```
|
||||
|
||||
The information passed to the `@ray.remote` decorator includes type information
|
||||
for the arguments and for the return values of the function. Because of the
|
||||
distinction that we make between *submitting a task* and *executing the task*,
|
||||
we require type information so that we can catch type errors when the remote
|
||||
function is called instead of catching them when the task is actually executed
|
||||
(which could be much later and could be on a different machine).
|
||||
|
||||
### Remote functions
|
||||
|
||||
Whereas in regular Python, calling `add(1, 2)` would return `3`, in Ray, calling
|
||||
@@ -156,7 +149,7 @@ passed into the actual execution of the remote function.
|
||||
Note that a remote function can return multiple object IDs.
|
||||
|
||||
```python
|
||||
@ray.remote([], [int, float, str])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def return_multiple():
|
||||
return 0, 0.0, "zero"
|
||||
|
||||
@@ -194,7 +187,7 @@ around `time.sleep`.
|
||||
```python
|
||||
import time
|
||||
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def sleep(n):
|
||||
time.sleep(n)
|
||||
return 0
|
||||
@@ -245,11 +238,11 @@ Computation graphs encode dependencies. For example, suppose we define
|
||||
```python
|
||||
import numpy as np
|
||||
|
||||
@ray.remote([list], [np.ndarray])
|
||||
@ray.remote
|
||||
def zeros(shape):
|
||||
return np.zeros(shape)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
```
|
||||
@@ -282,12 +275,12 @@ processes can also call remote functions. To illustrate this, consider the
|
||||
following example.
|
||||
|
||||
```python
|
||||
@ray.remote([int, int], [int])
|
||||
@ray.remote
|
||||
def sub_experiment(i, j):
|
||||
# Run the jth sub-experiment for the ith experiment.
|
||||
return i + j
|
||||
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def run_experiment(i):
|
||||
sub_results = []
|
||||
# Launch tasks to perform 10 sub-experiments in parallel.
|
||||
|
||||
@@ -6,7 +6,7 @@ RUN apt-get update
|
||||
RUN apt-get -y install apt-utils
|
||||
RUN apt-get -y install sudo
|
||||
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user
|
||||
RUN adduser ray-user sudo
|
||||
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
|
||||
|
||||
@@ -6,7 +6,7 @@ RUN apt-get update
|
||||
RUN apt-get -y install apt-utils
|
||||
RUN apt-get -y install sudo
|
||||
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user --uid 500
|
||||
RUN adduser ray-user sudo
|
||||
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
|
||||
|
||||
@@ -7,7 +7,7 @@ RUN apt-get update
|
||||
RUN apt-get -y install apt-utils
|
||||
RUN apt-get -y install sudo
|
||||
RUN apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
RUN pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
RUN adduser --gecos --ingroup ray-user --disabled-login --gecos ray-user
|
||||
RUN adduser ray-user sudo
|
||||
RUN sed -i "s|%sudo\tALL=(ALL:ALL) ALL|%sudo\tALL=NOPASSWD: ALL|" /etc/sudoers
|
||||
|
||||
@@ -59,7 +59,7 @@ workers. At the core of our loading code is the remote function
|
||||
retrieves the appropriate object.
|
||||
|
||||
```python
|
||||
@ray.remote([str, str, List], [np.ndarray, List])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def load_tarfile_from_s3(bucket, s3_key, size=[]):
|
||||
# Pull the object with the given key and bucket from S3, untar the contents,
|
||||
# and return it.
|
||||
@@ -85,7 +85,7 @@ The other parallel component of this application is the training procedure. This
|
||||
is built on top of the remote function `compute_grad`.
|
||||
|
||||
```python
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List])
|
||||
@ray.remote
|
||||
def compute_grad(X, Y, mean, weights):
|
||||
# Load the weights into the network.
|
||||
# Subtract the mean and crop the images.
|
||||
|
||||
@@ -7,7 +7,6 @@ import tarfile, io
|
||||
import boto3
|
||||
import PIL.Image as Image
|
||||
import tensorflow as tf
|
||||
from typing import List, Tuple
|
||||
|
||||
import ray.array.remote as ra
|
||||
|
||||
@@ -39,7 +38,7 @@ def load_chunk(tarfile, size=None):
|
||||
filenames.append(filename)
|
||||
return np.concatenate(result), filenames
|
||||
|
||||
@ray.remote([str, str, List], [np.ndarray, List])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def load_tarfile_from_s3(bucket, s3_key, size=[]):
|
||||
"""Load an imagenet .tar file.
|
||||
|
||||
@@ -231,7 +230,7 @@ def net_initialization():
|
||||
def net_reinitialization(net_vars):
|
||||
return net_vars
|
||||
|
||||
@ray.remote([List], [int])
|
||||
@ray.remote
|
||||
def num_images(batches):
|
||||
"""Counts number of images in batches.
|
||||
|
||||
@@ -244,7 +243,7 @@ def num_images(batches):
|
||||
shape_ids = [ra.shape.remote(batch) for batch in batches]
|
||||
return sum([ray.get(shape_id)[0] for shape_id in shape_ids])
|
||||
|
||||
@ray.remote([List], [np.ndarray])
|
||||
@ray.remote
|
||||
def compute_mean_image(batches):
|
||||
"""Computes the mean image given a list of batches of images.
|
||||
|
||||
@@ -261,7 +260,7 @@ def compute_mean_image(batches):
|
||||
n_images = num_images.remote(batches)
|
||||
return np.sum(sum_images, axis=0).astype("float64") / ray.get(n_images)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray, np.ndarray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=4)
|
||||
def shuffle_arrays(first_images, first_labels, second_images, second_labels):
|
||||
"""Shuffles the images and labels from two batches.
|
||||
|
||||
@@ -306,7 +305,7 @@ def shuffle_pair(first_batch, second_batch):
|
||||
images1, labels1, images2, labels2 = shuffle_arrays.remote(first_batch[0], first_batch[1], second_batch[0], second_batch[1])
|
||||
return (images1, labels1), (images2, labels2)
|
||||
|
||||
@ray.remote([list, dict], [np.ndarray])
|
||||
@ray.remote
|
||||
def filenames_to_labels(filenames, filename_label_dict):
|
||||
"""Converts filename strings to integer labels.
|
||||
|
||||
@@ -381,7 +380,7 @@ def shuffle(batches):
|
||||
new_batches.append(permuted_batches[-1])
|
||||
return new_batches
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, List], [List])
|
||||
@ray.remote
|
||||
def compute_grad(X, Y, mean, weights):
|
||||
"""Computes the gradient of the network.
|
||||
|
||||
@@ -406,7 +405,7 @@ def compute_grad(X, Y, mean, weights):
|
||||
# Compute the gradients.
|
||||
return sess.run([g for (g, v) in comp_grads], feed_dict={images: subset_X, y_true: subset_Y, dropout: 0.5})
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, List], [np.float32])
|
||||
@ray.remote
|
||||
def compute_accuracy(X, Y, weights):
|
||||
"""Returns the accuracy of the network
|
||||
|
||||
|
||||
@@ -82,15 +82,13 @@ complicated version of this remote function is defined in
|
||||
[hyperopt.py](hyperopt.py).
|
||||
|
||||
```python
|
||||
@ray.remote([dict, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float])
|
||||
@ray.remote
|
||||
def train_cnn_and_compute_accuracy(hyperparameters, train_images, train_labels, validation_images, validation_labels):
|
||||
# Actual work omitted.
|
||||
return validation_accuracy
|
||||
```
|
||||
|
||||
The only difference is that we added the `@ray.remote` decorator specifying a
|
||||
little bit of type information (the input is a dictionary along with some numpy
|
||||
arrays, and the return value is a float).
|
||||
The only difference is that we added the `@ray.remote` decorator.
|
||||
|
||||
Now a call to `train_cnn_and_compute_accuracy` does not execute the function. It
|
||||
submits the task to the scheduler and returns an object ID for the output
|
||||
|
||||
@@ -51,7 +51,7 @@ def cnn_setup(x, y, keep_prob, lr, stddev):
|
||||
|
||||
# Define a remote function that takes a set of hyperparameters as well as the
|
||||
# data, consructs and trains a network, and returns the validation accuracy.
|
||||
@ray.remote([dict, int, np.ndarray, np.ndarray, np.ndarray, np.ndarray], [float])
|
||||
@ray.remote
|
||||
def train_cnn_and_compute_accuracy(params, steps, train_images, train_labels, validation_images, validation_labels):
|
||||
# Extract the hyperparameters from the params dictionary.
|
||||
learning_rate = params["learning_rate"]
|
||||
|
||||
@@ -91,20 +91,18 @@ use remote functions to distribute the loading of the data.
|
||||
Now, lets turn `loss` and `grad` into remote functions.
|
||||
|
||||
```python
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [float])
|
||||
@ray.remote
|
||||
def loss(theta, xs, ys):
|
||||
# compute the loss
|
||||
return loss
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def grad(theta, xs, ys):
|
||||
# compute the gradient
|
||||
return grad
|
||||
```
|
||||
|
||||
The only difference is that we added the `@ray.remote` decorator specifying a
|
||||
little bit of type information (the inputs consist of numpy arrays, `loss`
|
||||
returns a float, and `grad` returns a numpy array).
|
||||
The only difference is that we added the `@ray.remote` decorator.
|
||||
|
||||
Now, it is easy to speed up the computation of the full loss and the full
|
||||
gradient.
|
||||
|
||||
@@ -74,14 +74,14 @@ if __name__ == "__main__":
|
||||
sess.run([update_w, update_b], feed_dict={w_new: theta[:w_size].reshape(w_shape), b_new: theta[w_size:]})
|
||||
|
||||
# Compute the loss on a batch of data.
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [float])
|
||||
@ray.remote
|
||||
def loss(theta, xs, ys):
|
||||
sess, _, _, cross_entropy, _, x, y_, _, _ = ray.reusables.net_vars
|
||||
load_weights(theta)
|
||||
return float(sess.run(cross_entropy, feed_dict={x: xs, y_: ys}))
|
||||
|
||||
# Compute the gradient of the loss on a batch of data.
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def grad(theta, xs, ys):
|
||||
sess, _, _, _, cross_entropy_grads, x, y_, _, _ = ray.reusables.net_vars
|
||||
load_weights(theta)
|
||||
|
||||
@@ -32,7 +32,7 @@ estimate of the gradient. Below is a simplified pseudocode version of this
|
||||
function.
|
||||
|
||||
```python
|
||||
@ray.remote([dict], [dict, float])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def compute_gradient(model):
|
||||
# Retrieve the game environment.
|
||||
env = ray.reusables.env
|
||||
|
||||
@@ -72,7 +72,7 @@ def policy_backward(eph, epx, epdlogp, model):
|
||||
dW1 = np.dot(dh.T, epx)
|
||||
return {"W1": dW1, "W2": dW2}
|
||||
|
||||
@ray.remote([dict], [dict, float])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def compute_gradient(model):
|
||||
env = ray.reusables.env
|
||||
observation = env.reset()
|
||||
|
||||
@@ -79,7 +79,7 @@ we use reusable variables to store the gym environment and the neural network po
|
||||
then used in the remote `do_rollout` function to do a remote rollout:
|
||||
|
||||
```python
|
||||
@ray.remote([np.ndarray, int, int], [dict])
|
||||
@ray.remote
|
||||
def do_rollout(policy, timestep_limit, seed):
|
||||
# Retrieve the game environment.
|
||||
env = ray.reusables.env
|
||||
|
||||
@@ -31,11 +31,11 @@ if [[ $platform == "linux" ]]; then
|
||||
# These commands must be kept in sync with the installation instructions.
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y git cmake build-essential autoconf curl libtool python-dev python-numpy python-pip libboost-all-dev unzip graphviz
|
||||
sudo pip install ipython typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
sudo pip install ipython funcsigs subprocess32 protobuf colorama graphviz cloudpickle
|
||||
elif [[ $platform == "macosx" ]]; then
|
||||
# These commands must be kept in sync with the installation instructions.
|
||||
brew install git cmake automake autoconf libtool boost graphviz
|
||||
sudo easy_install pip
|
||||
sudo pip install ipython --user
|
||||
sudo pip install numpy typing funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
|
||||
sudo pip install numpy funcsigs subprocess32 protobuf colorama graphviz cloudpickle --ignore-installed six
|
||||
fi
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import ray.array.remote as ra
|
||||
import ray
|
||||
@@ -66,12 +65,12 @@ class DistArray(object):
|
||||
a = self.assemble()
|
||||
return a[sliced]
|
||||
|
||||
@ray.remote([DistArray], [np.ndarray])
|
||||
@ray.remote
|
||||
def assemble(a):
|
||||
return a.assemble()
|
||||
|
||||
# TODO(rkn): what should we call this method
|
||||
@ray.remote([np.ndarray], [DistArray])
|
||||
@ray.remote
|
||||
def numpy_to_dist(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
@@ -80,28 +79,28 @@ def numpy_to_dist(a):
|
||||
result.objectids[index] = ray.put(a[[slice(l, u) for (l, u) in zip(lower, upper)]])
|
||||
return result
|
||||
|
||||
@ray.remote([List, str], [DistArray])
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.zeros.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([List, str], [DistArray])
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float"):
|
||||
result = DistArray(shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = ra.ones.remote(DistArray.compute_block_shape(index, shape), dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def copy(a):
|
||||
result = DistArray(a.shape)
|
||||
for index in np.ndindex(*result.num_blocks):
|
||||
result.objectids[index] = a.objectids[index] # We don't need to actually copy the objects because cluster-level objects are assumed to be immutable.
|
||||
return result
|
||||
|
||||
@ray.remote([int, int, str], [DistArray])
|
||||
@ray.remote
|
||||
def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
dim2 = dim1 if dim2 == -1 else dim2
|
||||
shape = [dim1, dim2]
|
||||
@@ -114,7 +113,7 @@ def eye(dim1, dim2=-1, dtype_name="float"):
|
||||
result.objectids[i, j] = ra.zeros.remote(block_shape, dtype_name=dtype_name)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def triu(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -128,7 +127,7 @@ def triu(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def tril(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("Input must have 2 dimensions, but a.ndim is " + str(a.ndim))
|
||||
@@ -142,7 +141,7 @@ def tril(a):
|
||||
result.objectids[i, j] = ra.zeros_like.remote(a.objectids[i, j])
|
||||
return result
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def blockwise_dot(*matrices):
|
||||
n = len(matrices)
|
||||
if n % 2 != 0:
|
||||
@@ -153,7 +152,7 @@ def blockwise_dot(*matrices):
|
||||
result += np.dot(matrices[i], matrices[n / 2 + i])
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
if a.ndim != 2:
|
||||
raise Exception("dot expects its arguments to be 2-dimensional, but a.ndim = {}.".format(a.ndim))
|
||||
@@ -168,7 +167,7 @@ def dot(a, b):
|
||||
result.objectids[i, j] = blockwise_dot.remote(*args)
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray, List], [DistArray])
|
||||
@ray.remote
|
||||
def subblocks(a, *ranges):
|
||||
"""
|
||||
This function produces a distributed array from a subset of the blocks in the `a`. The result and `a` will have the same number of dimensions.For example,
|
||||
@@ -198,7 +197,7 @@ def subblocks(a, *ranges):
|
||||
result.objectids[index] = a.objectids[tuple([ranges[i][index[i]] for i in range(a.ndim)])]
|
||||
return result
|
||||
|
||||
@ray.remote([DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def transpose(a):
|
||||
if a.ndim != 2:
|
||||
raise Exception("transpose expects its argument to be 2-dimensional, but a.ndim = {}, a.shape = {}.".format(a.ndim, a.shape))
|
||||
@@ -209,7 +208,7 @@ def transpose(a):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("add expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
@@ -219,7 +218,7 @@ def add(x1, x2):
|
||||
return result
|
||||
|
||||
# TODO(rkn): support broadcasting?
|
||||
@ray.remote([DistArray, DistArray], [DistArray])
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
if x1.shape != x2.shape:
|
||||
raise Exception("subtract expects arguments `x1` and `x2` to have the same shape, but x1.shape = {}, and x2.shape = {}.".format(x1.shape, x2.shape))
|
||||
|
||||
@@ -6,7 +6,7 @@ from core import *
|
||||
|
||||
__all__ = ["tsqr", "modified_lu", "tsqr_hr", "qr"]
|
||||
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr(a):
|
||||
"""
|
||||
arguments:
|
||||
@@ -75,7 +75,7 @@ def tsqr(a):
|
||||
return q_result, r
|
||||
|
||||
# TODO(rkn): This is unoptimized, we really want a block version of this.
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def modified_lu(q):
|
||||
"""
|
||||
Algorithm 5 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf
|
||||
@@ -105,19 +105,19 @@ def modified_lu(q):
|
||||
U = np.triu(q_work)[:b, :]
|
||||
return numpy_to_dist.remote(ray.put(L)), U, S # TODO(rkn): get rid of put
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, int], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tsqr_hr_helper1(u, s, y_top_block, b):
|
||||
y_top = y_top_block[:b, :b]
|
||||
s_full = np.diag(s)
|
||||
t = -1 * np.dot(u, np.dot(s_full, np.linalg.inv(y_top).T))
|
||||
return t, y_top
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def tsqr_hr_helper2(s, r_temp):
|
||||
s_full = np.diag(s)
|
||||
return np.dot(s_full, r_temp)
|
||||
|
||||
@ray.remote([DistArray], [DistArray, np.ndarray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=4)
|
||||
def tsqr_hr(a):
|
||||
"""Algorithm 6 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
q, r_temp = tsqr.remote(a)
|
||||
@@ -127,15 +127,15 @@ def tsqr_hr(a):
|
||||
r = tsqr_hr_helper2.remote(s, r_temp)
|
||||
return y, t, y_top, r
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray, np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def qr_helper1(a_rc, y_ri, t, W_c):
|
||||
return a_rc - np.dot(y_ri, np.dot(t.T, W_c))
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def qr_helper2(y_ri, a_rc):
|
||||
return np.dot(y_ri.T, a_rc)
|
||||
|
||||
@ray.remote([DistArray], [DistArray, DistArray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
"""Algorithm 7 from http://www.eecs.berkeley.edu/Pubs/TechRpts/2013/EECS-2013-175.pdf"""
|
||||
m, n = a.shape[0], a.shape[1]
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import ray.array.remote as ra
|
||||
import ray
|
||||
|
||||
from core import *
|
||||
|
||||
@ray.remote([List], [DistArray])
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
num_blocks = DistArray.compute_num_blocks(shape)
|
||||
objectids = np.empty(num_blocks, dtype=object)
|
||||
|
||||
@@ -1,83 +1,82 @@
|
||||
from typing import List, Any
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
__all__ = ["zeros", "zeros_like", "ones", "eye", "dot", "vstack", "hstack", "subarray", "copy", "tril", "triu", "diag", "transpose", "add", "subtract", "sum", "shape", "sum_list"]
|
||||
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
@ray.remote
|
||||
def zeros(shape, dtype_name="float", order="C"):
|
||||
return np.zeros(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote([np.ndarray, str, str, bool], [np.ndarray])
|
||||
@ray.remote
|
||||
def zeros_like(a, dtype_name="None", order="K", subok=True):
|
||||
dtype_val = None if dtype_name == "None" else np.dtype(dtype_name)
|
||||
return np.zeros_like(a, dtype=dtype_val, order=order, subok=subok)
|
||||
|
||||
@ray.remote([List, str, str], [np.ndarray])
|
||||
@ray.remote
|
||||
def ones(shape, dtype_name="float", order="C"):
|
||||
return np.ones(shape, dtype=np.dtype(dtype_name), order=order)
|
||||
|
||||
@ray.remote([int, int, int, str], [np.ndarray])
|
||||
@ray.remote
|
||||
def eye(N, M=-1, k=0, dtype_name="float"):
|
||||
M = N if M == -1 else M
|
||||
return np.eye(N, M=M, k=k, dtype=np.dtype(dtype_name))
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def dot(a, b):
|
||||
return np.dot(a, b)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def vstack(*xs):
|
||||
return np.vstack(xs)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def hstack(*xs):
|
||||
return np.hstack(xs)
|
||||
|
||||
# TODO(rkn): instead of this, consider implementing slicing
|
||||
@ray.remote([np.ndarray, List, List], [np.ndarray])
|
||||
@ray.remote
|
||||
def subarray(a, lower_indices, upper_indices): # TODO(rkn): be consistent about using "index" versus "indices"
|
||||
return a[[slice(l, u) for (l, u) in zip(lower_indices, upper_indices)]]
|
||||
|
||||
@ray.remote([np.ndarray, str], [np.ndarray])
|
||||
@ray.remote
|
||||
def copy(a, order="K"):
|
||||
return np.copy(a, order=order)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote
|
||||
def tril(m, k=0):
|
||||
return np.tril(m, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote
|
||||
def triu(m, k=0):
|
||||
return np.triu(m, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote
|
||||
def diag(v, k=0):
|
||||
return np.diag(v, k=k)
|
||||
|
||||
@ray.remote([np.ndarray, List], [np.ndarray])
|
||||
@ray.remote
|
||||
def transpose(a, axes=[]):
|
||||
axes = None if axes == [] else axes
|
||||
return np.transpose(a, axes=axes)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def add(x1, x2):
|
||||
return np.add(x1, x2)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def subtract(x1, x2):
|
||||
return np.subtract(x1, x2)
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote
|
||||
def sum(x, axis=-1):
|
||||
return np.sum(x, axis=axis if axis != -1 else None)
|
||||
|
||||
@ray.remote([np.ndarray], [tuple])
|
||||
@ray.remote
|
||||
def shape(a):
|
||||
return np.shape(a)
|
||||
|
||||
# We use Any to allow different numerical types as well as numpy arrays.
|
||||
# TODO(rkn):this isn't in the numpy API, so be careful about exposing this.
|
||||
@ray.remote([Any], [Any])
|
||||
@ray.remote
|
||||
def sum_list(*xs):
|
||||
return np.sum(xs, axis=0)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@@ -7,82 +6,82 @@ __all__ = ["matrix_power", "solve", "tensorsolve", "tensorinv", "inv",
|
||||
"svd", "eig", "eigh", "lstsq", "norm", "qr", "cond", "matrix_rank",
|
||||
"LinAlgError", "multi_dot"]
|
||||
|
||||
@ray.remote([np.ndarray, int], [np.ndarray])
|
||||
@ray.remote
|
||||
def matrix_power(M, n):
|
||||
return np.linalg.matrix_power(M, n)
|
||||
|
||||
@ray.remote([np.ndarray, np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def solve(a, b):
|
||||
return np.linalg.solve(a, b)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorsolve(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def tensorinv(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def inv(a):
|
||||
return np.linalg.inv(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def cholesky(a):
|
||||
return np.linalg.cholesky(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def eigvals(a):
|
||||
return np.linalg.eigvals(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def eigvalsh(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def pinv(a):
|
||||
return np.linalg.pinv(a)
|
||||
|
||||
@ray.remote([np.ndarray], [int])
|
||||
@ray.remote
|
||||
def slogdet(a):
|
||||
raise NotImplementedError
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote
|
||||
def det(a):
|
||||
return np.linalg.det(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def svd(a):
|
||||
return np.linalg.svd(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eig(a):
|
||||
return np.linalg.eig(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def eigh(a):
|
||||
return np.linalg.eigh(a)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray, int, np.ndarray])
|
||||
@ray.remote(num_return_vals=4)
|
||||
def lstsq(a, b):
|
||||
return np.linalg.lstsq(a)
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote
|
||||
def norm(x):
|
||||
return np.linalg.norm(x)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray, np.ndarray])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def qr(a):
|
||||
return np.linalg.qr(a)
|
||||
|
||||
@ray.remote([np.ndarray], [float])
|
||||
@ray.remote
|
||||
def cond(x):
|
||||
return np.linalg.cond(x)
|
||||
|
||||
@ray.remote([np.ndarray], [int])
|
||||
@ray.remote
|
||||
def matrix_rank(M):
|
||||
return np.linalg.matrix_rank(M)
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def multi_dot(*a):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
import numpy as np
|
||||
import ray
|
||||
|
||||
@ray.remote([List], [np.ndarray])
|
||||
@ray.remote
|
||||
def normal(shape):
|
||||
return np.random.normal(size=shape)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Note that a little bit of code here is taken and slightly modified from the pickler because it was not possible to change its behavior otherwise.
|
||||
|
||||
import sys
|
||||
import typing
|
||||
from ctypes import c_void_p
|
||||
from cloudpickle import pickle, cloudpickle, CloudPickler, load, loads
|
||||
|
||||
@@ -38,9 +37,6 @@ def _fill_function(func, globals, defaults, closure, dict):
|
||||
pythonapi.PyCell_Set(c_void_p(id(result.__closure__[i])), c_void_p(id(v)))
|
||||
return result
|
||||
|
||||
def _create_type(type_repr):
|
||||
return eval(type_repr.replace("~", ""), None, (lambda d: d.setdefault("typing", typing) and None or d)(dict(typing.__dict__)))
|
||||
|
||||
class BetterPickler(CloudPickler):
|
||||
def save_function_tuple(self, func):
|
||||
code, f_globals, defaults, closure, dct, base_globals = self.extract_func_data(func)
|
||||
@@ -63,10 +59,5 @@ class BetterPickler(CloudPickler):
|
||||
self.save(cloudpickle._make_cell)
|
||||
self.save((obj.cell_contents,))
|
||||
self.write(pickle.REDUCE)
|
||||
def save_type(self, obj):
|
||||
self.save(_create_type)
|
||||
self.save((repr(obj),))
|
||||
self.write(pickle.REDUCE)
|
||||
dispatch = CloudPickler.dispatch.copy()
|
||||
dispatch[(lambda _: lambda: _)(0).__closure__[0].__class__] = save_cell
|
||||
# dispatch[typing.GenericMeta] = save_type
|
||||
|
||||
@@ -28,7 +28,7 @@ class Tuple(tuple):
|
||||
|
||||
class Str(str):
|
||||
pass
|
||||
|
||||
|
||||
class Unicode(unicode):
|
||||
pass
|
||||
|
||||
|
||||
+85
-226
@@ -4,8 +4,6 @@ import time
|
||||
import traceback
|
||||
import copy
|
||||
import logging
|
||||
from types import ModuleType
|
||||
import typing
|
||||
import funcsigs
|
||||
import numpy as np
|
||||
import colorama
|
||||
@@ -45,7 +43,7 @@ class RayTaskError(Exception):
|
||||
def __init__(self, function_name, exception, traceback_str):
|
||||
"""Initialize a RayTaskError."""
|
||||
self.function_name = function_name
|
||||
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError) or isinstance(exception, RayGetArgumentTypeError):
|
||||
if isinstance(exception, RayGetError) or isinstance(exception, RayGetArgumentError):
|
||||
self.exception = exception
|
||||
else:
|
||||
self.exception = None
|
||||
@@ -59,8 +57,6 @@ class RayTaskError(Exception):
|
||||
exception = RayGetError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentError":
|
||||
exception = RayGetArgumentError.deserialize(exception[1])
|
||||
elif exception[0] == "RayGetArgumentTypeError":
|
||||
exception = RayGetArgumentTypeError.deserialize(exception[1])
|
||||
elif exception[0] == "None":
|
||||
exception = None
|
||||
else:
|
||||
@@ -73,8 +69,6 @@ class RayTaskError(Exception):
|
||||
serialized_exception = ("RayGetError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentError):
|
||||
serialized_exception = ("RayGetArgumentError", self.exception.serialize())
|
||||
elif isinstance(self.exception, RayGetArgumentTypeError):
|
||||
serialized_exception = ("RayGetArgumentTypeError", self.exception.serialize())
|
||||
elif self.exception is None:
|
||||
serialized_exception = ("None",)
|
||||
else:
|
||||
@@ -151,41 +145,6 @@ class RayGetArgumentError(Exception):
|
||||
"""Format a RayGetArgumentError as a string."""
|
||||
return "Failed to get objectid {} as argument {} for remote function {}{}{}. It was created by remote function {}{}{} which failed with:\n{}".format(self.objectid, self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, colorama.Fore.RED, self.task_error.function_name, colorama.Fore.RESET, self.task_error)
|
||||
|
||||
class RayGetArgumentTypeError(Exception):
|
||||
"""An exception used when a task's argument doesn't type check.
|
||||
|
||||
Attributes:
|
||||
function_name (str): The name of the function for the current task.
|
||||
argument_index (int): The index (zero indexed) of the argument in the
|
||||
present task's remote function call.
|
||||
received_type: The type of the argument that was passed in.
|
||||
expected_type: The type that was expected. This is determined by the remote
|
||||
decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, function_name, argument_index, received_type, expected_type):
|
||||
"""Initialize a RayGetArgumentTypeError object."""
|
||||
self.function_name = function_name
|
||||
self.argument_index = argument_index
|
||||
# TODO(rkn): when we support the serialization of types, then we should
|
||||
# remove the string conversions below.
|
||||
self.received_type = str(received_type)
|
||||
self.expected_type = str(expected_type)
|
||||
|
||||
@staticmethod
|
||||
def deserialize(primitives):
|
||||
"""Create a RayGetArgumentTypeError from a primitive object."""
|
||||
function_name, argument_index, received_type, expected_type = primitives
|
||||
return RayGetArgumentTypeError(function_name, argument_index, received_type, expected_type)
|
||||
|
||||
def serialize(self):
|
||||
"""Turn a RayGetArgumentTypeError into a primitive object."""
|
||||
return (self.function_name, self.argument_index, self.received_type, self.expected_type)
|
||||
|
||||
def __str__(self):
|
||||
"""Format a RayGetArgumentTypeError as a string."""
|
||||
return "Argument {} for remote function {}{}{} has type {} but an argument of type {} was expected.".format(self.argument_index, colorama.Fore.RED, self.function_name, colorama.Fore.RESET, self.received_type, self.expected_type)
|
||||
|
||||
class RayDealloc(object):
|
||||
"""An object used internally to properly implement reference counting.
|
||||
|
||||
@@ -232,13 +191,13 @@ class Reusable(object):
|
||||
|
||||
def __init__(self, initializer, reinitializer=None):
|
||||
"""Initialize a Reusable object."""
|
||||
if not isinstance(initializer, typing.Callable):
|
||||
if not callable(initializer):
|
||||
raise Exception("When creating a RayReusable, initializer must be a function.")
|
||||
self.initializer = initializer
|
||||
if reinitializer is None:
|
||||
# If no reinitializer is passed in, use a wrapped version of the initializer.
|
||||
reinitializer = lambda value: initializer()
|
||||
if not isinstance(reinitializer, typing.Callable):
|
||||
if not callable(reinitializer):
|
||||
raise Exception("When creating a RayReusable, reinitializer must be a function.")
|
||||
self.reinitializer = reinitializer
|
||||
|
||||
@@ -993,7 +952,7 @@ def main_loop(worker=global_worker):
|
||||
def process_remote_function(function_name, serialized_function):
|
||||
"""Import a remote function."""
|
||||
try:
|
||||
(function, arg_types, return_types, module) = pickling.loads(serialized_function)
|
||||
function, num_return_vals, module = pickling.loads(serialized_function)
|
||||
except:
|
||||
# If an exception was thrown when the remote function was imported, we
|
||||
# record the traceback and notify the scheduler of the failure.
|
||||
@@ -1005,11 +964,11 @@ def main_loop(worker=global_worker):
|
||||
# TODO(rkn): Why is the below line necessary?
|
||||
function.__module__ = module
|
||||
assert function_name == "{}.{}".format(function.__module__, function.__name__), "The remote function name does not match the name that was passed in."
|
||||
worker.functions[function_name] = remote(arg_types, return_types, worker)(function)
|
||||
worker.functions[function_name] = remote(num_return_vals=num_return_vals)(function)
|
||||
_logger().info("Successfully imported remote function {}.".format(function_name))
|
||||
# Noify the scheduler that the remote function imported successfully.
|
||||
# We pass an empty error message string because the import succeeded.
|
||||
raylib.register_remote_function(worker.handle, function_name, len(return_types))
|
||||
raylib.register_remote_function(worker.handle, function_name, num_return_vals)
|
||||
|
||||
def process_reusable_variable(reusable_variable_name, initializer_str, reinitializer_str):
|
||||
"""Import a reusable variable."""
|
||||
@@ -1110,75 +1069,89 @@ def _export_reusable_variable(name, reusable, worker=global_worker):
|
||||
raise Exception("_export_reusable_variable can only be called on a driver.")
|
||||
raylib.export_reusable_variable(worker.handle, name, pickling.dumps(reusable.initializer), pickling.dumps(reusable.reinitializer))
|
||||
|
||||
def remote(arg_types, return_types, worker=global_worker):
|
||||
def remote(*args, **kwargs):
|
||||
"""This decorator is used to create remote functions.
|
||||
|
||||
Args:
|
||||
arg_types (List[type]): List of Python types of the function arguments.
|
||||
return_types (List[type]): List of Python types of the return values.
|
||||
num_return_vals (int): The number of object IDs that a call to this function
|
||||
should return.
|
||||
"""
|
||||
def remote_decorator(func):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if _mode() == raylib.PYTHON_MODE:
|
||||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
check_arguments(arg_types, has_vararg_param, func_name, args) # throws an exception if args are invalid
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
return objectids
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
_logger().info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
check_return_values(func_invoker, result) # throws an exception if result is invalid
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is returned by the decorator and used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.arg_types = arg_types
|
||||
func_invoker.return_types = return_types
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
func_invoker.func_doc = func.func_doc
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_invoker.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
worker = global_worker
|
||||
def make_remote_decorator(num_return_vals):
|
||||
def remote_decorator(func):
|
||||
def func_call(*args, **kwargs):
|
||||
"""This gets run immediately when a worker calls a remote function."""
|
||||
check_connected()
|
||||
args = list(args)
|
||||
args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in keyword_defaults[len(args):]]) # fill in the remaining arguments
|
||||
if any([arg is funcsigs._empty for arg in args]):
|
||||
raise Exception("Not enough arguments were provided to {}.".format(func_name))
|
||||
if _mode() == raylib.PYTHON_MODE:
|
||||
# In raylib.PYTHON_MODE, remote calls simply execute the function. We copy the
|
||||
# arguments to prevent the function call from mutating them and to match
|
||||
# the usual behavior of immutable remote objects.
|
||||
return func(*copy.deepcopy(args))
|
||||
objectids = _submit_task(func_name, args)
|
||||
if len(objectids) == 1:
|
||||
return objectids[0]
|
||||
elif len(objectids) > 1:
|
||||
return objectids
|
||||
def func_executor(arguments):
|
||||
"""This gets run when the remote function is executed."""
|
||||
_logger().info("Calling function {}".format(func.__name__))
|
||||
start_time = time.time()
|
||||
result = func(*arguments)
|
||||
end_time = time.time()
|
||||
_logger().info("Finished executing function {}, it took {} seconds".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
def func_invoker(*args, **kwargs):
|
||||
"""This is returned by the decorator and used to invoke the function."""
|
||||
raise Exception("Remote functions cannot be called directly. Instead of running '{}()', try '{}.remote()'.".format(func_name, func_name))
|
||||
func_invoker.remote = func_call
|
||||
func_invoker.executor = func_executor
|
||||
func_invoker.is_remote = True
|
||||
func_name = "{}.{}".format(func.__module__, func.__name__)
|
||||
func_invoker.func_name = func_name
|
||||
func_invoker.func_doc = func.func_doc
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, arg_types, return_types, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.export_remote_function(worker.handle, func_name, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((func_name, to_export))
|
||||
return func_invoker
|
||||
return remote_decorator
|
||||
sig_params = [(k, v) for k, v in funcsigs.signature(func).parameters.iteritems()]
|
||||
keyword_defaults = [(k, v.default) for k, v in sig_params]
|
||||
has_vararg_param = any([v.kind == v.VAR_POSITIONAL for k, v in sig_params])
|
||||
func_invoker.has_vararg_param = has_vararg_param
|
||||
has_kwargs_param = any([v.kind == v.VAR_KEYWORD for k, v in sig_params])
|
||||
check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, func_name)
|
||||
|
||||
# Everything ready - export the function
|
||||
if worker.mode in [None, raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
func_name_global_valid = func.__name__ in func.__globals__
|
||||
func_name_global_value = func.__globals__.get(func.__name__)
|
||||
# Set the function globally to make it refer to itself
|
||||
func.__globals__[func.__name__] = func_invoker # Allow the function to reference itself as a global variable
|
||||
try:
|
||||
to_export = pickling.dumps((func, num_return_vals, func.__module__))
|
||||
finally:
|
||||
# Undo our changes
|
||||
if func_name_global_valid: func.__globals__[func.__name__] = func_name_global_value
|
||||
else: del func.__globals__[func.__name__]
|
||||
if worker.mode in [raylib.SCRIPT_MODE, raylib.SILENT_MODE]:
|
||||
raylib.export_remote_function(worker.handle, func_name, to_export)
|
||||
elif worker.mode is None:
|
||||
worker.cached_remote_functions.append((func_name, to_export))
|
||||
return func_invoker
|
||||
|
||||
return remote_decorator
|
||||
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
num_return_vals = 1
|
||||
func = args[0]
|
||||
return make_remote_decorator(num_return_vals)(func)
|
||||
else:
|
||||
# This is the case where the decorator is something like
|
||||
# @ray.remote(num_return_vals=2).
|
||||
assert len(args) == 0 and "num_return_vals" in kwargs.keys(), "The @ray.remote decorator must be applied either with no arguments and no parentheses, for example '@ray.remote', or it must be applied with only the argument num_return_vals, like '@ray.remote(num_return_vals=2)'."
|
||||
num_return_vals = kwargs["num_return_vals"]
|
||||
return make_remote_decorator(num_return_vals)
|
||||
|
||||
def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaults, name):
|
||||
"""Check if we support the signature of this function.
|
||||
@@ -1205,109 +1178,12 @@ def check_signature_supported(has_kwargs_param, has_vararg_param, keyword_defaul
|
||||
if has_vararg_param and any([d != funcsigs._empty for _, d in keyword_defaults]):
|
||||
raise "Function {} has a *args argument as well as a keyword argument, which is currently not supported.".format(name)
|
||||
|
||||
|
||||
def check_return_values(function, result):
|
||||
"""Check the types and number of return values.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose outputs are being checked.
|
||||
result: The value returned by an invocation of the remote function. The
|
||||
expected types and number are defined in the remote decorator.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised if the return values have incorrect types
|
||||
or the function returned the wrong number of return values.
|
||||
"""
|
||||
# If the @remote decorator declares that the function has no return values,
|
||||
# then all we do is check that there were in fact no return values.
|
||||
if len(function.return_types) == 0:
|
||||
if result is not None:
|
||||
raise Exception("The @remote decorator for function {} has 0 return values, but {} returned more than 0 values.".format(function.__name__, function.__name__))
|
||||
return
|
||||
# If a function has multiple return values, Python returns a tuple of the
|
||||
# values. If there is a single return value, then Python does not return a
|
||||
# tuple, it simply returns the value. That is why we place result with
|
||||
# (result,) when there is only one return value, so we can treat these two
|
||||
# cases similarly.
|
||||
if len(function.return_types) == 1:
|
||||
result = (result,)
|
||||
# Below we check that the number of values returned by the function match the
|
||||
# number of return values declared in the @remote decorator.
|
||||
if len(result) != len(function.return_types):
|
||||
raise Exception("The @remote decorator for function {} has {} return values with types {}, but {} returned {} values.".format(function.__name__, len(function.return_types), function.return_types, function.__name__, len(result)))
|
||||
# Here we do some limited type checking to make sure the return values have
|
||||
# the right types.
|
||||
for i in range(len(result)):
|
||||
if (not issubclass(type(result[i]), function.return_types[i])) and (not isinstance(result[i], raylib.ObjectID)):
|
||||
raise Exception("The {}th return value for function {} has type {}, but the @remote decorator expected a return value of type {} or an ObjectID.".format(i, function.__name__, type(result[i]), function.return_types[i]))
|
||||
|
||||
def typecheck_arg(arg, expected_type, i, name):
|
||||
"""Check that an argument has the expected type.
|
||||
|
||||
Args:
|
||||
arg: An argument to function.
|
||||
expected_type (type): The expected type of arg.
|
||||
i (int): The position of the argument to the function.
|
||||
name (str): The name of the function.
|
||||
|
||||
Raises:
|
||||
RayGetArgumentTypeError: An exception is raised if arg does not have the
|
||||
expected type.
|
||||
"""
|
||||
if issubclass(type(arg), expected_type):
|
||||
# Passed the type-checck
|
||||
# TODO(rkn): This check doesn't really work, e.g., issubclass(type([1, 2, 3]), typing.List[str]) == True
|
||||
pass
|
||||
elif isinstance(arg, long) and issubclass(int, expected_type):
|
||||
# TODO(mehrdadn): Should long really be convertible to int?
|
||||
pass
|
||||
else:
|
||||
raise RayGetArgumentTypeError(name, i, type(arg), expected_type)
|
||||
|
||||
def check_arguments(arg_types, has_vararg_param, name, args):
|
||||
"""Check that the arguments to the remote function have the right types.
|
||||
|
||||
This is called by the worker that calls the remote function (not the worker
|
||||
that executes the remote function).
|
||||
|
||||
Args:
|
||||
arg_types (List[type]): A list of the types of the arguments to the function
|
||||
being checked.
|
||||
has_vararg_param (bool): True if the function being checked has a *args
|
||||
argument.
|
||||
name (str): The name of the function.
|
||||
args (List): The arguments to the function.
|
||||
|
||||
Raises:
|
||||
Exception: An exception is raised the args do not all have the right types.
|
||||
"""
|
||||
# check the number of args
|
||||
if len(args) != len(arg_types) and not has_vararg_param:
|
||||
raise Exception("Function {} expects {} arguments, but received {}.".format(name, len(arg_types), len(args)))
|
||||
elif len(args) < len(arg_types) - 1 and has_vararg_param:
|
||||
raise Exception("Function {} expects at least {} arguments, but received {}.".format(name, len(arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i <= len(arg_types) - 1:
|
||||
expected_type = arg_types[i]
|
||||
elif has_vararg_param:
|
||||
expected_type = arg_types[-1]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
# TODO(rkn): When we have type information in the ObjectID, do type checking here.
|
||||
pass
|
||||
else:
|
||||
typecheck_arg(arg, expected_type, i, name)
|
||||
|
||||
def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
"""Retrieve the arguments for the remote function.
|
||||
|
||||
This retrieves the values for the arguments to the remote function that were
|
||||
passed in as object IDs. Argumens that were passed by value are not changed.
|
||||
This also does some type checking. This is called by the worker that is
|
||||
executing the remote function.
|
||||
This is called by the worker that is executing the remote function.
|
||||
|
||||
Args:
|
||||
function (Callable): The remote function whose arguments are being
|
||||
@@ -1321,25 +1197,9 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
Raises:
|
||||
RayGetArgumentError: This exception is raised if a task that created one of
|
||||
the arguments failed.
|
||||
RayGetArgumentTypeError: This exception is raised (via typecheck_arg) if one
|
||||
of the arguments does not have the expected type.
|
||||
"""
|
||||
# TODO(rkn): Eventually, all of the type checking can be put in `check_arguments` above so that the error will happen immediately when calling a remote function.
|
||||
arguments = []
|
||||
# # check the number of args
|
||||
# if len(args) != len(function.arg_types) and function.arg_types[-1] is not None:
|
||||
# raise Exception("Function {} expects {} arguments, but received {}.".format(function.__name__, len(function.arg_types), len(args)))
|
||||
# elif len(args) < len(function.arg_types) - 1 and function.arg_types[-1] is None:
|
||||
# raise Exception("Function {} expects at least {} arguments, but received {}.".format(function.__name__, len(function.arg_types) - 1, len(args)))
|
||||
|
||||
for (i, arg) in enumerate(args):
|
||||
if i <= len(function.arg_types) - 1:
|
||||
expected_type = function.arg_types[i]
|
||||
elif function.has_vararg_param and len(function.arg_types) >= 1:
|
||||
expected_type = function.arg_types[-1]
|
||||
else:
|
||||
assert False, "This code should be unreachable."
|
||||
|
||||
if isinstance(arg, raylib.ObjectID):
|
||||
# get the object from the local object store
|
||||
_logger().info("Getting argument {} for function {}.".format(i, function.__name__))
|
||||
@@ -1353,7 +1213,6 @@ def get_arguments_for_execution(function, args, worker=global_worker):
|
||||
# pass the argument by value
|
||||
argument = arg
|
||||
|
||||
typecheck_arg(argument, expected_type, i, function.__name__)
|
||||
arguments.append(argument)
|
||||
return arguments
|
||||
|
||||
|
||||
+2
-30
@@ -6,34 +6,6 @@ import test_functions
|
||||
|
||||
class FailureTest(unittest.TestCase):
|
||||
|
||||
def testNoArgs(self):
|
||||
reload(test_functions)
|
||||
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
test_functions.no_op_fail.remote()
|
||||
time.sleep(0.2)
|
||||
task_info = ray.task_info()
|
||||
self.assertEqual(len(task_info["failed_tasks"]), 1)
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testTypeChecking(self):
|
||||
reload(test_functions)
|
||||
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
|
||||
# Make sure that these functions throw exceptions because there return
|
||||
# values do not type check.
|
||||
test_functions.test_return1.remote()
|
||||
test_functions.test_return2.remote()
|
||||
time.sleep(0.2)
|
||||
task_info = ray.task_info()
|
||||
self.assertEqual(len(task_info["failed_tasks"]), 2)
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
|
||||
ray.worker.cleanup()
|
||||
|
||||
def testUnknownSerialization(self):
|
||||
reload(test_functions)
|
||||
ray.init(start_ray_local=True, num_workers=1, driver_mode=ray.SILENT_MODE)
|
||||
@@ -104,7 +76,7 @@ class TaskStatusTest(unittest.TestCase):
|
||||
return reducer, ()
|
||||
def __call__(self):
|
||||
return
|
||||
ray.remote([], [])(Foo())
|
||||
ray.remote(Foo())
|
||||
for _ in range(100): # Retry if we need to wait longer.
|
||||
if len(ray.task_info()["failed_remote_function_imports"]) >= 1:
|
||||
break
|
||||
@@ -140,7 +112,7 @@ class TaskStatusTest(unittest.TestCase):
|
||||
def reinitializer(foo):
|
||||
raise Exception("The reinitializer failed.")
|
||||
ray.reusables.foo = ray.Reusable(initializer, reinitializer)
|
||||
@ray.remote([], [])
|
||||
@ray.remote
|
||||
def use_foo():
|
||||
ray.reusables.foo
|
||||
use_foo.remote()
|
||||
|
||||
+21
-21
@@ -262,42 +262,42 @@ class APITest(unittest.TestCase):
|
||||
ray.init(start_ray_local=True, num_workers=2)
|
||||
|
||||
# Test that we can define a remote function in the shell.
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 1
|
||||
self.assertEqual(ray.get(f.remote(0)), 1)
|
||||
|
||||
# Test that we can redefine the remote function.
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 10
|
||||
self.assertEqual(ray.get(f.remote(0)), 10)
|
||||
|
||||
# Test that we can close over plain old data.
|
||||
data = [np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 2L], 2L, {"a": np.zeros(3)}]
|
||||
@ray.remote([], [list])
|
||||
@ray.remote
|
||||
def g():
|
||||
return data
|
||||
ray.get(g.remote())
|
||||
|
||||
# Test that we can close over modules.
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def h():
|
||||
return np.zeros([3, 5])
|
||||
assert_equal(ray.get(h.remote()), np.zeros([3, 5]))
|
||||
@ray.remote([], [float])
|
||||
@ray.remote
|
||||
def j():
|
||||
return time.time()
|
||||
ray.get(j.remote())
|
||||
|
||||
# Test that we can define remote functions that call other remote functions.
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def k(x):
|
||||
return x + 1
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def l(x):
|
||||
return k.remote(x)
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def m(x):
|
||||
return ray.get(l.remote(x))
|
||||
self.assertEqual(ray.get(k.remote(1)), 2)
|
||||
@@ -309,7 +309,7 @@ class APITest(unittest.TestCase):
|
||||
def testSelect(self):
|
||||
ray.init(start_ray_local=True, num_workers=4)
|
||||
|
||||
@ray.remote([float], [int])
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
time.sleep(delay)
|
||||
return 1
|
||||
@@ -345,10 +345,10 @@ class APITest(unittest.TestCase):
|
||||
ray.reusables.foo = ray.Reusable(foo_initializer)
|
||||
ray.reusables.bar = ray.Reusable(bar_initializer, bar_reinitializer)
|
||||
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def use_foo():
|
||||
return ray.reusables.foo
|
||||
@ray.remote([], [list])
|
||||
@ray.remote
|
||||
def use_bar():
|
||||
ray.reusables.bar.append(1)
|
||||
return ray.reusables.bar
|
||||
@@ -368,7 +368,7 @@ class APITest(unittest.TestCase):
|
||||
def f():
|
||||
sys.path.append("fake_directory")
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
@ray.remote([], [list])
|
||||
@ray.remote
|
||||
def get_path():
|
||||
return sys.path
|
||||
self.assertEqual("fake_directory", ray.get(get_path.remote())[-1])
|
||||
@@ -509,7 +509,7 @@ class PythonCExtensionTest(unittest.TestCase):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def f():
|
||||
return sys.getrefcount(None)
|
||||
first_count = ray.get(f.remote())
|
||||
@@ -522,7 +522,7 @@ class PythonCExtensionTest(unittest.TestCase):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def f():
|
||||
return sys.getrefcount(True)
|
||||
first_count = ray.get(f.remote())
|
||||
@@ -535,7 +535,7 @@ class PythonCExtensionTest(unittest.TestCase):
|
||||
ray.init(start_ray_local=True, num_workers=1)
|
||||
|
||||
# Make sure that we aren't accidentally messing up Python's reference counts.
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def f():
|
||||
return sys.getrefcount(False)
|
||||
first_count = ray.get(f.remote())
|
||||
@@ -559,7 +559,7 @@ class ReusablesTest(unittest.TestCase):
|
||||
ray.reusables.foo = ray.Reusable(foo_initializer, foo_reinitializer)
|
||||
self.assertEqual(ray.reusables.foo, 1)
|
||||
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def use_foo():
|
||||
return ray.reusables.foo
|
||||
self.assertEqual(ray.get(use_foo.remote()), 1)
|
||||
@@ -573,7 +573,7 @@ class ReusablesTest(unittest.TestCase):
|
||||
|
||||
ray.reusables.bar = ray.Reusable(bar_initializer)
|
||||
|
||||
@ray.remote([], [list])
|
||||
@ray.remote
|
||||
def use_bar():
|
||||
ray.reusables.bar.append(4)
|
||||
return ray.reusables.bar
|
||||
@@ -592,7 +592,7 @@ class ReusablesTest(unittest.TestCase):
|
||||
|
||||
ray.reusables.baz = ray.Reusable(baz_initializer, baz_reinitializer)
|
||||
|
||||
@ray.remote([int], [np.ndarray])
|
||||
@ray.remote
|
||||
def use_baz(i):
|
||||
baz = ray.reusables.baz
|
||||
baz[i] = 1
|
||||
@@ -613,7 +613,7 @@ class ReusablesTest(unittest.TestCase):
|
||||
|
||||
ray.reusables.qux = ray.Reusable(qux_initializer, qux_reinitializer)
|
||||
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def use_qux():
|
||||
return ray.reusables.qux
|
||||
self.assertEqual(ray.get(use_qux.remote()), 0)
|
||||
@@ -634,7 +634,7 @@ class ClusterAttachingTest(unittest.TestCase):
|
||||
|
||||
ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address)
|
||||
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 1
|
||||
self.assertEqual(ray.get(f.remote(0)), 1)
|
||||
@@ -653,7 +653,7 @@ class ClusterAttachingTest(unittest.TestCase):
|
||||
|
||||
ray.init(node_ip_address=node_ip_address, scheduler_address=scheduler_address)
|
||||
|
||||
@ray.remote([int], [int])
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x + 1
|
||||
self.assertEqual(ray.get(f.remote(0)), 1)
|
||||
|
||||
+20
-34
@@ -4,60 +4,60 @@ import numpy as np
|
||||
|
||||
# Test simple functionality
|
||||
|
||||
@ray.remote([int, int], [int, int])
|
||||
@ray.remote(num_return_vals=2)
|
||||
def handle_int(a, b):
|
||||
return a + 1, b + 1
|
||||
|
||||
# Test aliasing
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def test_alias_f():
|
||||
return np.ones([3, 4, 5])
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def test_alias_g():
|
||||
return test_alias_f.remote()
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def test_alias_h():
|
||||
return test_alias_g.remote()
|
||||
|
||||
# Test timing
|
||||
|
||||
@ray.remote([], [])
|
||||
@ray.remote
|
||||
def empty_function():
|
||||
pass
|
||||
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def trivial_function():
|
||||
return 1
|
||||
|
||||
# Test keyword arguments
|
||||
|
||||
@ray.remote([int, str], [str])
|
||||
@ray.remote
|
||||
def keyword_fct1(a, b="hello"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
@ray.remote([str, str], [str])
|
||||
@ray.remote
|
||||
def keyword_fct2(a="hello", b="world"):
|
||||
return "{} {}".format(a, b)
|
||||
|
||||
@ray.remote([int, int, str, str], [str])
|
||||
@ray.remote
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
# Test variable numbers of arguments
|
||||
|
||||
@ray.remote([int], [str])
|
||||
@ray.remote
|
||||
def varargs_fct1(*a):
|
||||
return " ".join(map(str, a))
|
||||
|
||||
@ray.remote([int, int], [str])
|
||||
@ray.remote
|
||||
def varargs_fct2(a, *b):
|
||||
return " ".join(map(str, b))
|
||||
|
||||
try:
|
||||
@ray.remote([int], [])
|
||||
@ray.remote
|
||||
def kwargs_throw_exception(**c):
|
||||
return ()
|
||||
kwargs_exception_thrown = False
|
||||
@@ -65,7 +65,7 @@ except:
|
||||
kwargs_exception_thrown = True
|
||||
|
||||
try:
|
||||
@ray.remote([int, str, int], [str])
|
||||
@ray.remote
|
||||
def varargs_and_kwargs_throw_exception(a, b="hi", *c):
|
||||
return "{} {} {}".format(a, b, c)
|
||||
varargs_and_kwargs_exception_thrown = False
|
||||
@@ -74,53 +74,39 @@ except:
|
||||
|
||||
# test throwing an exception
|
||||
|
||||
@ray.remote([], [])
|
||||
@ray.remote
|
||||
def throw_exception_fct1():
|
||||
raise Exception("Test function 1 intentionally failed.")
|
||||
|
||||
@ray.remote([], [int])
|
||||
@ray.remote
|
||||
def throw_exception_fct2():
|
||||
raise Exception("Test function 2 intentionally failed.")
|
||||
|
||||
@ray.remote([float], [int, str, np.ndarray])
|
||||
@ray.remote(num_return_vals=3)
|
||||
def throw_exception_fct3(x):
|
||||
raise Exception("Test function 3 intentionally failed.")
|
||||
|
||||
# test Python mode
|
||||
|
||||
@ray.remote([], [np.ndarray])
|
||||
@ray.remote
|
||||
def python_mode_f():
|
||||
return np.array([0, 0])
|
||||
|
||||
@ray.remote([np.ndarray], [np.ndarray])
|
||||
@ray.remote
|
||||
def python_mode_g(x):
|
||||
x[0] = 1
|
||||
return x
|
||||
|
||||
# test no return values
|
||||
|
||||
@ray.remote([], [])
|
||||
@ray.remote
|
||||
def no_op():
|
||||
pass
|
||||
|
||||
@ray.remote([], [])
|
||||
def no_op_fail():
|
||||
return 0
|
||||
|
||||
# test wrong return types
|
||||
|
||||
@ray.remote([], [int])
|
||||
def test_return1():
|
||||
return 0.0
|
||||
|
||||
@ray.remote([], [int, float])
|
||||
def test_return2():
|
||||
return 2.0, 3.0
|
||||
|
||||
class TestClass(object):
|
||||
def __init__(self):
|
||||
self.a = 5
|
||||
|
||||
@ray.remote([], [TestClass])
|
||||
@ray.remote
|
||||
def test_unknown_type():
|
||||
return TestClass()
|
||||
|
||||
Reference in New Issue
Block a user