Files
ray/python/ray/tests/test_basic.py
T
2020-02-23 21:26:49 -08:00

1737 lines
47 KiB
Python

# coding: utf-8
import collections
import io
import json
import logging
import os
import pickle
import re
import string
import sys
import tempfile
import threading
import time
import uuid
import weakref
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.exceptions import RayTimeoutError
logger = logging.getLogger(__name__)
# https://github.com/ray-project/ray/issues/6662
def test_ignore_http_proxy(shutdown_only):
ray.init(num_cpus=1)
os.environ["http_proxy"] = "http://example.com"
os.environ["https_proxy"] = "http://example.com"
@ray.remote
def f():
return 1
assert ray.get(f.remote()) == 1
# https://github.com/ray-project/ray/issues/7287
def test_omp_threads_set(shutdown_only):
ray.init(num_cpus=1)
# Should have been auto set by ray init.
assert os.environ["OMP_NUM_THREADS"] == "1"
def test_simple_serialization(ray_start_regular):
primitive_objects = [
# Various primitive types.
0,
0.0,
0.9,
1 << 62,
1 << 999,
b"",
b"a",
"a",
string.printable,
"\u262F",
u"hello world",
u"\xff\xfe\x9c\x001\x000\x00",
None,
True,
False,
[],
(),
{},
type,
int,
set(),
# Collections types.
collections.Counter([np.random.randint(0, 10) for _ in range(100)]),
collections.OrderedDict([("hello", 1), ("world", 2)]),
collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]),
collections.defaultdict(lambda: [], [("hello", 1), ("world", 2)]),
collections.deque([1, 2, 3, "a", "b", "c", 3.5]),
# Numpy dtypes.
np.int8(3),
np.int32(4),
np.int64(5),
np.uint8(3),
np.uint32(4),
np.uint64(5),
np.float32(1.9),
np.float64(1.9),
]
composite_objects = (
[[obj]
for obj in primitive_objects] + [(obj, )
for obj in primitive_objects] + [{
(): obj
} for obj in primitive_objects])
@ray.remote
def f(x):
return x
# Check that we can pass arguments by value to remote functions and
# that they are uncorrupted.
for obj in primitive_objects + composite_objects:
new_obj_1 = ray.get(f.remote(obj))
new_obj_2 = ray.get(ray.put(obj))
assert obj == new_obj_1
assert obj == new_obj_2
# TODO(rkn): The numpy dtypes currently come back as regular integers
# or floats.
if type(obj).__module__ != "numpy":
assert type(obj) == type(new_obj_1)
assert type(obj) == type(new_obj_2)
def test_background_tasks_with_max_calls(shutdown_only):
ray.init(num_cpus=2)
@ray.remote
def g():
time.sleep(.1)
return 0
@ray.remote(max_calls=1, max_retries=0)
def f():
return [g.remote()]
nested = ray.get([f.remote() for _ in range(10)])
# Should still be able to retrieve these objects, since f's workers will
# wait for g to finish before exiting.
ray.get([x[0] for x in nested])
def test_fair_queueing(shutdown_only):
ray.init(
num_cpus=1, _internal_config=json.dumps({
"fair_queueing_enabled": 1
}))
@ray.remote
def h():
return 0
@ray.remote
def g():
return ray.get(h.remote())
@ray.remote
def f():
return ray.get(g.remote())
# This will never finish without fair queueing of {f, g, h}:
# https://github.com/ray-project/ray/issues/3644
ready, _ = ray.wait(
[f.remote() for _ in range(1000)], timeout=60.0, num_returns=1000)
assert len(ready) == 1000, len(ready)
def complex_serialization(use_pickle):
def assert_equal(obj1, obj2):
module_numpy = (type(obj1).__module__ == np.__name__
or type(obj2).__module__ == np.__name__)
if module_numpy:
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ())
or (hasattr(obj2, "shape") and obj2.shape == ()))
if empty_shape:
# This is a special case because currently
# np.testing.assert_equal fails because we do not properly
# handle different numerical types.
assert obj1 == obj2, ("Objects {} and {} are "
"different.".format(obj1, obj2))
else:
np.testing.assert_equal(obj1, obj2)
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
special_keys = ["_pytype_"]
assert (set(list(obj1.__dict__.keys()) + special_keys) == set(
list(obj2.__dict__.keys()) + special_keys)), (
"Objects {} and {} are different.".format(obj1, obj2))
for key in obj1.__dict__.keys():
if key not in special_keys:
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
elif type(obj1) is dict or type(obj2) is dict:
assert_equal(obj1.keys(), obj2.keys())
for key in obj1.keys():
assert_equal(obj1[key], obj2[key])
elif type(obj1) is list or type(obj2) is list:
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
"different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif type(obj1) is tuple or type(obj2) is tuple:
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples "
"with different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif (ray.serialization.is_named_tuple(type(obj1))
or ray.serialization.is_named_tuple(type(obj2))):
assert len(obj1) == len(obj2), (
"Objects {} and {} are named "
"tuples with different lengths.".format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
else:
assert obj1 == obj2, "Objects {} and {} are different.".format(
obj1, obj2)
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a",
string.printable, "\u262F", u"hello world",
u"\xff\xfe\x9c\x001\x000\x00", None, True, False, [], (), {},
np.int8(3),
np.int32(4),
np.int64(5),
np.uint8(3),
np.uint32(4),
np.uint64(5),
np.float32(1.9),
np.float64(1.9),
np.zeros([100, 100]),
np.random.normal(size=[100, 100]),
np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)
] + long_extras
COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{
"obj{}".format(i): np.random.normal(size=[100, 100])
for i in range(10)
},
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
# (): {(): {}}}}}}}}}}}}},
(
(((((((((), ), ), ), ), ), ), ), ), ),
{
"a": {
"b": {
"c": {
"d": {}
}
}
}
},
]
class Foo:
def __init__(self, value=0):
self.value = value
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return other.value == self.value
class Bar:
def __init__(self):
for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS):
setattr(self, "field{}".format(i), val)
class Baz:
def __init__(self):
self.foo = Foo()
self.bar = Bar()
def method(self, arg):
pass
class Qux:
def __init__(self):
self.objs = [Foo(), Bar(), Baz()]
class SubQux(Qux):
def __init__(self):
Qux.__init__(self)
class CustomError(Exception):
pass
Point = collections.namedtuple("Point", ["x", "y"])
NamedTupleExample = collections.namedtuple(
"Example", "field1, field2, field3, field4, field5")
CUSTOM_OBJECTS = [
Exception("Test object."),
CustomError(),
Point(11, y=22),
Foo(),
Bar(),
Baz(), # Qux(), SubQux(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]),
]
# Test dataclasses in Python 3.7.
if sys.version_info >= (3, 7):
from dataclasses import make_dataclass
DataClass0 = make_dataclass("DataClass0", [("number", int)])
CUSTOM_OBJECTS.append(DataClass0(number=3))
class CustomClass:
def __init__(self, value):
self.value = value
DataClass1 = make_dataclass("DataClass1", [("custom", CustomClass)])
class DataClass2(DataClass1):
@classmethod
def from_custom(cls, data):
custom = CustomClass(data)
return cls(custom)
def __reduce__(self):
return (self.from_custom, (self.custom.value, ))
CUSTOM_OBJECTS.append(DataClass2(custom=CustomClass(43)))
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS]
TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS]
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
# otherwise this seems to fail on Mac OS X on Travis.
DICT_OBJECTS = ([{
obj: obj
} for obj in PRIMITIVE_OBJECTS if (
obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
0: obj
} for obj in BASE_OBJECTS] + [{
Foo(123): Foo(456)
}])
RAY_TEST_OBJECTS = (
BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS)
@ray.remote
def f(x):
return x
# Check that we can pass arguments by value to remote functions and
# that they are uncorrupted.
for obj in RAY_TEST_OBJECTS:
assert_equal(obj, ray.get(f.remote(obj)))
assert_equal(obj, ray.get(ray.put(obj)))
# Test StringIO serialization
s = io.StringIO(u"Hello, world!\n")
s.seek(0)
line = s.readline()
s.seek(0)
assert ray.get(ray.put(s)).readline() == line
def test_complex_serialization(ray_start_regular):
complex_serialization(use_pickle=False)
def test_complex_serialization_with_pickle(shutdown_only):
ray.init(use_pickle=True)
complex_serialization(use_pickle=True)
def test_function_descriptor():
python_descriptor = ray._raylet.PythonFunctionDescriptor(
"module_name", "function_name", "class_name", "function_hash")
python_descriptor2 = pickle.loads(pickle.dumps(python_descriptor))
assert python_descriptor == python_descriptor2
assert hash(python_descriptor) == hash(python_descriptor2)
assert python_descriptor.function_id == python_descriptor2.function_id
java_descriptor = ray._raylet.JavaFunctionDescriptor(
"class_name", "function_name", "signature")
java_descriptor2 = pickle.loads(pickle.dumps(java_descriptor))
assert java_descriptor == java_descriptor2
assert python_descriptor != java_descriptor
assert python_descriptor != object()
d = {python_descriptor: 123}
assert d.get(python_descriptor2) == 123
def test_nested_functions(ray_start_regular):
# Make sure that remote functions can use other values that are defined
# after the remote function but before the first function invocation.
@ray.remote
def f():
return g(), ray.get(h.remote())
def g():
return 1
@ray.remote
def h():
return 2
assert ray.get(f.remote()) == (1, 2)
# Test a remote function that recursively calls itself.
@ray.remote
def factorial(n):
if n == 0:
return 1
return n * ray.get(factorial.remote(n - 1))
assert ray.get(factorial.remote(0)) == 1
assert ray.get(factorial.remote(1)) == 1
assert ray.get(factorial.remote(2)) == 2
assert ray.get(factorial.remote(3)) == 6
assert ray.get(factorial.remote(4)) == 24
assert ray.get(factorial.remote(5)) == 120
# Test remote functions that recursively call each other.
@ray.remote
def factorial_even(n):
assert n % 2 == 0
if n == 0:
return 1
return n * ray.get(factorial_odd.remote(n - 1))
@ray.remote
def factorial_odd(n):
assert n % 2 == 1
return n * ray.get(factorial_even.remote(n - 1))
assert ray.get(factorial_even.remote(4)) == 24
assert ray.get(factorial_odd.remote(5)) == 120
def test_ray_recursive_objects(ray_start_regular):
class ClassA:
pass
# Make a list that contains itself.
lst = []
lst.append(lst)
# Make an object that contains itself as a field.
a1 = ClassA()
a1.field = a1
# Make two objects that contain each other as fields.
a2 = ClassA()
a3 = ClassA()
a2.field = a3
a3.field = a2
# Make a dictionary that contains itself.
d1 = {}
d1["key"] = d1
# Create a list of recursive objects.
recursive_objects = [lst, a1, a2, a3, d1]
if ray.worker.global_worker.use_pickle:
# Serialize the recursive objects.
for obj in recursive_objects:
ray.put(obj)
else:
# Check that exceptions are thrown when we serialize the recursive
# objects.
for obj in recursive_objects:
with pytest.raises(Exception):
ray.put(obj)
def test_reducer_override_no_reference_cycle(ray_start_regular):
# bpo-39492: reducer_override used to induce a spurious reference cycle
# inside the Pickler object, that could prevent all serialized objects
# from being garbage-collected without explicity invoking gc.collect.
# test a dynamic function
def f():
return 4669201609102990671853203821578
wr = weakref.ref(f)
bio = io.BytesIO()
from ray.cloudpickle import CloudPickler, loads, dumps
p = CloudPickler(bio, protocol=5)
p.dump(f)
new_f = loads(bio.getvalue())
assert new_f() == 4669201609102990671853203821578
del p
del f
assert wr() is None
# test a dynamic class
class ShortlivedObject:
def __del__(self):
print("Went out of scope!")
obj = ShortlivedObject()
new_obj = weakref.ref(obj)
dumps(obj)
del obj
assert new_obj() is None
def test_deserialized_from_buffer_immutable(ray_start_regular):
x = np.full((2, 2), 1.)
o = ray.put(x)
y = ray.get(o)
with pytest.raises(
ValueError, match="assignment destination is read-only"):
y[0, 0] = 9.
def test_passing_arguments_by_value_out_of_the_box(ray_start_regular):
@ray.remote
def f(x):
return x
# Test passing lambdas.
def temp():
return 1
assert ray.get(f.remote(temp))() == 1
assert ray.get(f.remote(lambda x: x + 1))(3) == 4
# Test sets.
assert ray.get(f.remote(set())) == set()
s = {1, (1, 2, "hi")}
assert ray.get(f.remote(s)) == s
# Test types.
assert ray.get(f.remote(int)) == int
assert ray.get(f.remote(float)) == float
assert ray.get(f.remote(str)) == str
class Foo:
def __init__(self):
pass
# Make sure that we can put and get a custom type. Note that the result
# won't be "equal" to Foo.
ray.get(ray.put(Foo))
def test_putting_object_that_closes_over_object_id(ray_start_regular):
# This test is here to prevent a regression of
# https://github.com/ray-project/ray/issues/1317.
class Foo:
def __init__(self):
self.val = ray.put(0)
def method(self):
f
f = Foo()
ray.put(f)
def test_put_get(shutdown_only):
ray.init(num_cpus=0)
for i in range(100):
value_before = i * 10**6
objectid = ray.put(value_before)
value_after = ray.get(objectid)
assert value_before == value_after
for i in range(100):
value_before = i * 10**6 * 1.0
objectid = ray.put(value_before)
value_after = ray.get(objectid)
assert value_before == value_after
for i in range(100):
value_before = "h" * i
objectid = ray.put(value_before)
value_after = ray.get(objectid)
assert value_before == value_after
for i in range(100):
value_before = [1] * i
objectid = ray.put(value_before)
value_after = ray.get(objectid)
assert value_before == value_after
def custom_serializers():
class Foo:
def __init__(self):
self.x = 3
def custom_serializer(obj):
return 3, "string1", type(obj).__name__
def custom_deserializer(serialized_obj):
return serialized_obj, "string2"
ray.register_custom_serializer(
Foo, serializer=custom_serializer, deserializer=custom_deserializer)
assert ray.get(ray.put(Foo())) == ((3, "string1", Foo.__name__), "string2")
class Bar:
def __init__(self):
self.x = 3
ray.register_custom_serializer(
Bar, serializer=custom_serializer, deserializer=custom_deserializer)
@ray.remote
def f():
return Bar()
assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2")
def test_custom_serializers(ray_start_regular):
custom_serializers()
def test_custom_serializers_with_pickle(shutdown_only):
ray.init(use_pickle=True)
custom_serializers()
class Foo:
def __init__(self):
self.x = 4
# Test the pickle serialization backend without serializer.
# NOTE: 'use_pickle' here is different from 'use_pickle' in
# ray.init
ray.register_custom_serializer(Foo, use_pickle=True)
@ray.remote
def f():
return Foo()
assert type(ray.get(f.remote())) == Foo
def test_serialization_final_fallback(ray_start_regular):
pytest.importorskip("catboost")
# This test will only run when "catboost" is installed.
from catboost import CatBoostClassifier
model = CatBoostClassifier(
iterations=2,
depth=2,
learning_rate=1,
loss_function="Logloss",
logging_level="Verbose")
reconstructed_model = ray.get(ray.put(model))
assert set(model.get_params().items()) == set(
reconstructed_model.get_params().items())
def test_register_class(ray_start_2_cpus):
# Check that putting an object of a class that has not been registered
# throws an exception.
class TempClass:
pass
ray.get(ray.put(TempClass()))
# Test passing custom classes into remote functions from the driver.
@ray.remote
def f(x):
return x
class Foo:
def __init__(self, value=0):
self.value = value
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return other.value == self.value
foo = ray.get(f.remote(Foo(7)))
assert foo == Foo(7)
regex = re.compile(r"\d+\.\d*")
new_regex = ray.get(f.remote(regex))
# This seems to fail on the system Python 3 that comes with
# Ubuntu, so it is commented out for now:
# assert regex == new_regex
# Instead, we do this:
assert regex.pattern == new_regex.pattern
class TempClass1:
def __init__(self):
self.value = 1
# Test returning custom classes created on workers.
@ray.remote
def g():
class TempClass2:
def __init__(self):
self.value = 2
return TempClass1(), TempClass2()
object_1, object_2 = ray.get(g.remote())
assert object_1.value == 1
assert object_2.value == 2
# Test exporting custom class definitions from one worker to another
# when the worker is blocked in a get.
class NewTempClass:
def __init__(self, value):
self.value = value
@ray.remote
def h1(x):
return NewTempClass(x)
@ray.remote
def h2(x):
return ray.get(h1.remote(x))
assert ray.get(h2.remote(10)).value == 10
# Test registering multiple classes with the same name.
@ray.remote(num_return_vals=3)
def j():
class Class0:
def method0(self):
pass
c0 = Class0()
class Class0:
def method1(self):
pass
c1 = Class0()
class Class0:
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = []
for _ in range(5):
results += j.remote()
for i in range(len(results) // 3):
c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))])
c0.method0()
c1.method1()
c2.method2()
assert not hasattr(c0, "method1")
assert not hasattr(c0, "method2")
assert not hasattr(c1, "method0")
assert not hasattr(c1, "method2")
assert not hasattr(c2, "method0")
assert not hasattr(c2, "method1")
@ray.remote
def k():
class Class0:
def method0(self):
pass
c0 = Class0()
class Class0:
def method1(self):
pass
c1 = Class0()
class Class0:
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = ray.get([k.remote() for _ in range(5)])
for c0, c1, c2 in results:
c0.method0()
c1.method1()
c2.method2()
assert not hasattr(c0, "method1")
assert not hasattr(c0, "method2")
assert not hasattr(c1, "method0")
assert not hasattr(c1, "method2")
assert not hasattr(c2, "method0")
assert not hasattr(c2, "method1")
def test_keyword_args(ray_start_regular):
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct2(a="hello", b="world"):
return "{} {}".format(a, b)
@ray.remote
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
x = keyword_fct1.remote(1)
assert ray.get(x) == "1 hello"
x = keyword_fct1.remote(1, "hi")
assert ray.get(x) == "1 hi"
x = keyword_fct1.remote(1, b="world")
assert ray.get(x) == "1 world"
x = keyword_fct1.remote(a=1, b="world")
assert ray.get(x) == "1 world"
x = keyword_fct2.remote(a="w", b="hi")
assert ray.get(x) == "w hi"
x = keyword_fct2.remote(b="hi", a="w")
assert ray.get(x) == "w hi"
x = keyword_fct2.remote(a="w")
assert ray.get(x) == "w world"
x = keyword_fct2.remote(b="hi")
assert ray.get(x) == "hello hi"
x = keyword_fct2.remote("w")
assert ray.get(x) == "w world"
x = keyword_fct2.remote("w", "hi")
assert ray.get(x) == "w hi"
x = keyword_fct3.remote(0, 1, c="w", d="hi")
assert ray.get(x) == "0 1 w hi"
x = keyword_fct3.remote(0, b=1, c="w", d="hi")
assert ray.get(x) == "0 1 w hi"
x = keyword_fct3.remote(a=0, b=1, c="w", d="hi")
assert ray.get(x) == "0 1 w hi"
x = keyword_fct3.remote(0, 1, d="hi", c="w")
assert ray.get(x) == "0 1 w hi"
x = keyword_fct3.remote(0, 1, c="w")
assert ray.get(x) == "0 1 w world"
x = keyword_fct3.remote(0, 1, d="hi")
assert ray.get(x) == "0 1 hello hi"
x = keyword_fct3.remote(0, 1)
assert ray.get(x) == "0 1 hello world"
x = keyword_fct3.remote(a=0, b=1)
assert ray.get(x) == "0 1 hello world"
# Check that we cannot pass invalid keyword arguments to functions.
@ray.remote
def f1():
return
@ray.remote
def f2(x, y=0, z=0):
return
# Make sure we get an exception if too many arguments are passed in.
with pytest.raises(Exception):
f1.remote(3)
with pytest.raises(Exception):
f1.remote(x=3)
with pytest.raises(Exception):
f2.remote(0, w=0)
with pytest.raises(Exception):
f2.remote(3, x=3)
# Make sure we get an exception if too many arguments are passed in.
with pytest.raises(Exception):
f2.remote(1, 2, 3, 4)
@ray.remote
def f3(x):
return x
assert ray.get(f3.remote(4)) == 4
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_args_starkwargs(ray_start_regular):
def starkwargs(a, b, **kwargs):
return a, b, kwargs
class TestActor:
def starkwargs(self, a, b, **kwargs):
return a, b, kwargs
def test_function(fn, remote_fn):
assert fn(1, 2, x=3) == ray.get(remote_fn.remote(1, 2, x=3))
with pytest.raises(TypeError):
remote_fn.remote(3)
remote_test_function = ray.remote(test_function)
remote_starkwargs = ray.remote(starkwargs)
test_function(starkwargs, remote_starkwargs)
ray.get(remote_test_function.remote(starkwargs, remote_starkwargs))
remote_actor_class = ray.remote(TestActor)
remote_actor = remote_actor_class.remote()
actor_method = remote_actor.starkwargs
local_actor = TestActor()
local_method = local_actor.starkwargs
test_function(local_method, actor_method)
ray.get(remote_test_function.remote(local_method, actor_method))
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_args_named_and_star(ray_start_regular):
def hello(a, x="hello", **kwargs):
return a, x, kwargs
class TestActor:
def hello(self, a, x="hello", **kwargs):
return a, x, kwargs
def test_function(fn, remote_fn):
assert fn(1, x=2, y=3) == ray.get(remote_fn.remote(1, x=2, y=3))
assert fn(1, 2, y=3) == ray.get(remote_fn.remote(1, 2, y=3))
assert fn(1, y=3) == ray.get(remote_fn.remote(1, y=3))
assert fn(1, ) == ray.get(remote_fn.remote(1, ))
assert fn(1) == ray.get(remote_fn.remote(1))
with pytest.raises(TypeError):
remote_fn.remote(1, 2, x=3)
remote_test_function = ray.remote(test_function)
remote_hello = ray.remote(hello)
test_function(hello, remote_hello)
ray.get(remote_test_function.remote(hello, remote_hello))
remote_actor_class = ray.remote(TestActor)
remote_actor = remote_actor_class.remote()
actor_method = remote_actor.hello
local_actor = TestActor()
local_method = local_actor.hello
test_function(local_method, actor_method)
ray.get(remote_test_function.remote(local_method, actor_method))
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_args_stars_after(ray_start_regular):
def star_args_after(a="hello", b="heo", *args, **kwargs):
return a, b, args, kwargs
class TestActor:
def star_args_after(self, a="hello", b="heo", *args, **kwargs):
return a, b, args, kwargs
def test_function(fn, remote_fn):
assert fn("hi", "hello", 2) == ray.get(
remote_fn.remote("hi", "hello", 2))
assert fn(
"hi", "hello", 2, hi="hi") == ray.get(
remote_fn.remote("hi", "hello", 2, hi="hi"))
assert fn(hi="hi") == ray.get(remote_fn.remote(hi="hi"))
remote_test_function = ray.remote(test_function)
remote_star_args_after = ray.remote(star_args_after)
test_function(star_args_after, remote_star_args_after)
ray.get(
remote_test_function.remote(star_args_after, remote_star_args_after))
remote_actor_class = ray.remote(TestActor)
remote_actor = remote_actor_class.remote()
actor_method = remote_actor.star_args_after
local_actor = TestActor()
local_method = local_actor.star_args_after
test_function(local_method, actor_method)
ray.get(remote_test_function.remote(local_method, actor_method))
def test_variable_number_of_args(shutdown_only):
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
ray.init(num_cpus=1)
x = varargs_fct1.remote(0, 1, 2)
assert ray.get(x) == "0 1 2"
x = varargs_fct2.remote(0, 1, 2)
assert ray.get(x) == "1 2"
@ray.remote
def f1(*args):
return args
@ray.remote
def f2(x, y, *args):
return x, y, args
assert ray.get(f1.remote()) == ()
assert ray.get(f1.remote(1)) == (1, )
assert ray.get(f1.remote(1, 2, 3)) == (1, 2, 3)
with pytest.raises(Exception):
f2.remote()
with pytest.raises(Exception):
f2.remote(1)
assert ray.get(f2.remote(1, 2)) == (1, 2, ())
assert ray.get(f2.remote(1, 2, 3)) == (1, 2, (3, ))
assert ray.get(f2.remote(1, 2, 3, 4)) == (1, 2, (3, 4))
def testNoArgs(self):
@ray.remote
def no_op():
pass
self.ray_start()
ray.get(no_op.remote())
def test_defining_remote_functions(shutdown_only):
ray.init(num_cpus=3)
# Test that we can close over plain old data.
data = [
np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {
"a": np.zeros(3)
}
]
@ray.remote
def g():
return data
ray.get(g.remote())
# Test that we can close over modules.
@ray.remote
def h():
return np.zeros([3, 5])
assert np.alltrue(ray.get(h.remote()) == np.zeros([3, 5]))
@ray.remote
def j():
return time.time()
ray.get(j.remote())
# Test that we can define remote functions that call other remote
# functions.
@ray.remote
def k(x):
return x + 1
@ray.remote
def k2(x):
return ray.get(k.remote(x))
@ray.remote
def m(x):
return ray.get(k2.remote(x))
assert ray.get(k.remote(1)) == 2
assert ray.get(k2.remote(1)) == 2
assert ray.get(m.remote(1)) == 2
def test_redefining_remote_functions(shutdown_only):
ray.init(num_cpus=1)
# Test that we can define a remote function in the shell.
@ray.remote
def f(x):
return x + 1
assert ray.get(f.remote(0)) == 1
# Test that we can redefine the remote function.
@ray.remote
def f(x):
return x + 10
while True:
val = ray.get(f.remote(0))
assert val in [1, 10]
if val == 10:
break
else:
logger.info("Still using old definition of f, trying again.")
# Check that we can redefine functions even when the remote function source
# doesn't change (see https://github.com/ray-project/ray/issues/6130).
@ray.remote
def g():
return nonexistent()
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
ray.get(g.remote())
def nonexistent():
return 1
# Redefine the function and make sure it succeeds.
@ray.remote
def g():
return nonexistent()
assert ray.get(g.remote()) == 1
# Check the same thing but when the redefined function is inside of another
# task.
@ray.remote
def h(i):
@ray.remote
def j():
return i
return j.remote()
for i in range(20):
assert ray.get(ray.get(h.remote(i))) == i
def test_submit_api(shutdown_only):
ray.init(num_cpus=2, num_gpus=1, resources={"Custom": 1})
@ray.remote
def f(n):
return list(range(n))
@ray.remote
def g():
return ray.get_gpu_ids()
assert f._remote([0], num_return_vals=0) is None
id1 = f._remote(args=[1], num_return_vals=1)
assert ray.get(id1) == [0]
id1, id2 = f._remote(args=[2], num_return_vals=2)
assert ray.get([id1, id2]) == [0, 1]
id1, id2, id3 = f._remote(args=[3], num_return_vals=3)
assert ray.get([id1, id2, id3]) == [0, 1, 2]
assert ray.get(
g._remote(args=[], num_cpus=1, num_gpus=1,
resources={"Custom": 1})) == [0]
infeasible_id = g._remote(args=[], resources={"NonexistentCustom": 1})
assert ray.get(g._remote()) == []
ready_ids, remaining_ids = ray.wait([infeasible_id], timeout=0.05)
assert len(ready_ids) == 0
assert len(remaining_ids) == 1
@ray.remote
class Actor:
def __init__(self, x, y=0):
self.x = x
self.y = y
def method(self, a, b=0):
return self.x, self.y, a, b
def gpu_ids(self):
return ray.get_gpu_ids()
@ray.remote
class Actor2:
def __init__(self):
pass
def method(self):
pass
a = Actor._remote(
args=[0], kwargs={"y": 1}, num_gpus=1, resources={"Custom": 1})
a2 = Actor2._remote()
ray.get(a2.method._remote())
id1, id2, id3, id4 = a.method._remote(
args=["test"], kwargs={"b": 2}, num_return_vals=4)
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
def test_many_fractional_resources(shutdown_only):
ray.init(num_cpus=2, num_gpus=2, resources={"Custom": 2})
@ray.remote
def g():
return 1
@ray.remote
def f(block, accepted_resources):
true_resources = {
resource: value[0][1]
for resource, value in ray.get_resource_ids().items()
}
if block:
ray.get(g.remote())
return true_resources == accepted_resources
# Check that the resource are assigned correctly.
result_ids = []
for rand1, rand2, rand3 in np.random.uniform(size=(100, 3)):
resource_set = {"CPU": int(rand1 * 10000) / 10000}
result_ids.append(f._remote([False, resource_set], num_cpus=rand1))
resource_set = {"CPU": 1, "GPU": int(rand1 * 10000) / 10000}
result_ids.append(f._remote([False, resource_set], num_gpus=rand1))
resource_set = {"CPU": 1, "Custom": int(rand1 * 10000) / 10000}
result_ids.append(
f._remote([False, resource_set], resources={"Custom": rand1}))
resource_set = {
"CPU": int(rand1 * 10000) / 10000,
"GPU": int(rand2 * 10000) / 10000,
"Custom": int(rand3 * 10000) / 10000
}
result_ids.append(
f._remote(
[False, resource_set],
num_cpus=rand1,
num_gpus=rand2,
resources={"Custom": rand3}))
result_ids.append(
f._remote(
[True, resource_set],
num_cpus=rand1,
num_gpus=rand2,
resources={"Custom": rand3}))
assert all(ray.get(result_ids))
# Check that the available resources at the end are the same as the
# beginning.
stop_time = time.time() + 10
correct_available_resources = False
while time.time() < stop_time:
if (ray.available_resources()["CPU"] == 2.0
and ray.available_resources()["GPU"] == 2.0
and ray.available_resources()["Custom"] == 2.0):
correct_available_resources = True
break
if not correct_available_resources:
assert False, "Did not get correct available resources."
def test_get_multiple(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
assert ray.get(object_ids) == list(range(10))
# Get a random choice of object IDs with duplicates.
indices = list(np.random.choice(range(10), 5))
indices += indices
results = ray.get([object_ids[i] for i in indices])
assert results == indices
def test_get_multiple_experimental(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
object_ids_tuple = tuple(object_ids)
assert ray.experimental.get(object_ids_tuple) == list(range(10))
object_ids_nparray = np.array(object_ids)
assert ray.experimental.get(object_ids_nparray) == list(range(10))
def test_get_dict(ray_start_regular):
d = {str(i): ray.put(i) for i in range(5)}
for i in range(5, 10):
d[str(i)] = i
result = ray.experimental.get(d)
expected = {str(i): i for i in range(10)}
assert result == expected
def test_get_with_timeout(ray_start_regular):
def random_path():
return os.path.join(tempfile.gettempdir(), uuid.uuid4().hex)
def touch(path):
with open(path, "w"):
pass
@ray.remote
def wait_for_file(path):
if path:
while True:
if os.path.exists(path):
break
time.sleep(0.1)
# Check that get() returns early if object is ready.
start = time.time()
ray.get(wait_for_file.remote(None), timeout=30)
assert time.time() - start < 30
# Check that get() raises a TimeoutError after the timeout if the object
# is not ready yet.
path = random_path()
result_id = wait_for_file.remote(path)
with pytest.raises(RayTimeoutError):
ray.get(result_id, timeout=0.1)
# Check that a subsequent get() returns early.
touch(path)
start = time.time()
ray.get(result_id, timeout=30)
assert time.time() - start < 30
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
"num_nodes": 1,
}, {
"num_cpus": 1,
"num_nodes": 2,
}],
indirect=True)
def test_direct_call_simple(ray_start_cluster):
@ray.remote
def f(x):
return x + 1
f_direct = f.options(is_direct_call=True)
assert ray.get(f_direct.remote(2)) == 3
for _ in range(10):
assert ray.get([f_direct.remote(i) for i in range(100)]) == list(
range(1, 101))
# https://github.com/ray-project/ray/issues/6329
def test_call_actors_indirect_through_tasks(ray_start_regular):
@ray.remote
class Counter:
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return self.value
@ray.remote
def foo(object):
return ray.get(object.increase.remote(1))
@ray.remote
def bar(object):
return ray.get(object.increase.remote(1))
@ray.remote
def zoo(object):
return ray.get(object[0].increase.remote(1))
c = Counter.remote(0)
for _ in range(0, 100):
ray.get(foo.remote(c))
ray.get(bar.remote(c))
ray.get(zoo.remote([c]))
def test_direct_call_refcount(ray_start_regular):
@ray.remote
def f(x):
return x + 1
@ray.remote
def sleep():
time.sleep(.1)
return 1
# Multiple gets should not hang with ref counting enabled.
f_direct = f.options(is_direct_call=True)
x = f_direct.remote(2)
ray.get(x)
ray.get(x)
# Temporary objects should be retained for chained callers.
y = f_direct.remote(sleep.options(is_direct_call=True).remote())
assert ray.get(y) == 2
def test_direct_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@ray.remote
class Actor:
def small_value(self):
return 0
def large_value(self):
return np.zeros(10 * 1024 * 1024)
def echo(self, x):
if isinstance(x, list):
x = ray.get(x[0])
return x
@ray.remote
def small_value():
return 0
@ray.remote
def large_value():
return np.zeros(10 * 1024 * 1024)
@ray.remote
def echo(x):
if isinstance(x, list):
x = ray.get(x[0])
return x
def check(source_actor, dest_actor, is_large, out_of_band):
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
if dest_actor else "task", "large_object"
if is_large else "small_object", "out_of_band"
if out_of_band else "in_band")
if source_actor:
a = Actor.options(is_direct_call=True).remote()
if is_large:
x_id = a.large_value.remote()
else:
x_id = a.small_value.remote()
else:
if is_large:
x_id = large_value.options(is_direct_call=True).remote()
else:
x_id = small_value.options(is_direct_call=True).remote()
if out_of_band:
x_id = [x_id]
if dest_actor:
b = Actor.options(is_direct_call=True).remote()
x = ray.get(b.echo.remote(x_id))
else:
x = ray.get(echo.options(is_direct_call=True).remote(x_id))
if is_large:
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, int)
for is_large in [False, True]:
for source_actor in [False, True]:
for dest_actor in [False, True]:
for out_of_band in [False, True]:
check(source_actor, dest_actor, is_large, out_of_band)
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
"num_nodes": 1,
}, {
"num_cpus": 1,
"num_nodes": 2,
}],
indirect=True)
def test_direct_call_chain(ray_start_cluster):
@ray.remote
def g(x):
return x + 1
g_direct = g.options(is_direct_call=True)
x = 0
for _ in range(100):
x = g_direct.remote(x)
assert ray.get(x) == 100
def test_direct_inline_arg_memory_corruption(ray_start_regular):
@ray.remote
def f():
return np.zeros(1000, dtype=np.uint8)
@ray.remote
class Actor:
def __init__(self):
self.z = []
def add(self, x):
self.z.append(x)
for prev in self.z:
assert np.sum(prev) == 0, ("memory corruption detected", prev)
a = Actor.options(is_direct_call=True).remote()
f_direct = f.options(is_direct_call=True)
for i in range(100):
ray.get(a.add.remote(f_direct.remote()))
def test_direct_actor_enabled(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
return x * 2
a = Actor._remote(is_direct_call=True)
obj_id = a.f.remote(1)
# it is not stored in plasma
assert not ray.worker.global_worker.core_worker.object_exists(obj_id)
assert ray.get(obj_id) == 2
def test_direct_actor_order(shutdown_only):
ray.init(num_cpus=4)
@ray.remote
def small_value():
time.sleep(0.01 * np.random.randint(0, 10))
return 0
@ray.remote
class Actor:
def __init__(self):
self.count = 0
def inc(self, count, dependency):
assert count == self.count
self.count += 1
return count
a = Actor._remote(is_direct_call=True)
assert ray.get([
a.inc.remote(i, small_value.options(is_direct_call=True).remote())
for i in range(100)
]) == list(range(100))
def test_direct_actor_large_objects(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self):
time.sleep(1)
return np.zeros(10000000)
a = Actor._remote(is_direct_call=True)
obj_id = a.f.remote()
assert not ray.worker.global_worker.core_worker.object_exists(obj_id)
done, _ = ray.wait([obj_id])
assert len(done) == 1
assert ray.worker.global_worker.core_worker.object_exists(obj_id)
assert isinstance(ray.get(obj_id), np.ndarray)
def test_direct_actor_pass_by_ref(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
return x * 2
@ray.remote
def f(x):
return x
@ray.remote
def error():
sys.exit(0)
a = Actor._remote(is_direct_call=True)
assert ray.get(a.f.remote(f.remote(1))) == 2
fut = [a.f.remote(f.remote(i)) for i in range(100)]
assert ray.get(fut) == [i * 2 for i in range(100)]
# propagates errors for pass by ref
with pytest.raises(Exception):
ray.get(a.f.remote(error.remote()))
def test_direct_actor_pass_by_ref_order_optimization(shutdown_only):
ray.init(num_cpus=4)
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
pass
a = Actor._remote(is_direct_call=True)
@ray.remote
def fast_value():
print("fast value")
pass
@ray.remote
def slow_value():
print("start sleep")
time.sleep(30)
@ray.remote
def runner(f):
print("runner", a, f)
return ray.get(a.f.remote(f.remote()))
runner.remote(slow_value)
time.sleep(1)
x2 = runner.remote(fast_value)
start = time.time()
ray.get(x2)
delta = time.time() - start
assert delta < 10, "did not skip slow value"
def test_direct_actor_recursive(ray_start_regular):
@ray.remote
class Actor:
def __init__(self, delegate=None):
self.delegate = delegate
def f(self, x):
if self.delegate:
return ray.get(self.delegate.f.remote(x))
return x * 2
a = Actor._remote(is_direct_call=True)
b = Actor._remote(args=[a], is_direct_call=True)
c = Actor._remote(args=[b], is_direct_call=True)
result = ray.get([c.f.remote(i) for i in range(100)])
assert result == [x * 2 for x in range(100)]
result, _ = ray.wait([c.f.remote(i) for i in range(100)], num_returns=100)
result = ray.get(result)
assert result == [x * 2 for x in range(100)]
def test_direct_actor_concurrent(ray_start_regular):
@ray.remote
class Batcher:
def __init__(self):
self.batch = []
self.event = threading.Event()
def add(self, x):
self.batch.append(x)
if len(self.batch) >= 3:
self.event.set()
else:
self.event.wait()
return sorted(self.batch)
a = Batcher.options(is_direct_call=True, max_concurrency=3).remote()
x1 = a.add.remote(1)
x2 = a.add.remote(2)
x3 = a.add.remote(3)
r1 = ray.get(x1)
r2 = ray.get(x2)
r3 = ray.get(x3)
assert r1 == [1, 2, 3]
assert r1 == r2 == r3
def test_wait(ray_start_regular):
@ray.remote
def f(delay):
time.sleep(delay)
return
object_ids = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)]
ready_ids, remaining_ids = ray.wait(object_ids)
assert len(ready_ids) == 1
assert len(remaining_ids) == 3
ready_ids, remaining_ids = ray.wait(object_ids, num_returns=4)
assert set(ready_ids) == set(object_ids)
assert remaining_ids == []
object_ids = [f.remote(0), f.remote(5)]
ready_ids, remaining_ids = ray.wait(object_ids, timeout=0.5, num_returns=2)
assert len(ready_ids) == 1
assert len(remaining_ids) == 1
# Verify that calling wait with duplicate object IDs throws an
# exception.
x = ray.put(1)
with pytest.raises(Exception):
ray.wait([x, x])
# Make sure it is possible to call wait with an empty list.
ready_ids, remaining_ids = ray.wait([])
assert ready_ids == []
assert remaining_ids == []
# Test semantics of num_returns with no timeout.
oids = [ray.put(i) for i in range(10)]
(found, rest) = ray.wait(oids, num_returns=2)
assert len(found) == 2
assert len(rest) == 8
# Verify that incorrect usage raises a TypeError.
x = ray.put(1)
with pytest.raises(TypeError):
ray.wait(x)
with pytest.raises(TypeError):
ray.wait(1)
with pytest.raises(TypeError):
ray.wait([1])
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))