From b84fe56bedeed846848c40596e7e37b82b0d6948 Mon Sep 17 00:00:00 2001 From: Edward Oakes Date: Tue, 12 May 2020 10:18:21 -0500 Subject: [PATCH] Split test_basic to avoid timeouts in CI (#8405) --- python/ray/tests/BUILD | 16 + python/ray/tests/test_basic.py | 1167 ------------------------ python/ray/tests/test_basic_2.py | 678 ++++++++++++++ python/ray/tests/test_serialization.py | 526 +++++++++++ 4 files changed, 1220 insertions(+), 1167 deletions(-) create mode 100644 python/ray/tests/test_basic_2.py create mode 100644 python/ray/tests/test_serialization.py diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index f2fefb044..eb4bc2bac 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -55,6 +55,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_serialization", + size = "small", + srcs = ["test_serialization.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) + py_test( name = "test_basic", size = "medium", @@ -63,6 +71,14 @@ py_test( deps = ["//:ray_lib"], ) +py_test( + name = "test_basic_2", + size = "medium", + srcs = ["test_basic_2.py"], + tags = ["exclusive"], + deps = ["//:ray_lib"], +) + py_test( name = "test_advanced", size = "medium", diff --git a/python/ray/tests/test_basic.py b/python/ray/tests/test_basic.py index 44ae291ba..ded9b6123 100644 --- a/python/ray/tests/test_basic.py +++ b/python/ray/tests/test_basic.py @@ -1,14 +1,10 @@ # coding: utf-8 -import collections import io import json import logging import os import pickle -import re -import string import sys -import threading import time import weakref @@ -18,22 +14,10 @@ import pytest import ray import ray.cluster_utils import ray.test_utils -from ray.exceptions import RayTimeoutError logger = logging.getLogger(__name__) -def is_named_tuple(cls): - """Return True if cls is a namedtuple and False otherwise.""" - b = cls.__bases__ - if len(b) != 1 or b[0] != tuple: - return False - f = getattr(cls, "_fields", None) - if not isinstance(f, tuple): - return False - return all(type(n) == str for n in f) - - # https://github.com/ray-project/ray/issues/6662 def test_ignore_http_proxy(shutdown_only): ray.init(num_cpus=1) @@ -191,79 +175,6 @@ def test_many_fractional_resources(shutdown_only): assert False, "Did not get correct available resources." -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_simple_serialization(ray_start_regular): - primitive_objects = [ - # Various primitive types. - 0, - 0.0, - 0.9, - 1 << 62, - 1 << 999, - b"", - b"a", - "a", - string.printable, - "\u262F", - u"hello world", - u"\xff\xfe\x9c\x001\x000\x00", - None, - True, - False, - [], - (), - {}, - type, - int, - set(), - # Collections types. - collections.Counter([np.random.randint(0, 10) for _ in range(100)]), - collections.OrderedDict([("hello", 1), ("world", 2)]), - collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]), - collections.defaultdict(lambda: [], [("hello", 1), ("world", 2)]), - collections.deque([1, 2, 3, "a", "b", "c", 3.5]), - # Numpy dtypes. - np.int8(3), - np.int32(4), - np.int64(5), - np.uint8(3), - np.uint32(4), - np.uint64(5), - np.float32(1.9), - np.float64(1.9), - ] - - composite_objects = ( - [[obj] - for obj in primitive_objects] + [(obj, ) - for obj in primitive_objects] + [{ - (): obj - } for obj in primitive_objects]) - - @ray.remote - def f(x): - return x - - # Check that we can pass arguments by value to remote functions and - # that they are uncorrupted. - for obj in primitive_objects + composite_objects: - new_obj_1 = ray.get(f.remote(obj)) - new_obj_2 = ray.get(ray.put(obj)) - assert obj == new_obj_1 - assert obj == new_obj_2 - # TODO(rkn): The numpy dtypes currently come back as regular integers - # or floats. - if type(obj).__module__ != "numpy": - assert type(obj) == type(new_obj_1) - assert type(obj) == type(new_obj_2) - - def test_background_tasks_with_max_calls(shutdown_only): ray.init(num_cpus=2) @@ -319,260 +230,6 @@ def test_fair_queueing(shutdown_only): assert len(ready) == 1000, len(ready) -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_complex_serialization(ray_start_regular): - def assert_equal(obj1, obj2): - module_numpy = (type(obj1).__module__ == np.__name__ - or type(obj2).__module__ == np.__name__) - if module_numpy: - empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) - or (hasattr(obj2, "shape") and obj2.shape == ())) - if empty_shape: - # This is a special case because currently - # np.testing.assert_equal fails because we do not properly - # handle different numerical types. - assert obj1 == obj2, ("Objects {} and {} are " - "different.".format(obj1, obj2)) - else: - np.testing.assert_equal(obj1, obj2) - elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): - special_keys = ["_pytype_"] - assert (set(list(obj1.__dict__.keys()) + special_keys) == set( - list(obj2.__dict__.keys()) + special_keys)), ( - "Objects {} and {} are different.".format(obj1, obj2)) - for key in obj1.__dict__.keys(): - if key not in special_keys: - assert_equal(obj1.__dict__[key], obj2.__dict__[key]) - elif type(obj1) is dict or type(obj2) is dict: - assert_equal(obj1.keys(), obj2.keys()) - for key in obj1.keys(): - assert_equal(obj1[key], obj2[key]) - elif type(obj1) is list or type(obj2) is list: - assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " - "different lengths.".format( - obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - elif type(obj1) is tuple or type(obj2) is tuple: - assert len(obj1) == len(obj2), ("Objects {} and {} are tuples " - "with different lengths.".format( - obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - elif (is_named_tuple(type(obj1)) or is_named_tuple(type(obj2))): - assert len(obj1) == len(obj2), ( - "Objects {} and {} are named " - "tuples with different lengths.".format(obj1, obj2)) - for i in range(len(obj1)): - assert_equal(obj1[i], obj2[i]) - else: - assert obj1 == obj2, "Objects {} and {} are different.".format( - obj1, obj2) - - long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] - - PRIMITIVE_OBJECTS = [ - 0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a", - string.printable, "\u262F", u"hello world", - u"\xff\xfe\x9c\x001\x000\x00", None, True, False, [], (), {}, - np.int8(3), - np.int32(4), - np.int64(5), - np.uint8(3), - np.uint32(4), - np.uint64(5), - np.float32(1.9), - np.float64(1.9), - np.zeros([100, 100]), - np.random.normal(size=[100, 100]), - np.array(["hi", 3]), - np.array(["hi", 3], dtype=object) - ] + long_extras - - COMPLEX_OBJECTS = [ - [[[[[[[[[[[[]]]]]]]]]]]], - { - "obj{}".format(i): np.random.normal(size=[100, 100]) - for i in range(10) - }, - # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { - # (): {(): {}}}}}}}}}}}}}, - ( - (((((((((), ), ), ), ), ), ), ), ), ), - { - "a": { - "b": { - "c": { - "d": {} - } - } - } - }, - ] - - class Foo: - def __init__(self, value=0): - self.value = value - - def __hash__(self): - return hash(self.value) - - def __eq__(self, other): - return other.value == self.value - - class Bar: - def __init__(self): - for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): - setattr(self, "field{}".format(i), val) - - class Baz: - def __init__(self): - self.foo = Foo() - self.bar = Bar() - - def method(self, arg): - pass - - class Qux: - def __init__(self): - self.objs = [Foo(), Bar(), Baz()] - - class SubQux(Qux): - def __init__(self): - Qux.__init__(self) - - class CustomError(Exception): - pass - - Point = collections.namedtuple("Point", ["x", "y"]) - NamedTupleExample = collections.namedtuple( - "Example", "field1, field2, field3, field4, field5") - - CUSTOM_OBJECTS = [ - Exception("Test object."), - CustomError(), - Point(11, y=22), - Foo(), - Bar(), - Baz(), # Qux(), SubQux(), - NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]), - ] - - # Test dataclasses in Python 3.7. - if sys.version_info >= (3, 7): - from dataclasses import make_dataclass - - DataClass0 = make_dataclass("DataClass0", [("number", int)]) - - CUSTOM_OBJECTS.append(DataClass0(number=3)) - - class CustomClass: - def __init__(self, value): - self.value = value - - DataClass1 = make_dataclass("DataClass1", [("custom", CustomClass)]) - - class DataClass2(DataClass1): - @classmethod - def from_custom(cls, data): - custom = CustomClass(data) - return cls(custom) - - def __reduce__(self): - return (self.from_custom, (self.custom.value, )) - - CUSTOM_OBJECTS.append(DataClass2(custom=CustomClass(43))) - - BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS - - LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS] - TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS] - # The check that type(obj).__module__ != "numpy" should be unnecessary, but - # otherwise this seems to fail on Mac OS X on Travis. - DICT_OBJECTS = ([{ - obj: obj - } for obj in PRIMITIVE_OBJECTS if ( - obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{ - 0: obj - } for obj in BASE_OBJECTS] + [{ - Foo(123): Foo(456) - }]) - - RAY_TEST_OBJECTS = ( - BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS) - - @ray.remote - def f(x): - return x - - # Check that we can pass arguments by value to remote functions and - # that they are uncorrupted. - for obj in RAY_TEST_OBJECTS: - assert_equal(obj, ray.get(f.remote(obj))) - assert_equal(obj, ray.get(ray.put(obj))) - - # Test StringIO serialization - s = io.StringIO(u"Hello, world!\n") - s.seek(0) - line = s.readline() - s.seek(0) - assert ray.get(ray.put(s)).readline() == line - - -def test_numpy_serialization(ray_start_regular): - array = np.zeros(314) - from ray.cloudpickle import dumps - buffers = [] - inband = dumps(array, protocol=5, buffer_callback=buffers.append) - assert len(inband) < array.nbytes - assert len(buffers) == 1 - - -def test_numpy_subclass_serialization(ray_start_regular): - class MyNumpyConstant(np.ndarray): - def __init__(self, value): - super().__init__() - self.constant = value - - def __str__(self): - print(self.constant) - - constant = MyNumpyConstant(123) - - def explode(x): - raise RuntimeError("Expected error.") - - ray.register_custom_serializer( - type(constant), serializer=explode, deserializer=explode) - - try: - ray.put(constant) - assert False, "Should never get here!" - except (RuntimeError, IndexError): - print("Correct behavior, proof that customer serializer was used.") - - -def test_numpy_subclass_serialization_pickle(ray_start_regular): - class MyNumpyConstant(np.ndarray): - def __init__(self, value): - super().__init__() - self.constant = value - - def __str__(self): - print(self.constant) - - constant = MyNumpyConstant(123) - repr_orig = repr(constant) - repr_ser = repr(ray.get(ray.put(constant))) - assert repr_orig == repr_ser - - def test_function_descriptor(): python_descriptor = ray._raylet.PythonFunctionDescriptor( "module_name", "function_name", "class_name", "function_hash") @@ -864,173 +521,6 @@ def test_custom_serializers(ray_start_regular): assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2") -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_serialization_final_fallback(ray_start_regular): - pytest.importorskip("catboost") - # This test will only run when "catboost" is installed. - from catboost import CatBoostClassifier - - model = CatBoostClassifier( - iterations=2, - depth=2, - learning_rate=1, - loss_function="Logloss", - logging_level="Verbose") - - reconstructed_model = ray.get(ray.put(model)) - assert set(model.get_params().items()) == set( - reconstructed_model.get_params().items()) - - -def test_register_class(ray_start_2_cpus): - # Check that putting an object of a class that has not been registered - # throws an exception. - class TempClass: - pass - - ray.get(ray.put(TempClass())) - - # Test passing custom classes into remote functions from the driver. - @ray.remote - def f(x): - return x - - class Foo: - def __init__(self, value=0): - self.value = value - - def __hash__(self): - return hash(self.value) - - def __eq__(self, other): - return other.value == self.value - - foo = ray.get(f.remote(Foo(7))) - assert foo == Foo(7) - - regex = re.compile(r"\d+\.\d*") - new_regex = ray.get(f.remote(regex)) - # This seems to fail on the system Python 3 that comes with - # Ubuntu, so it is commented out for now: - # assert regex == new_regex - # Instead, we do this: - assert regex.pattern == new_regex.pattern - - class TempClass1: - def __init__(self): - self.value = 1 - - # Test returning custom classes created on workers. - @ray.remote - def g(): - class TempClass2: - def __init__(self): - self.value = 2 - - return TempClass1(), TempClass2() - - object_1, object_2 = ray.get(g.remote()) - assert object_1.value == 1 - assert object_2.value == 2 - - # Test exporting custom class definitions from one worker to another - # when the worker is blocked in a get. - class NewTempClass: - def __init__(self, value): - self.value = value - - @ray.remote - def h1(x): - return NewTempClass(x) - - @ray.remote - def h2(x): - return ray.get(h1.remote(x)) - - assert ray.get(h2.remote(10)).value == 10 - - # Test registering multiple classes with the same name. - @ray.remote(num_return_vals=3) - def j(): - class Class0: - def method0(self): - pass - - c0 = Class0() - - class Class0: - def method1(self): - pass - - c1 = Class0() - - class Class0: - def method2(self): - pass - - c2 = Class0() - - return c0, c1, c2 - - results = [] - for _ in range(5): - results += j.remote() - for i in range(len(results) // 3): - c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))]) - - c0.method0() - c1.method1() - c2.method2() - - assert not hasattr(c0, "method1") - assert not hasattr(c0, "method2") - assert not hasattr(c1, "method0") - assert not hasattr(c1, "method2") - assert not hasattr(c2, "method0") - assert not hasattr(c2, "method1") - - @ray.remote - def k(): - class Class0: - def method0(self): - pass - - c0 = Class0() - - class Class0: - def method1(self): - pass - - c1 = Class0() - - class Class0: - def method2(self): - pass - - c2 = Class0() - - return c0, c1, c2 - - results = ray.get([k.remote() for _ in range(5)]) - for c0, c1, c2 in results: - c0.method0() - c1.method1() - c2.method2() - - assert not hasattr(c0, "method1") - assert not hasattr(c0, "method2") - assert not hasattr(c1, "method0") - assert not hasattr(c1, "method2") - assert not hasattr(c2, "method0") - assert not hasattr(c2, "method1") - - @pytest.mark.parametrize( "ray_start_regular", [{ "local_mode": True @@ -1238,663 +728,6 @@ def test_args_stars_after(ray_start_regular): ray.get(remote_test_function.remote(local_method, actor_method)) -@pytest.mark.parametrize( - "shutdown_only", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_variable_number_of_args(shutdown_only): - @ray.remote - def varargs_fct1(*a): - return " ".join(map(str, a)) - - @ray.remote - def varargs_fct2(a, *b): - return " ".join(map(str, b)) - - ray.init(num_cpus=1) - - x = varargs_fct1.remote(0, 1, 2) - assert ray.get(x) == "0 1 2" - x = varargs_fct2.remote(0, 1, 2) - assert ray.get(x) == "1 2" - - @ray.remote - def f1(*args): - return args - - @ray.remote - def f2(x, y, *args): - return x, y, args - - assert ray.get(f1.remote()) == () - assert ray.get(f1.remote(1)) == (1, ) - assert ray.get(f1.remote(1, 2, 3)) == (1, 2, 3) - with pytest.raises(Exception): - f2.remote() - with pytest.raises(Exception): - f2.remote(1) - assert ray.get(f2.remote(1, 2)) == (1, 2, ()) - assert ray.get(f2.remote(1, 2, 3)) == (1, 2, (3, )) - assert ray.get(f2.remote(1, 2, 3, 4)) == (1, 2, (3, 4)) - - def testNoArgs(self): - @ray.remote - def no_op(): - pass - - self.ray_start() - - ray.get(no_op.remote()) - - -@pytest.mark.parametrize( - "shutdown_only", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_defining_remote_functions(shutdown_only): - ray.init(num_cpus=3) - - # Test that we can close over plain old data. - data = [ - np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, { - "a": np.zeros(3) - } - ] - - @ray.remote - def g(): - return data - - ray.get(g.remote()) - - # Test that we can close over modules. - @ray.remote - def h(): - return np.zeros([3, 5]) - - assert np.alltrue(ray.get(h.remote()) == np.zeros([3, 5])) - - @ray.remote - def j(): - return time.time() - - ray.get(j.remote()) - - # Test that we can define remote functions that call other remote - # functions. - @ray.remote - def k(x): - return x + 1 - - @ray.remote - def k2(x): - return ray.get(k.remote(x)) - - @ray.remote - def m(x): - return ray.get(k2.remote(x)) - - assert ray.get(k.remote(1)) == 2 - assert ray.get(k2.remote(1)) == 2 - assert ray.get(m.remote(1)) == 2 - - -@pytest.mark.parametrize( - "shutdown_only", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_redefining_remote_functions(shutdown_only): - ray.init(num_cpus=1) - - # Test that we can define a remote function in the shell. - @ray.remote - def f(x): - return x + 1 - - assert ray.get(f.remote(0)) == 1 - - # Test that we can redefine the remote function. - @ray.remote - def f(x): - return x + 10 - - while True: - val = ray.get(f.remote(0)) - assert val in [1, 10] - if val == 10: - break - else: - logger.info("Still using old definition of f, trying again.") - - # Check that we can redefine functions even when the remote function source - # doesn't change (see https://github.com/ray-project/ray/issues/6130). - @ray.remote - def g(): - return nonexistent() - - with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"): - ray.get(g.remote()) - - def nonexistent(): - return 1 - - # Redefine the function and make sure it succeeds. - @ray.remote - def g(): - return nonexistent() - - assert ray.get(g.remote()) == 1 - - # Check the same thing but when the redefined function is inside of another - # task. - @ray.remote - def h(i): - @ray.remote - def j(): - return i - - return j.remote() - - for i in range(20): - assert ray.get(ray.get(h.remote(i))) == i - - -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_get_multiple(ray_start_regular): - object_ids = [ray.put(i) for i in range(10)] - assert ray.get(object_ids) == list(range(10)) - - # Get a random choice of object IDs with duplicates. - indices = list(np.random.choice(range(10), 5)) - indices += indices - results = ray.get([object_ids[i] for i in indices]) - assert results == indices - - -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_get_multiple_experimental(ray_start_regular): - object_ids = [ray.put(i) for i in range(10)] - - object_ids_tuple = tuple(object_ids) - assert ray.experimental.get(object_ids_tuple) == list(range(10)) - - object_ids_nparray = np.array(object_ids) - assert ray.experimental.get(object_ids_nparray) == list(range(10)) - - -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -def test_get_dict(ray_start_regular): - d = {str(i): ray.put(i) for i in range(5)} - for i in range(5, 10): - d[str(i)] = i - result = ray.experimental.get(d) - expected = {str(i): i for i in range(10)} - assert result == expected - - -def test_get_with_timeout(ray_start_regular): - signal = ray.test_utils.SignalActor.remote() - - # Check that get() returns early if object is ready. - start = time.time() - ray.get(signal.wait.remote(should_wait=False), timeout=30) - assert time.time() - start < 30 - - # Check that get() raises a TimeoutError after the timeout if the object - # is not ready yet. - result_id = signal.wait.remote() - with pytest.raises(RayTimeoutError): - ray.get(result_id, timeout=0.1) - - # Check that a subsequent get() returns early. - ray.get(signal.send.remote()) - start = time.time() - ray.get(result_id, timeout=30) - assert time.time() - start < 30 - - -@pytest.mark.parametrize( - "ray_start_regular", [{ - "local_mode": True - }, { - "local_mode": False - }], - indirect=True) -# https://github.com/ray-project/ray/issues/6329 -def test_call_actors_indirect_through_tasks(ray_start_regular): - @ray.remote - class Counter: - def __init__(self, value): - self.value = int(value) - - def increase(self, delta): - self.value += int(delta) - return self.value - - @ray.remote - def foo(object): - return ray.get(object.increase.remote(1)) - - @ray.remote - def bar(object): - return ray.get(object.increase.remote(1)) - - @ray.remote - def zoo(object): - return ray.get(object[0].increase.remote(1)) - - c = Counter.remote(0) - for _ in range(0, 100): - ray.get(foo.remote(c)) - ray.get(bar.remote(c)) - ray.get(zoo.remote([c])) - - -def test_call_matrix(shutdown_only): - ray.init(object_store_memory=1000 * 1024 * 1024) - - @ray.remote - class Actor: - def small_value(self): - return 0 - - def large_value(self): - return np.zeros(10 * 1024 * 1024) - - def echo(self, x): - if isinstance(x, list): - x = ray.get(x[0]) - return x - - @ray.remote - def small_value(): - return 0 - - @ray.remote - def large_value(): - return np.zeros(10 * 1024 * 1024) - - @ray.remote - def echo(x): - if isinstance(x, list): - x = ray.get(x[0]) - return x - - def check(source_actor, dest_actor, is_large, out_of_band): - print("CHECKING", "actor" if source_actor else "task", "to", "actor" - if dest_actor else "task", "large_object" - if is_large else "small_object", "out_of_band" - if out_of_band else "in_band") - if source_actor: - a = Actor.remote() - if is_large: - x_id = a.large_value.remote() - else: - x_id = a.small_value.remote() - else: - if is_large: - x_id = large_value.remote() - else: - x_id = small_value.remote() - if out_of_band: - x_id = [x_id] - if dest_actor: - b = Actor.remote() - x = ray.get(b.echo.remote(x_id)) - else: - x = ray.get(echo.remote(x_id)) - if is_large: - assert isinstance(x, np.ndarray) - else: - assert isinstance(x, int) - - for is_large in [False, True]: - for source_actor in [False, True]: - for dest_actor in [False, True]: - for out_of_band in [False, True]: - check(source_actor, dest_actor, is_large, out_of_band) - - -@pytest.mark.parametrize( - "ray_start_cluster", [{ - "num_cpus": 1, - "num_nodes": 1, - }, { - "num_cpus": 1, - "num_nodes": 2, - }], - indirect=True) -def test_call_chain(ray_start_cluster): - @ray.remote - def g(x): - return x + 1 - - x = 0 - for _ in range(100): - x = g.remote(x) - assert ray.get(x) == 100 - - -def test_inline_arg_memory_corruption(ray_start_regular): - @ray.remote - def f(): - return np.zeros(1000, dtype=np.uint8) - - @ray.remote - class Actor: - def __init__(self): - self.z = [] - - def add(self, x): - self.z.append(x) - for prev in self.z: - assert np.sum(prev) == 0, ("memory corruption detected", prev) - - a = Actor.remote() - for i in range(100): - ray.get(a.add.remote(f.remote())) - - -def test_skip_plasma(ray_start_regular): - @ray.remote - class Actor: - def __init__(self): - pass - - def f(self, x): - return x * 2 - - a = Actor.remote() - obj_id = a.f.remote(1) - # it is not stored in plasma - assert not ray.worker.global_worker.core_worker.object_exists(obj_id) - assert ray.get(obj_id) == 2 - - -def test_actor_call_order(shutdown_only): - ray.init(num_cpus=4) - - @ray.remote - def small_value(): - time.sleep(0.01 * np.random.randint(0, 10)) - return 0 - - @ray.remote - class Actor: - def __init__(self): - self.count = 0 - - def inc(self, count, dependency): - assert count == self.count - self.count += 1 - return count - - a = Actor.remote() - assert ray.get([a.inc.remote(i, small_value.remote()) - for i in range(100)]) == list(range(100)) - - -def test_actor_large_objects(ray_start_regular): - @ray.remote - class Actor: - def __init__(self): - pass - - def f(self): - time.sleep(1) - return np.zeros(10000000) - - a = Actor.remote() - obj_id = a.f.remote() - assert not ray.worker.global_worker.core_worker.object_exists(obj_id) - done, _ = ray.wait([obj_id]) - assert len(done) == 1 - assert ray.worker.global_worker.core_worker.object_exists(obj_id) - assert isinstance(ray.get(obj_id), np.ndarray) - - -def test_actor_pass_by_ref(ray_start_regular): - @ray.remote - class Actor: - def __init__(self): - pass - - def f(self, x): - return x * 2 - - @ray.remote - def f(x): - return x - - @ray.remote - def error(): - sys.exit(0) - - a = Actor.remote() - assert ray.get(a.f.remote(f.remote(1))) == 2 - - fut = [a.f.remote(f.remote(i)) for i in range(100)] - assert ray.get(fut) == [i * 2 for i in range(100)] - - # propagates errors for pass by ref - with pytest.raises(Exception): - ray.get(a.f.remote(error.remote())) - - -def test_actor_pass_by_ref_order_optimization(shutdown_only): - ray.init(num_cpus=4) - - @ray.remote - class Actor: - def __init__(self): - pass - - def f(self, x): - pass - - a = Actor.remote() - - @ray.remote - def fast_value(): - print("fast value") - pass - - @ray.remote - def slow_value(): - print("start sleep") - time.sleep(30) - - @ray.remote - def runner(f): - print("runner", a, f) - return ray.get(a.f.remote(f.remote())) - - runner.remote(slow_value) - time.sleep(1) - x2 = runner.remote(fast_value) - start = time.time() - ray.get(x2) - delta = time.time() - start - assert delta < 10, "did not skip slow value" - - -def test_actor_recursive(ray_start_regular): - @ray.remote - class Actor: - def __init__(self, delegate=None): - self.delegate = delegate - - def f(self, x): - if self.delegate: - return ray.get(self.delegate.f.remote(x)) - return x * 2 - - a = Actor.remote() - b = Actor.remote(a) - c = Actor.remote(b) - - result = ray.get([c.f.remote(i) for i in range(100)]) - assert result == [x * 2 for x in range(100)] - - result, _ = ray.wait([c.f.remote(i) for i in range(100)], num_returns=100) - result = ray.get(result) - assert result == [x * 2 for x in range(100)] - - -def test_actor_concurrent(ray_start_regular): - @ray.remote - class Batcher: - def __init__(self): - self.batch = [] - self.event = threading.Event() - - def add(self, x): - self.batch.append(x) - if len(self.batch) >= 3: - self.event.set() - else: - self.event.wait() - return sorted(self.batch) - - a = Batcher.options(max_concurrency=3).remote() - x1 = a.add.remote(1) - x2 = a.add.remote(2) - x3 = a.add.remote(3) - r1 = ray.get(x1) - r2 = ray.get(x2) - r3 = ray.get(x3) - assert r1 == [1, 2, 3] - assert r1 == r2 == r3 - - -def test_wait(ray_start_regular): - @ray.remote - def f(delay): - time.sleep(delay) - return - - object_ids = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)] - ready_ids, remaining_ids = ray.wait(object_ids) - assert len(ready_ids) == 1 - assert len(remaining_ids) == 3 - ready_ids, remaining_ids = ray.wait(object_ids, num_returns=4) - assert set(ready_ids) == set(object_ids) - assert remaining_ids == [] - - object_ids = [f.remote(0), f.remote(5)] - ready_ids, remaining_ids = ray.wait(object_ids, timeout=0.5, num_returns=2) - assert len(ready_ids) == 1 - assert len(remaining_ids) == 1 - - # Verify that calling wait with duplicate object IDs throws an - # exception. - x = ray.put(1) - with pytest.raises(Exception): - ray.wait([x, x]) - - # Make sure it is possible to call wait with an empty list. - ready_ids, remaining_ids = ray.wait([]) - assert ready_ids == [] - assert remaining_ids == [] - - # Test semantics of num_returns with no timeout. - oids = [ray.put(i) for i in range(10)] - (found, rest) = ray.wait(oids, num_returns=2) - assert len(found) == 2 - assert len(rest) == 8 - - # Verify that incorrect usage raises a TypeError. - x = ray.put(1) - with pytest.raises(TypeError): - ray.wait(x) - with pytest.raises(TypeError): - ray.wait(1) - with pytest.raises(TypeError): - ray.wait([1]) - - -def test_duplicate_args(ray_start_regular): - @ray.remote - def f(arg1, - arg2, - arg1_duplicate, - kwarg1=None, - kwarg2=None, - kwarg1_duplicate=None): - assert arg1 == kwarg1 - assert arg1 != arg2 - assert arg1 == arg1_duplicate - assert kwarg1 != kwarg2 - assert kwarg1 == kwarg1_duplicate - - # Test by-value arguments. - arg1 = [1] - arg2 = [2] - ray.get( - f.remote( - arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) - - # Test by-reference arguments. - arg1 = ray.put([1]) - arg2 = ray.put([2]) - ray.get( - f.remote( - arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) - - -def test_internal_config_when_connecting(ray_start_cluster): - config = json.dumps({ - "object_pinning_enabled": 0, - "initial_reconstruction_timeout_milliseconds": 200 - }) - cluster = ray.cluster_utils.Cluster() - cluster.add_node( - _internal_config=config, object_store_memory=100 * 1024 * 1024) - cluster.wait_for_nodes() - - # Specifying _internal_config when connecting to a cluster is disallowed. - with pytest.raises(ValueError): - ray.init(address=cluster.address, _internal_config=config) - - # Check that the config was picked up (object pinning is disabled). - ray.init(address=cluster.address) - oid = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8)) - - for _ in range(5): - ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8)) - - # This would not raise an exception if object pinning was enabled. - with pytest.raises(ray.exceptions.UnreconstructableError): - ray.get(oid) - - if __name__ == "__main__": import pytest sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_basic_2.py b/python/ray/tests/test_basic_2.py new file mode 100644 index 000000000..a0a4c89a3 --- /dev/null +++ b/python/ray/tests/test_basic_2.py @@ -0,0 +1,678 @@ +# coding: utf-8 +import json +import logging +import sys +import threading +import time + +import numpy as np +import pytest + +import ray +import ray.cluster_utils +import ray.test_utils +from ray.exceptions import RayTimeoutError + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "shutdown_only", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_variable_number_of_args(shutdown_only): + @ray.remote + def varargs_fct1(*a): + return " ".join(map(str, a)) + + @ray.remote + def varargs_fct2(a, *b): + return " ".join(map(str, b)) + + ray.init(num_cpus=1) + + x = varargs_fct1.remote(0, 1, 2) + assert ray.get(x) == "0 1 2" + x = varargs_fct2.remote(0, 1, 2) + assert ray.get(x) == "1 2" + + @ray.remote + def f1(*args): + return args + + @ray.remote + def f2(x, y, *args): + return x, y, args + + assert ray.get(f1.remote()) == () + assert ray.get(f1.remote(1)) == (1, ) + assert ray.get(f1.remote(1, 2, 3)) == (1, 2, 3) + with pytest.raises(Exception): + f2.remote() + with pytest.raises(Exception): + f2.remote(1) + assert ray.get(f2.remote(1, 2)) == (1, 2, ()) + assert ray.get(f2.remote(1, 2, 3)) == (1, 2, (3, )) + assert ray.get(f2.remote(1, 2, 3, 4)) == (1, 2, (3, 4)) + + def testNoArgs(self): + @ray.remote + def no_op(): + pass + + self.ray_start() + + ray.get(no_op.remote()) + + +@pytest.mark.parametrize( + "shutdown_only", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_defining_remote_functions(shutdown_only): + ray.init(num_cpus=3) + + # Test that we can close over plain old data. + data = [ + np.zeros([3, 5]), (1, 2, "a"), [0.0, 1.0, 1 << 62], 1 << 60, { + "a": np.zeros(3) + } + ] + + @ray.remote + def g(): + return data + + ray.get(g.remote()) + + # Test that we can close over modules. + @ray.remote + def h(): + return np.zeros([3, 5]) + + assert np.alltrue(ray.get(h.remote()) == np.zeros([3, 5])) + + @ray.remote + def j(): + return time.time() + + ray.get(j.remote()) + + # Test that we can define remote functions that call other remote + # functions. + @ray.remote + def k(x): + return x + 1 + + @ray.remote + def k2(x): + return ray.get(k.remote(x)) + + @ray.remote + def m(x): + return ray.get(k2.remote(x)) + + assert ray.get(k.remote(1)) == 2 + assert ray.get(k2.remote(1)) == 2 + assert ray.get(m.remote(1)) == 2 + + +@pytest.mark.parametrize( + "shutdown_only", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_redefining_remote_functions(shutdown_only): + ray.init(num_cpus=1) + + # Test that we can define a remote function in the shell. + @ray.remote + def f(x): + return x + 1 + + assert ray.get(f.remote(0)) == 1 + + # Test that we can redefine the remote function. + @ray.remote + def f(x): + return x + 10 + + while True: + val = ray.get(f.remote(0)) + assert val in [1, 10] + if val == 10: + break + else: + logger.info("Still using old definition of f, trying again.") + + # Check that we can redefine functions even when the remote function source + # doesn't change (see https://github.com/ray-project/ray/issues/6130). + @ray.remote + def g(): + return nonexistent() + + with pytest.raises(ray.exceptions.RayTaskError, match="nonexistent"): + ray.get(g.remote()) + + def nonexistent(): + return 1 + + # Redefine the function and make sure it succeeds. + @ray.remote + def g(): + return nonexistent() + + assert ray.get(g.remote()) == 1 + + # Check the same thing but when the redefined function is inside of another + # task. + @ray.remote + def h(i): + @ray.remote + def j(): + return i + + return j.remote() + + for i in range(20): + assert ray.get(ray.get(h.remote(i))) == i + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_get_multiple(ray_start_regular): + object_ids = [ray.put(i) for i in range(10)] + assert ray.get(object_ids) == list(range(10)) + + # Get a random choice of object IDs with duplicates. + indices = list(np.random.choice(range(10), 5)) + indices += indices + results = ray.get([object_ids[i] for i in indices]) + assert results == indices + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_get_multiple_experimental(ray_start_regular): + object_ids = [ray.put(i) for i in range(10)] + + object_ids_tuple = tuple(object_ids) + assert ray.experimental.get(object_ids_tuple) == list(range(10)) + + object_ids_nparray = np.array(object_ids) + assert ray.experimental.get(object_ids_nparray) == list(range(10)) + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_get_dict(ray_start_regular): + d = {str(i): ray.put(i) for i in range(5)} + for i in range(5, 10): + d[str(i)] = i + result = ray.experimental.get(d) + expected = {str(i): i for i in range(10)} + assert result == expected + + +def test_get_with_timeout(ray_start_regular): + signal = ray.test_utils.SignalActor.remote() + + # Check that get() returns early if object is ready. + start = time.time() + ray.get(signal.wait.remote(should_wait=False), timeout=30) + assert time.time() - start < 30 + + # Check that get() raises a TimeoutError after the timeout if the object + # is not ready yet. + result_id = signal.wait.remote() + with pytest.raises(RayTimeoutError): + ray.get(result_id, timeout=0.1) + + # Check that a subsequent get() returns early. + ray.get(signal.send.remote()) + start = time.time() + ray.get(result_id, timeout=30) + assert time.time() - start < 30 + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +# https://github.com/ray-project/ray/issues/6329 +def test_call_actors_indirect_through_tasks(ray_start_regular): + @ray.remote + class Counter: + def __init__(self, value): + self.value = int(value) + + def increase(self, delta): + self.value += int(delta) + return self.value + + @ray.remote + def foo(object): + return ray.get(object.increase.remote(1)) + + @ray.remote + def bar(object): + return ray.get(object.increase.remote(1)) + + @ray.remote + def zoo(object): + return ray.get(object[0].increase.remote(1)) + + c = Counter.remote(0) + for _ in range(0, 100): + ray.get(foo.remote(c)) + ray.get(bar.remote(c)) + ray.get(zoo.remote([c])) + + +def test_call_matrix(shutdown_only): + ray.init(object_store_memory=1000 * 1024 * 1024) + + @ray.remote + class Actor: + def small_value(self): + return 0 + + def large_value(self): + return np.zeros(10 * 1024 * 1024) + + def echo(self, x): + if isinstance(x, list): + x = ray.get(x[0]) + return x + + @ray.remote + def small_value(): + return 0 + + @ray.remote + def large_value(): + return np.zeros(10 * 1024 * 1024) + + @ray.remote + def echo(x): + if isinstance(x, list): + x = ray.get(x[0]) + return x + + def check(source_actor, dest_actor, is_large, out_of_band): + print("CHECKING", "actor" if source_actor else "task", "to", "actor" + if dest_actor else "task", "large_object" + if is_large else "small_object", "out_of_band" + if out_of_band else "in_band") + if source_actor: + a = Actor.remote() + if is_large: + x_id = a.large_value.remote() + else: + x_id = a.small_value.remote() + else: + if is_large: + x_id = large_value.remote() + else: + x_id = small_value.remote() + if out_of_band: + x_id = [x_id] + if dest_actor: + b = Actor.remote() + x = ray.get(b.echo.remote(x_id)) + else: + x = ray.get(echo.remote(x_id)) + if is_large: + assert isinstance(x, np.ndarray) + else: + assert isinstance(x, int) + + for is_large in [False, True]: + for source_actor in [False, True]: + for dest_actor in [False, True]: + for out_of_band in [False, True]: + check(source_actor, dest_actor, is_large, out_of_band) + + +@pytest.mark.parametrize( + "ray_start_cluster", [{ + "num_cpus": 1, + "num_nodes": 1, + }, { + "num_cpus": 1, + "num_nodes": 2, + }], + indirect=True) +def test_call_chain(ray_start_cluster): + @ray.remote + def g(x): + return x + 1 + + x = 0 + for _ in range(100): + x = g.remote(x) + assert ray.get(x) == 100 + + +def test_inline_arg_memory_corruption(ray_start_regular): + @ray.remote + def f(): + return np.zeros(1000, dtype=np.uint8) + + @ray.remote + class Actor: + def __init__(self): + self.z = [] + + def add(self, x): + self.z.append(x) + for prev in self.z: + assert np.sum(prev) == 0, ("memory corruption detected", prev) + + a = Actor.remote() + for i in range(100): + ray.get(a.add.remote(f.remote())) + + +def test_skip_plasma(ray_start_regular): + @ray.remote + class Actor: + def __init__(self): + pass + + def f(self, x): + return x * 2 + + a = Actor.remote() + obj_id = a.f.remote(1) + # it is not stored in plasma + assert not ray.worker.global_worker.core_worker.object_exists(obj_id) + assert ray.get(obj_id) == 2 + + +def test_actor_call_order(shutdown_only): + ray.init(num_cpus=4) + + @ray.remote + def small_value(): + time.sleep(0.01 * np.random.randint(0, 10)) + return 0 + + @ray.remote + class Actor: + def __init__(self): + self.count = 0 + + def inc(self, count, dependency): + assert count == self.count + self.count += 1 + return count + + a = Actor.remote() + assert ray.get([a.inc.remote(i, small_value.remote()) + for i in range(100)]) == list(range(100)) + + +def test_actor_large_objects(ray_start_regular): + @ray.remote + class Actor: + def __init__(self): + pass + + def f(self): + time.sleep(1) + return np.zeros(10000000) + + a = Actor.remote() + obj_id = a.f.remote() + assert not ray.worker.global_worker.core_worker.object_exists(obj_id) + done, _ = ray.wait([obj_id]) + assert len(done) == 1 + assert ray.worker.global_worker.core_worker.object_exists(obj_id) + assert isinstance(ray.get(obj_id), np.ndarray) + + +def test_actor_pass_by_ref(ray_start_regular): + @ray.remote + class Actor: + def __init__(self): + pass + + def f(self, x): + return x * 2 + + @ray.remote + def f(x): + return x + + @ray.remote + def error(): + sys.exit(0) + + a = Actor.remote() + assert ray.get(a.f.remote(f.remote(1))) == 2 + + fut = [a.f.remote(f.remote(i)) for i in range(100)] + assert ray.get(fut) == [i * 2 for i in range(100)] + + # propagates errors for pass by ref + with pytest.raises(Exception): + ray.get(a.f.remote(error.remote())) + + +def test_actor_pass_by_ref_order_optimization(shutdown_only): + ray.init(num_cpus=4) + + @ray.remote + class Actor: + def __init__(self): + pass + + def f(self, x): + pass + + a = Actor.remote() + + @ray.remote + def fast_value(): + print("fast value") + pass + + @ray.remote + def slow_value(): + print("start sleep") + time.sleep(30) + + @ray.remote + def runner(f): + print("runner", a, f) + return ray.get(a.f.remote(f.remote())) + + runner.remote(slow_value) + time.sleep(1) + x2 = runner.remote(fast_value) + start = time.time() + ray.get(x2) + delta = time.time() - start + assert delta < 10, "did not skip slow value" + + +def test_actor_recursive(ray_start_regular): + @ray.remote + class Actor: + def __init__(self, delegate=None): + self.delegate = delegate + + def f(self, x): + if self.delegate: + return ray.get(self.delegate.f.remote(x)) + return x * 2 + + a = Actor.remote() + b = Actor.remote(a) + c = Actor.remote(b) + + result = ray.get([c.f.remote(i) for i in range(100)]) + assert result == [x * 2 for x in range(100)] + + result, _ = ray.wait([c.f.remote(i) for i in range(100)], num_returns=100) + result = ray.get(result) + assert result == [x * 2 for x in range(100)] + + +def test_actor_concurrent(ray_start_regular): + @ray.remote + class Batcher: + def __init__(self): + self.batch = [] + self.event = threading.Event() + + def add(self, x): + self.batch.append(x) + if len(self.batch) >= 3: + self.event.set() + else: + self.event.wait() + return sorted(self.batch) + + a = Batcher.options(max_concurrency=3).remote() + x1 = a.add.remote(1) + x2 = a.add.remote(2) + x3 = a.add.remote(3) + r1 = ray.get(x1) + r2 = ray.get(x2) + r3 = ray.get(x3) + assert r1 == [1, 2, 3] + assert r1 == r2 == r3 + + +def test_wait(ray_start_regular): + @ray.remote + def f(delay): + time.sleep(delay) + return + + object_ids = [f.remote(0), f.remote(0), f.remote(0), f.remote(0)] + ready_ids, remaining_ids = ray.wait(object_ids) + assert len(ready_ids) == 1 + assert len(remaining_ids) == 3 + ready_ids, remaining_ids = ray.wait(object_ids, num_returns=4) + assert set(ready_ids) == set(object_ids) + assert remaining_ids == [] + + object_ids = [f.remote(0), f.remote(5)] + ready_ids, remaining_ids = ray.wait(object_ids, timeout=0.5, num_returns=2) + assert len(ready_ids) == 1 + assert len(remaining_ids) == 1 + + # Verify that calling wait with duplicate object IDs throws an + # exception. + x = ray.put(1) + with pytest.raises(Exception): + ray.wait([x, x]) + + # Make sure it is possible to call wait with an empty list. + ready_ids, remaining_ids = ray.wait([]) + assert ready_ids == [] + assert remaining_ids == [] + + # Test semantics of num_returns with no timeout. + oids = [ray.put(i) for i in range(10)] + (found, rest) = ray.wait(oids, num_returns=2) + assert len(found) == 2 + assert len(rest) == 8 + + # Verify that incorrect usage raises a TypeError. + x = ray.put(1) + with pytest.raises(TypeError): + ray.wait(x) + with pytest.raises(TypeError): + ray.wait(1) + with pytest.raises(TypeError): + ray.wait([1]) + + +def test_duplicate_args(ray_start_regular): + @ray.remote + def f(arg1, + arg2, + arg1_duplicate, + kwarg1=None, + kwarg2=None, + kwarg1_duplicate=None): + assert arg1 == kwarg1 + assert arg1 != arg2 + assert arg1 == arg1_duplicate + assert kwarg1 != kwarg2 + assert kwarg1 == kwarg1_duplicate + + # Test by-value arguments. + arg1 = [1] + arg2 = [2] + ray.get( + f.remote( + arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) + + # Test by-reference arguments. + arg1 = ray.put([1]) + arg2 = ray.put([2]) + ray.get( + f.remote( + arg1, arg2, arg1, kwarg1=arg1, kwarg2=arg2, kwarg1_duplicate=arg1)) + + +def test_internal_config_when_connecting(ray_start_cluster): + config = json.dumps({ + "object_pinning_enabled": 0, + "initial_reconstruction_timeout_milliseconds": 200 + }) + cluster = ray.cluster_utils.Cluster() + cluster.add_node( + _internal_config=config, object_store_memory=100 * 1024 * 1024) + cluster.wait_for_nodes() + + # Specifying _internal_config when connecting to a cluster is disallowed. + with pytest.raises(ValueError): + ray.init(address=cluster.address, _internal_config=config) + + # Check that the config was picked up (object pinning is disabled). + ray.init(address=cluster.address) + oid = ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8)) + + for _ in range(5): + ray.put(np.zeros(40 * 1024 * 1024, dtype=np.uint8)) + + # This would not raise an exception if object pinning was enabled. + with pytest.raises(ray.exceptions.UnreconstructableError): + ray.get(oid) + + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/tests/test_serialization.py b/python/ray/tests/test_serialization.py new file mode 100644 index 000000000..10a9ccdaa --- /dev/null +++ b/python/ray/tests/test_serialization.py @@ -0,0 +1,526 @@ +# coding: utf-8 +import collections +import io +import logging +import re +import string +import sys + +import numpy as np +import pytest + +import ray +import ray.cluster_utils +import ray.test_utils + +logger = logging.getLogger(__name__) + + +def is_named_tuple(cls): + """Return True if cls is a namedtuple and False otherwise.""" + b = cls.__bases__ + if len(b) != 1 or b[0] != tuple: + return False + f = getattr(cls, "_fields", None) + if not isinstance(f, tuple): + return False + return all(type(n) == str for n in f) + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_simple_serialization(ray_start_regular): + primitive_objects = [ + # Various primitive types. + 0, + 0.0, + 0.9, + 1 << 62, + 1 << 999, + b"", + b"a", + "a", + string.printable, + "\u262F", + u"hello world", + u"\xff\xfe\x9c\x001\x000\x00", + None, + True, + False, + [], + (), + {}, + type, + int, + set(), + # Collections types. + collections.Counter([np.random.randint(0, 10) for _ in range(100)]), + collections.OrderedDict([("hello", 1), ("world", 2)]), + collections.defaultdict(lambda: 0, [("hello", 1), ("world", 2)]), + collections.defaultdict(lambda: [], [("hello", 1), ("world", 2)]), + collections.deque([1, 2, 3, "a", "b", "c", 3.5]), + # Numpy dtypes. + np.int8(3), + np.int32(4), + np.int64(5), + np.uint8(3), + np.uint32(4), + np.uint64(5), + np.float32(1.9), + np.float64(1.9), + ] + + composite_objects = ( + [[obj] + for obj in primitive_objects] + [(obj, ) + for obj in primitive_objects] + [{ + (): obj + } for obj in primitive_objects]) + + @ray.remote + def f(x): + return x + + # Check that we can pass arguments by value to remote functions and + # that they are uncorrupted. + for obj in primitive_objects + composite_objects: + new_obj_1 = ray.get(f.remote(obj)) + new_obj_2 = ray.get(ray.put(obj)) + assert obj == new_obj_1 + assert obj == new_obj_2 + # TODO(rkn): The numpy dtypes currently come back as regular integers + # or floats. + if type(obj).__module__ != "numpy": + assert type(obj) == type(new_obj_1) + assert type(obj) == type(new_obj_2) + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_complex_serialization(ray_start_regular): + def assert_equal(obj1, obj2): + module_numpy = (type(obj1).__module__ == np.__name__ + or type(obj2).__module__ == np.__name__) + if module_numpy: + empty_shape = ((hasattr(obj1, "shape") and obj1.shape == ()) + or (hasattr(obj2, "shape") and obj2.shape == ())) + if empty_shape: + # This is a special case because currently + # np.testing.assert_equal fails because we do not properly + # handle different numerical types. + assert obj1 == obj2, ("Objects {} and {} are " + "different.".format(obj1, obj2)) + else: + np.testing.assert_equal(obj1, obj2) + elif hasattr(obj1, "__dict__") and hasattr(obj2, "__dict__"): + special_keys = ["_pytype_"] + assert (set(list(obj1.__dict__.keys()) + special_keys) == set( + list(obj2.__dict__.keys()) + special_keys)), ( + "Objects {} and {} are different.".format(obj1, obj2)) + for key in obj1.__dict__.keys(): + if key not in special_keys: + assert_equal(obj1.__dict__[key], obj2.__dict__[key]) + elif type(obj1) is dict or type(obj2) is dict: + assert_equal(obj1.keys(), obj2.keys()) + for key in obj1.keys(): + assert_equal(obj1[key], obj2[key]) + elif type(obj1) is list or type(obj2) is list: + assert len(obj1) == len(obj2), ("Objects {} and {} are lists with " + "different lengths.".format( + obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif type(obj1) is tuple or type(obj2) is tuple: + assert len(obj1) == len(obj2), ("Objects {} and {} are tuples " + "with different lengths.".format( + obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + elif (is_named_tuple(type(obj1)) or is_named_tuple(type(obj2))): + assert len(obj1) == len(obj2), ( + "Objects {} and {} are named " + "tuples with different lengths.".format(obj1, obj2)) + for i in range(len(obj1)): + assert_equal(obj1[i], obj2[i]) + else: + assert obj1 == obj2, "Objects {} and {} are different.".format( + obj1, obj2) + + long_extras = [0, np.array([["hi", u"hi"], [1.3, 1]])] + + PRIMITIVE_OBJECTS = [ + 0, 0.0, 0.9, 1 << 62, 1 << 100, 1 << 999, [1 << 100, [1 << 100]], "a", + string.printable, "\u262F", u"hello world", + u"\xff\xfe\x9c\x001\x000\x00", None, True, False, [], (), {}, + np.int8(3), + np.int32(4), + np.int64(5), + np.uint8(3), + np.uint32(4), + np.uint64(5), + np.float32(1.9), + np.float64(1.9), + np.zeros([100, 100]), + np.random.normal(size=[100, 100]), + np.array(["hi", 3]), + np.array(["hi", 3], dtype=object) + ] + long_extras + + COMPLEX_OBJECTS = [ + [[[[[[[[[[[[]]]]]]]]]]]], + { + "obj{}".format(i): np.random.normal(size=[100, 100]) + for i in range(10) + }, + # {(): {(): {(): {(): {(): {(): {(): {(): {(): {(): { + # (): {(): {}}}}}}}}}}}}}, + ( + (((((((((), ), ), ), ), ), ), ), ), ), + { + "a": { + "b": { + "c": { + "d": {} + } + } + } + }, + ] + + class Foo: + def __init__(self, value=0): + self.value = value + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + + class Bar: + def __init__(self): + for i, val in enumerate(PRIMITIVE_OBJECTS + COMPLEX_OBJECTS): + setattr(self, "field{}".format(i), val) + + class Baz: + def __init__(self): + self.foo = Foo() + self.bar = Bar() + + def method(self, arg): + pass + + class Qux: + def __init__(self): + self.objs = [Foo(), Bar(), Baz()] + + class SubQux(Qux): + def __init__(self): + Qux.__init__(self) + + class CustomError(Exception): + pass + + Point = collections.namedtuple("Point", ["x", "y"]) + NamedTupleExample = collections.namedtuple( + "Example", "field1, field2, field3, field4, field5") + + CUSTOM_OBJECTS = [ + Exception("Test object."), + CustomError(), + Point(11, y=22), + Foo(), + Bar(), + Baz(), # Qux(), SubQux(), + NamedTupleExample(1, 1.0, "hi", np.zeros([3, 5]), [1, 2, 3]), + ] + + # Test dataclasses in Python 3.7. + if sys.version_info >= (3, 7): + from dataclasses import make_dataclass + + DataClass0 = make_dataclass("DataClass0", [("number", int)]) + + CUSTOM_OBJECTS.append(DataClass0(number=3)) + + class CustomClass: + def __init__(self, value): + self.value = value + + DataClass1 = make_dataclass("DataClass1", [("custom", CustomClass)]) + + class DataClass2(DataClass1): + @classmethod + def from_custom(cls, data): + custom = CustomClass(data) + return cls(custom) + + def __reduce__(self): + return (self.from_custom, (self.custom.value, )) + + CUSTOM_OBJECTS.append(DataClass2(custom=CustomClass(43))) + + BASE_OBJECTS = PRIMITIVE_OBJECTS + COMPLEX_OBJECTS + CUSTOM_OBJECTS + + LIST_OBJECTS = [[obj] for obj in BASE_OBJECTS] + TUPLE_OBJECTS = [(obj, ) for obj in BASE_OBJECTS] + # The check that type(obj).__module__ != "numpy" should be unnecessary, but + # otherwise this seems to fail on Mac OS X on Travis. + DICT_OBJECTS = ([{ + obj: obj + } for obj in PRIMITIVE_OBJECTS if ( + obj.__hash__ is not None and type(obj).__module__ != "numpy")] + [{ + 0: obj + } for obj in BASE_OBJECTS] + [{ + Foo(123): Foo(456) + }]) + + RAY_TEST_OBJECTS = ( + BASE_OBJECTS + LIST_OBJECTS + TUPLE_OBJECTS + DICT_OBJECTS) + + @ray.remote + def f(x): + return x + + # Check that we can pass arguments by value to remote functions and + # that they are uncorrupted. + for obj in RAY_TEST_OBJECTS: + assert_equal(obj, ray.get(f.remote(obj))) + assert_equal(obj, ray.get(ray.put(obj))) + + # Test StringIO serialization + s = io.StringIO(u"Hello, world!\n") + s.seek(0) + line = s.readline() + s.seek(0) + assert ray.get(ray.put(s)).readline() == line + + +def test_numpy_serialization(ray_start_regular): + array = np.zeros(314) + from ray.cloudpickle import dumps + buffers = [] + inband = dumps(array, protocol=5, buffer_callback=buffers.append) + assert len(inband) < array.nbytes + assert len(buffers) == 1 + + +def test_numpy_subclass_serialization(ray_start_regular): + class MyNumpyConstant(np.ndarray): + def __init__(self, value): + super().__init__() + self.constant = value + + def __str__(self): + print(self.constant) + + constant = MyNumpyConstant(123) + + def explode(x): + raise RuntimeError("Expected error.") + + ray.register_custom_serializer( + type(constant), serializer=explode, deserializer=explode) + + try: + ray.put(constant) + assert False, "Should never get here!" + except (RuntimeError, IndexError): + print("Correct behavior, proof that customer serializer was used.") + + +def test_numpy_subclass_serialization_pickle(ray_start_regular): + class MyNumpyConstant(np.ndarray): + def __init__(self, value): + super().__init__() + self.constant = value + + def __str__(self): + print(self.constant) + + constant = MyNumpyConstant(123) + repr_orig = repr(constant) + repr_ser = repr(ray.get(ray.put(constant))) + assert repr_orig == repr_ser + + +@pytest.mark.parametrize( + "ray_start_regular", [{ + "local_mode": True + }, { + "local_mode": False + }], + indirect=True) +def test_serialization_final_fallback(ray_start_regular): + pytest.importorskip("catboost") + # This test will only run when "catboost" is installed. + from catboost import CatBoostClassifier + + model = CatBoostClassifier( + iterations=2, + depth=2, + learning_rate=1, + loss_function="Logloss", + logging_level="Verbose") + + reconstructed_model = ray.get(ray.put(model)) + assert set(model.get_params().items()) == set( + reconstructed_model.get_params().items()) + + +def test_register_class(ray_start_2_cpus): + # Check that putting an object of a class that has not been registered + # throws an exception. + class TempClass: + pass + + ray.get(ray.put(TempClass())) + + # Test passing custom classes into remote functions from the driver. + @ray.remote + def f(x): + return x + + class Foo: + def __init__(self, value=0): + self.value = value + + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + + foo = ray.get(f.remote(Foo(7))) + assert foo == Foo(7) + + regex = re.compile(r"\d+\.\d*") + new_regex = ray.get(f.remote(regex)) + # This seems to fail on the system Python 3 that comes with + # Ubuntu, so it is commented out for now: + # assert regex == new_regex + # Instead, we do this: + assert regex.pattern == new_regex.pattern + + class TempClass1: + def __init__(self): + self.value = 1 + + # Test returning custom classes created on workers. + @ray.remote + def g(): + class TempClass2: + def __init__(self): + self.value = 2 + + return TempClass1(), TempClass2() + + object_1, object_2 = ray.get(g.remote()) + assert object_1.value == 1 + assert object_2.value == 2 + + # Test exporting custom class definitions from one worker to another + # when the worker is blocked in a get. + class NewTempClass: + def __init__(self, value): + self.value = value + + @ray.remote + def h1(x): + return NewTempClass(x) + + @ray.remote + def h2(x): + return ray.get(h1.remote(x)) + + assert ray.get(h2.remote(10)).value == 10 + + # Test registering multiple classes with the same name. + @ray.remote(num_return_vals=3) + def j(): + class Class0: + def method0(self): + pass + + c0 = Class0() + + class Class0: + def method1(self): + pass + + c1 = Class0() + + class Class0: + def method2(self): + pass + + c2 = Class0() + + return c0, c1, c2 + + results = [] + for _ in range(5): + results += j.remote() + for i in range(len(results) // 3): + c0, c1, c2 = ray.get(results[(3 * i):(3 * (i + 1))]) + + c0.method0() + c1.method1() + c2.method2() + + assert not hasattr(c0, "method1") + assert not hasattr(c0, "method2") + assert not hasattr(c1, "method0") + assert not hasattr(c1, "method2") + assert not hasattr(c2, "method0") + assert not hasattr(c2, "method1") + + @ray.remote + def k(): + class Class0: + def method0(self): + pass + + c0 = Class0() + + class Class0: + def method1(self): + pass + + c1 = Class0() + + class Class0: + def method2(self): + pass + + c2 = Class0() + + return c0, c1, c2 + + results = ray.get([k.remote() for _ in range(5)]) + for c0, c1, c2 in results: + c0.method0() + c1.method1() + c2.method2() + + assert not hasattr(c0, "method1") + assert not hasattr(c0, "method2") + assert not hasattr(c1, "method0") + assert not hasattr(c1, "method2") + assert not hasattr(c2, "method0") + assert not hasattr(c2, "method1") + + +if __name__ == "__main__": + import pytest + sys.exit(pytest.main(["-v", __file__]))