Split test_basic to avoid timeouts in CI (#8405)

This commit is contained in:
Edward Oakes
2020-05-12 10:18:21 -05:00
committed by GitHub
parent a593fde606
commit b84fe56bed
4 changed files with 1220 additions and 1167 deletions
+16
View File
@@ -55,6 +55,14 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_serialization",
size = "small",
srcs = ["test_serialization.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_basic",
size = "medium",
@@ -63,6 +71,14 @@ py_test(
deps = ["//:ray_lib"],
)
py_test(
name = "test_basic_2",
size = "medium",
srcs = ["test_basic_2.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_advanced",
size = "medium",
File diff suppressed because it is too large Load Diff
+678
View File
@@ -0,0 +1,678 @@
# coding: utf-8
import json
import logging
import sys
import threading
import time
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
from ray.exceptions import RayTimeoutError
logger = logging.getLogger(__name__)
@pytest.mark.parametrize(
"shutdown_only", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_variable_number_of_args(shutdown_only):
@ray.remote
def varargs_fct1(*a):
return " ".join(map(str, a))
@ray.remote
def varargs_fct2(a, *b):
return " ".join(map(str, b))
ray.init(num_cpus=1)
x = varargs_fct1.remote(0, 1, 2)
assert ray.get(x) == "0 1 2"
x = varargs_fct2.remote(0, 1, 2)
assert ray.get(x) == "1 2"
@ray.remote
def f1(*args):
return args
@ray.remote
def f2(x, y, *args):
return x, y, args
assert ray.get(f1.remote()) == ()
assert ray.get(f1.remote(1)) == (1, )
assert ray.get(f1.remote(1, 2, 3)) == (1, 2, 3)
with pytest.raises(Exception):
f2.remote()
with pytest.raises(Exception):
f2.remote(1)
assert ray.get(f2.remote(1, 2)) == (1, 2, ())
assert ray.get(f2.remote(1, 2, 3)) == (1, 2, (3, ))
assert ray.get(f2.remote(1, 2, 3, 4)) == (1, 2, (3, 4))
def testNoArgs(self):
@ray.remote
def no_op():
pass
self.ray_start()
ray.get(no_op.remote())
@pytest.mark.parametrize(
"shutdown_only", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_defining_remote_functions(shutdown_only):
ray.init(num_cpus=3)
# Test that we can close over plain old data.
data = [
np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, {
"a": np.zeros(3)
}
]
@ray.remote
def g():
return data
ray.get(g.remote())
# Test that we can close over modules.
@ray.remote
def h():
return np.zeros([3, 5])
assert np.alltrue(ray.get(h.remote()) == np.zeros([3, 5]))
@ray.remote
def j():
return time.time()
ray.get(j.remote())
# Test that we can define remote functions that call other remote
# functions.
@ray.remote
def k(x):
return x + 1
@ray.remote
def k2(x):
return ray.get(k.remote(x))
@ray.remote
def m(x):
return ray.get(k2.remote(x))
assert ray.get(k.remote(1)) == 2
assert ray.get(k2.remote(1)) == 2
assert ray.get(m.remote(1)) == 2
@pytest.mark.parametrize(
"shutdown_only", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_redefining_remote_functions(shutdown_only):
ray.init(num_cpus=1)
# Test that we can define a remote function in the shell.
@ray.remote
def f(x):
return x + 1
assert ray.get(f.remote(0)) == 1
# Test that we can redefine the remote function.
@ray.remote
def f(x):
return x + 10
while True:
val = ray.get(f.remote(0))
assert val in [1, 10]
if val == 10:
break
else:
logger.info("Still using old definition of f, trying again.")
# Check that we can redefine functions even when the remote function source
# doesn't change (see https://github.com/ray-project/ray/issues/6130).
@ray.remote
def g():
return nonexistent()
with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"):
ray.get(g.remote())
def nonexistent():
return 1
# Redefine the function and make sure it succeeds.
@ray.remote
def g():
return nonexistent()
assert ray.get(g.remote()) == 1
# Check the same thing but when the redefined function is inside of another
# task.
@ray.remote
def h(i):
@ray.remote
def j():
return i
return j.remote()
for i in range(20):
assert ray.get(ray.get(h.remote(i))) == i
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_get_multiple(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
assert ray.get(object_ids) == list(range(10))
# Get a random choice of object IDs with duplicates.
indices = list(np.random.choice(range(10), 5))
indices += indices
results = ray.get([object_ids[i] for i in indices])
assert results == indices
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_get_multiple_experimental(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
object_ids_tuple = tuple(object_ids)
assert ray.experimental.get(object_ids_tuple) == list(range(10))
object_ids_nparray = np.array(object_ids)
assert ray.experimental.get(object_ids_nparray) == list(range(10))
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_get_dict(ray_start_regular):
d = {str(i): ray.put(i) for i in range(5)}
for i in range(5, 10):
d[str(i)] = i
result = ray.experimental.get(d)
expected = {str(i): i for i in range(10)}
assert result == expected
def test_get_with_timeout(ray_start_regular):
signal = ray.test_utils.SignalActor.remote()
# Check that get() returns early if object is ready.
start = time.time()
ray.get(signal.wait.remote(should_wait=False), timeout=30)
assert time.time() - start < 30
# Check that get() raises a TimeoutError after the timeout if the object
# is not ready yet.
result_id = signal.wait.remote()
with pytest.raises(RayTimeoutError):
ray.get(result_id, timeout=0.1)
# Check that a subsequent get() returns early.
ray.get(signal.send.remote())
start = time.time()
ray.get(result_id, timeout=30)
assert time.time() - start < 30
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
# https://github.com/ray-project/ray/issues/6329
def test_call_actors_indirect_through_tasks(ray_start_regular):
@ray.remote
class Counter:
def __init__(self, value):
self.value = int(value)
def increase(self, delta):
self.value += int(delta)
return self.value
@ray.remote
def foo(object):
return ray.get(object.increase.remote(1))
@ray.remote
def bar(object):
return ray.get(object.increase.remote(1))
@ray.remote
def zoo(object):
return ray.get(object[0].increase.remote(1))
c = Counter.remote(0)
for _ in range(0, 100):
ray.get(foo.remote(c))
ray.get(bar.remote(c))
ray.get(zoo.remote([c]))
def test_call_matrix(shutdown_only):
ray.init(object_store_memory=1000 * 1024 * 1024)
@ray.remote
class Actor:
def small_value(self):
return 0
def large_value(self):
return np.zeros(10 * 1024 * 1024)
def echo(self, x):
if isinstance(x, list):
x = ray.get(x[0])
return x
@ray.remote
def small_value():
return 0
@ray.remote
def large_value():
return np.zeros(10 * 1024 * 1024)
@ray.remote
def echo(x):
if isinstance(x, list):
x = ray.get(x[0])
return x
def check(source_actor, dest_actor, is_large, out_of_band):
print("CHECKING", "actor" if source_actor else "task", "to", "actor"
if dest_actor else "task", "large_object"
if is_large else "small_object", "out_of_band"
if out_of_band else "in_band")
if source_actor:
a = Actor.remote()
if is_large:
x_id = a.large_value.remote()
else:
x_id = a.small_value.remote()
else:
if is_large:
x_id = large_value.remote()
else:
x_id = small_value.remote()
if out_of_band:
x_id = [x_id]
if dest_actor:
b = Actor.remote()
x = ray.get(b.echo.remote(x_id))
else:
x = ray.get(echo.remote(x_id))
if is_large:
assert isinstance(x, np.ndarray)
else:
assert isinstance(x, int)
for is_large in [False, True]:
for source_actor in [False, True]:
for dest_actor in [False, True]:
for out_of_band in [False, True]:
check(source_actor, dest_actor, is_large, out_of_band)
@pytest.mark.parametrize(
"ray_start_cluster", [{
"num_cpus": 1,
"num_nodes": 1,
}, {
"num_cpus": 1,
"num_nodes": 2,
}],
indirect=True)
def test_call_chain(ray_start_cluster):
@ray.remote
def g(x):
return x + 1
x = 0
for _ in range(100):
x = g.remote(x)
assert ray.get(x) == 100
def test_inline_arg_memory_corruption(ray_start_regular):
@ray.remote
def f():
return np.zeros(1000, dtype=np.uint8)
@ray.remote
class Actor:
def __init__(self):
self.z = []
def add(self, x):
self.z.append(x)
for prev in self.z:
assert np.sum(prev) == 0, ("memory corruption detected", prev)
a = Actor.remote()
for i in range(100):
ray.get(a.add.remote(f.remote()))
def test_skip_plasma(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
return x * 2
a = Actor.remote()
obj_id = a.f.remote(1)
# it is not stored in plasma
assert not ray.worker.global_worker.core_worker.object_exists(obj_id)
assert ray.get(obj_id) == 2
def test_actor_call_order(shutdown_only):
ray.init(num_cpus=4)
@ray.remote
def small_value():
time.sleep(0.01 * np.random.randint(0, 10))
return 0
@ray.remote
class Actor:
def __init__(self):
self.count = 0
def inc(self, count, dependency):
assert count == self.count
self.count += 1
return count
a = Actor.remote()
assert ray.get([a.inc.remote(i, small_value.remote())
for i in range(100)]) == list(range(100))
def test_actor_large_objects(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self):
time.sleep(1)
return np.zeros(10000000)
a = Actor.remote()
obj_id = a.f.remote()
assert not ray.worker.global_worker.core_worker.object_exists(obj_id)
done, _ = ray.wait([obj_id])
assert len(done) == 1
assert ray.worker.global_worker.core_worker.object_exists(obj_id)
assert isinstance(ray.get(obj_id), np.ndarray)
def test_actor_pass_by_ref(ray_start_regular):
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
return x * 2
@ray.remote
def f(x):
return x
@ray.remote
def error():
sys.exit(0)
a = Actor.remote()
assert ray.get(a.f.remote(f.remote(1))) == 2
fut = [a.f.remote(f.remote(i)) for i in range(100)]
assert ray.get(fut) == [i * 2 for i in range(100)]
# propagates errors for pass by ref
with pytest.raises(Exception):
ray.get(a.f.remote(error.remote()))
def test_actor_pass_by_ref_order_optimization(shutdown_only):
ray.init(num_cpus=4)
@ray.remote
class Actor:
def __init__(self):
pass
def f(self, x):
pass
a = Actor.remote()
@ray.remote
def fast_value():
print("fast value")
pass
@ray.remote
def slow_value():
print("start sleep")
time.sleep(30)
@ray.remote
def runner(f):
print("runner", a, f)
return ray.get(a.f.remote(f.remote()))
runner.remote(slow_value)
time.sleep(1)
x2 = runner.remote(fast_value)
start = time.time()
ray.get(x2)
delta = time.time() - start
assert delta < 10, "did not skip slow value"
def test_actor_recursive(ray_start_regular):
@ray.remote
class Actor:
def __init__(self, delegate=None):
self.delegate = delegate
def f(self, x):
if self.delegate:
return ray.get(self.delegate.f.remote(x))
return x * 2
a = Actor.remote()
b = Actor.remote(a)
c = Actor.remote(b)
result = ray.get([c.f.remote(i) for i in range(100)])
assert result == [x * 2 for x in range(100)]
result, _ = ray.wait([c.f.remote(i) for i in range(100)], num_returns=100)
result = ray.get(result)
assert result == [x * 2 for x in range(100)]
def test_actor_concurrent(ray_start_regular):
@ray.remote
class Batcher:
def __init__(self):
self.batch = []
self.event = threading.Event()
def add(self, x):
self.batch.append(x)
if len(self.batch) >= 3:
self.event.set()
else:
self.event.wait()
return sorted(self.batch)
a = Batcher.options(max_concurrency=3).remote()
x1 = a.add.remote(1)
x2 = a.add.remote(2)
x3 = a.add.remote(3)
r1 = ray.get(x1)
r2 = ray.get(x2)
r3 = ray.get(x3)
assert r1 == [1, 2, 3]
assert r1 == r2 == r3
def test_wait(ray_start_regular):
@ray.remote
def f(delay):
time.sleep(delay)
return
object_ids = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)]
ready_ids, remaining_ids = ray.wait(object_ids)
assert len(ready_ids) == 1
assert len(remaining_ids) == 3
ready_ids, remaining_ids = ray.wait(object_ids, num_returns=4)
assert set(ready_ids) == set(object_ids)
assert remaining_ids == []
object_ids = [f.remote(0), f.remote(5)]
ready_ids, remaining_ids = ray.wait(object_ids, timeout=0.5, num_returns=2)
assert len(ready_ids) == 1
assert len(remaining_ids) == 1
# Verify that calling wait with duplicate object IDs throws an
# exception.
x = ray.put(1)
with pytest.raises(Exception):
ray.wait([x, x])
# Make sure it is possible to call wait with an empty list.
ready_ids, remaining_ids = ray.wait([])
assert ready_ids == []
assert remaining_ids == []
# Test semantics of num_returns with no timeout.
oids = [ray.put(i) for i in range(10)]
(found, rest) = ray.wait(oids, num_returns=2)
assert len(found) == 2
assert len(rest) == 8
# Verify that incorrect usage raises a TypeError.
x = ray.put(1)
with pytest.raises(TypeError):
ray.wait(x)
with pytest.raises(TypeError):
ray.wait(1)
with pytest.raises(TypeError):
ray.wait([1])
def test_duplicate_args(ray_start_regular):
@ray.remote
def f(arg1,
arg2,
arg1_duplicate,
kwarg1=None,
kwarg2=None,
kwarg1_duplicate=None):
assert arg1 == kwarg1
assert arg1 != arg2
assert arg1 == arg1_duplicate
assert kwarg1 != kwarg2
assert kwarg1 == kwarg1_duplicate
# Test by-value arguments.
arg1 = [1]
arg2 = [2]
ray.get(
f.remote(
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
# Test by-reference arguments.
arg1 = ray.put([1])
arg2 = ray.put([2])
ray.get(
f.remote(
arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1))
def test_internal_config_when_connecting(ray_start_cluster):
config = json.dumps({
"object_pinning_enabled": 0,
"initial_reconstruction_timeout_milliseconds": 200
})
cluster = ray.cluster_utils.Cluster()
cluster.add_node(
_internal_config=config, object_store_memory=100 * 1024 * 1024)
cluster.wait_for_nodes()
# Specifying _internal_config when connecting to a cluster is disallowed.
with pytest.raises(ValueError):
ray.init(address=cluster.address, _internal_config=config)
# Check that the config was picked up (object pinning is disabled).
ray.init(address=cluster.address)
oid = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8))
for _ in range(5):
ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8))
# This would not raise an exception if object pinning was enabled.
with pytest.raises(ray.exceptions.UnreconstructableError):
ray.get(oid)
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))
+526
View File
@@ -0,0 +1,526 @@
# coding: utf-8
import collections
import io
import logging
import re
import string
import sys
import numpy as np
import pytest
import ray
import ray.cluster_utils
import ray.test_utils
logger = logging.getLogger(__name__)
def is_named_tuple(cls):
"""Return True if cls is a namedtuple and False otherwise."""
b = cls.__bases__
if len(b) != 1 or b[0] != tuple:
return False
f = getattr(cls, "_fields", None)
if not isinstance(f, tuple):
return False
return all(type(n) == str for n in f)
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_simple_serialization(ray_start_regular):
primitive_objects = [
# Various primitive types.
0,
0.0,
0.9,
1 << 62,
1 << 999,
b"",
b"a",
"a",
string.printable,
"\u262F",
u"hello world",
u"\xff\xfe\x9c\x001\x000\x00",
None,
True,
False,
[],
(),
{},
type,
int,
set(),
# Collections types.
collections.Counter([np.random.randint(0, 10) for _ in range(100)]),
collections.OrderedDict([("hello", 1), ("world", 2)]),
collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]),
collections.defaultdict(lambda: [], [("hello", 1), ("world", 2)]),
collections.deque([1, 2, 3, "a", "b", "c", 3.5]),
# Numpy dtypes.
np.int8(3),
np.int32(4),
np.int64(5),
np.uint8(3),
np.uint32(4),
np.uint64(5),
np.float32(1.9),
np.float64(1.9),
]
composite_objects = (
[[obj]
for obj in primitive_objects] + [(obj, )
for obj in primitive_objects] + [{
(): obj
} for obj in primitive_objects])
@ray.remote
def f(x):
return x
# Check that we can pass arguments by value to remote functions and
# that they are uncorrupted.
for obj in primitive_objects + composite_objects:
new_obj_1 = ray.get(f.remote(obj))
new_obj_2 = ray.get(ray.put(obj))
assert obj == new_obj_1
assert obj == new_obj_2
# TODO(rkn): The numpy dtypes currently come back as regular integers
# or floats.
if type(obj).__module__ != "numpy":
assert type(obj) == type(new_obj_1)
assert type(obj) == type(new_obj_2)
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_complex_serialization(ray_start_regular):
def assert_equal(obj1, obj2):
module_numpy = (type(obj1).__module__ == np.__name__
or type(obj2).__module__ == np.__name__)
if module_numpy:
empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ())
or (hasattr(obj2, "shape") and obj2.shape == ()))
if empty_shape:
# This is a special case because currently
# np.testing.assert_equal fails because we do not properly
# handle different numerical types.
assert obj1 == obj2, ("Objects {} and {} are "
"different.".format(obj1, obj2))
else:
np.testing.assert_equal(obj1, obj2)
elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"):
special_keys = ["_pytype_"]
assert (set(list(obj1.__dict__.keys()) + special_keys) == set(
list(obj2.__dict__.keys()) + special_keys)), (
"Objects {} and {} are different.".format(obj1, obj2))
for key in obj1.__dict__.keys():
if key not in special_keys:
assert_equal(obj1.__dict__[key], obj2.__dict__[key])
elif type(obj1) is dict or type(obj2) is dict:
assert_equal(obj1.keys(), obj2.keys())
for key in obj1.keys():
assert_equal(obj1[key], obj2[key])
elif type(obj1) is list or type(obj2) is list:
assert len(obj1) == len(obj2), ("Objects {} and {} are lists with "
"different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif type(obj1) is tuple or type(obj2) is tuple:
assert len(obj1) == len(obj2), ("Objects {} and {} are tuples "
"with different lengths.".format(
obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
elif (is_named_tuple(type(obj1)) or is_named_tuple(type(obj2))):
assert len(obj1) == len(obj2), (
"Objects {} and {} are named "
"tuples with different lengths.".format(obj1, obj2))
for i in range(len(obj1)):
assert_equal(obj1[i], obj2[i])
else:
assert obj1 == obj2, "Objects {} and {} are different.".format(
obj1, obj2)
long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])]
PRIMITIVE_OBJECTS = [
0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a",
string.printable, "\u262F", u"hello world",
u"\xff\xfe\x9c\x001\x000\x00", None, True, False, [], (), {},
np.int8(3),
np.int32(4),
np.int64(5),
np.uint8(3),
np.uint32(4),
np.uint64(5),
np.float32(1.9),
np.float64(1.9),
np.zeros([100, 100]),
np.random.normal(size=[100, 100]),
np.array(["hi", 3]),
np.array(["hi", 3], dtype=object)
] + long_extras
COMPLEX_OBJECTS = [
[[[[[[[[[[[[]]]]]]]]]]]],
{
"obj{}".format(i): np.random.normal(size=[100, 100])
for i in range(10)
},
# {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): {
# (): {(): {}}}}}}}}}}}}},
(
(((((((((), ), ), ), ), ), ), ), ), ),
{
"a": {
"b": {
"c": {
"d": {}
}
}
}
},
]
class Foo:
def __init__(self, value=0):
self.value = value
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return other.value == self.value
class Bar:
def __init__(self):
for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS):
setattr(self, "field{}".format(i), val)
class Baz:
def __init__(self):
self.foo = Foo()
self.bar = Bar()
def method(self, arg):
pass
class Qux:
def __init__(self):
self.objs = [Foo(), Bar(), Baz()]
class SubQux(Qux):
def __init__(self):
Qux.__init__(self)
class CustomError(Exception):
pass
Point = collections.namedtuple("Point", ["x", "y"])
NamedTupleExample = collections.namedtuple(
"Example", "field1, field2, field3, field4, field5")
CUSTOM_OBJECTS = [
Exception("Test object."),
CustomError(),
Point(11, y=22),
Foo(),
Bar(),
Baz(), # Qux(), SubQux(),
NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]),
]
# Test dataclasses in Python 3.7.
if sys.version_info >= (3, 7):
from dataclasses import make_dataclass
DataClass0 = make_dataclass("DataClass0", [("number", int)])
CUSTOM_OBJECTS.append(DataClass0(number=3))
class CustomClass:
def __init__(self, value):
self.value = value
DataClass1 = make_dataclass("DataClass1", [("custom", CustomClass)])
class DataClass2(DataClass1):
@classmethod
def from_custom(cls, data):
custom = CustomClass(data)
return cls(custom)
def __reduce__(self):
return (self.from_custom, (self.custom.value, ))
CUSTOM_OBJECTS.append(DataClass2(custom=CustomClass(43)))
BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS
LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS]
TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS]
# The check that type(obj).__module__ != "numpy" should be unnecessary, but
# otherwise this seems to fail on Mac OS X on Travis.
DICT_OBJECTS = ([{
obj: obj
} for obj in PRIMITIVE_OBJECTS if (
obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{
0: obj
} for obj in BASE_OBJECTS] + [{
Foo(123): Foo(456)
}])
RAY_TEST_OBJECTS = (
BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS)
@ray.remote
def f(x):
return x
# Check that we can pass arguments by value to remote functions and
# that they are uncorrupted.
for obj in RAY_TEST_OBJECTS:
assert_equal(obj, ray.get(f.remote(obj)))
assert_equal(obj, ray.get(ray.put(obj)))
# Test StringIO serialization
s = io.StringIO(u"Hello, world!\n")
s.seek(0)
line = s.readline()
s.seek(0)
assert ray.get(ray.put(s)).readline() == line
def test_numpy_serialization(ray_start_regular):
array = np.zeros(314)
from ray.cloudpickle import dumps
buffers = []
inband = dumps(array, protocol=5, buffer_callback=buffers.append)
assert len(inband) < array.nbytes
assert len(buffers) == 1
def test_numpy_subclass_serialization(ray_start_regular):
class MyNumpyConstant(np.ndarray):
def __init__(self, value):
super().__init__()
self.constant = value
def __str__(self):
print(self.constant)
constant = MyNumpyConstant(123)
def explode(x):
raise RuntimeError("Expected error.")
ray.register_custom_serializer(
type(constant), serializer=explode, deserializer=explode)
try:
ray.put(constant)
assert False, "Should never get here!"
except (RuntimeError, IndexError):
print("Correct behavior, proof that customer serializer was used.")
def test_numpy_subclass_serialization_pickle(ray_start_regular):
class MyNumpyConstant(np.ndarray):
def __init__(self, value):
super().__init__()
self.constant = value
def __str__(self):
print(self.constant)
constant = MyNumpyConstant(123)
repr_orig = repr(constant)
repr_ser = repr(ray.get(ray.put(constant)))
assert repr_orig == repr_ser
@pytest.mark.parametrize(
"ray_start_regular", [{
"local_mode": True
}, {
"local_mode": False
}],
indirect=True)
def test_serialization_final_fallback(ray_start_regular):
pytest.importorskip("catboost")
# This test will only run when "catboost" is installed.
from catboost import CatBoostClassifier
model = CatBoostClassifier(
iterations=2,
depth=2,
learning_rate=1,
loss_function="Logloss",
logging_level="Verbose")
reconstructed_model = ray.get(ray.put(model))
assert set(model.get_params().items()) == set(
reconstructed_model.get_params().items())
def test_register_class(ray_start_2_cpus):
# Check that putting an object of a class that has not been registered
# throws an exception.
class TempClass:
pass
ray.get(ray.put(TempClass()))
# Test passing custom classes into remote functions from the driver.
@ray.remote
def f(x):
return x
class Foo:
def __init__(self, value=0):
self.value = value
def __hash__(self):
return hash(self.value)
def __eq__(self, other):
return other.value == self.value
foo = ray.get(f.remote(Foo(7)))
assert foo == Foo(7)
regex = re.compile(r"\d+\.\d*")
new_regex = ray.get(f.remote(regex))
# This seems to fail on the system Python 3 that comes with
# Ubuntu, so it is commented out for now:
# assert regex == new_regex
# Instead, we do this:
assert regex.pattern == new_regex.pattern
class TempClass1:
def __init__(self):
self.value = 1
# Test returning custom classes created on workers.
@ray.remote
def g():
class TempClass2:
def __init__(self):
self.value = 2
return TempClass1(), TempClass2()
object_1, object_2 = ray.get(g.remote())
assert object_1.value == 1
assert object_2.value == 2
# Test exporting custom class definitions from one worker to another
# when the worker is blocked in a get.
class NewTempClass:
def __init__(self, value):
self.value = value
@ray.remote
def h1(x):
return NewTempClass(x)
@ray.remote
def h2(x):
return ray.get(h1.remote(x))
assert ray.get(h2.remote(10)).value == 10
# Test registering multiple classes with the same name.
@ray.remote(num_return_vals=3)
def j():
class Class0:
def method0(self):
pass
c0 = Class0()
class Class0:
def method1(self):
pass
c1 = Class0()
class Class0:
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = []
for _ in range(5):
results += j.remote()
for i in range(len(results) // 3):
c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))])
c0.method0()
c1.method1()
c2.method2()
assert not hasattr(c0, "method1")
assert not hasattr(c0, "method2")
assert not hasattr(c1, "method0")
assert not hasattr(c1, "method2")
assert not hasattr(c2, "method0")
assert not hasattr(c2, "method1")
@ray.remote
def k():
class Class0:
def method0(self):
pass
c0 = Class0()
class Class0:
def method1(self):
pass
c1 = Class0()
class Class0:
def method2(self):
pass
c2 = Class0()
return c0, c1, c2
results = ray.get([k.remote() for _ in range(5)])
for c0, c1, c2 in results:
c0.method0()
c1.method1()
c2.method2()
assert not hasattr(c0, "method1")
assert not hasattr(c0, "method2")
assert not hasattr(c1, "method0")
assert not hasattr(c1, "method2")
assert not hasattr(c2, "method0")
assert not hasattr(c2, "method1")
if __name__ == "__main__":
import pytest
sys.exit(pytest.main(["-v", __file__]))