Refactor pytest fixtures for ray core (#4390)

This commit is contained in:
Yuhong Guo
2019-03-20 11:48:32 +08:00
committed by Hao Chen
parent c6f15a0057
commit 8ce7565530
22 changed files with 378 additions and 681 deletions
+25 -81
View File
@@ -33,23 +33,7 @@ from ray.utils import _random_string
logger = logging.getLogger(__name__)
@pytest.fixture
def ray_start():
# Start the Ray processes.
ray.init(num_cpus=1)
yield None
# The code after the yield will run as teardown code.
ray.shutdown()
@pytest.fixture
def shutdown_only():
yield None
# The code after the yield will run as teardown code.
ray.shutdown()
def test_simple_serialization(ray_start):
def test_simple_serialization(ray_start_regular):
primitive_objects = [
# Various primitive types.
0,
@@ -116,7 +100,7 @@ def test_simple_serialization(ray_start):
assert type(obj) == type(new_obj_2)
def test_complex_serialization(ray_start):
def test_complex_serialization(ray_start_regular):
def assert_equal(obj1, obj2):
module_numpy = (type(obj1).__module__ == np.__name__
or type(obj2).__module__ == np.__name__)
@@ -319,7 +303,7 @@ def test_complex_serialization(ray_start):
assert_equal(obj, ray.get(ray.put(obj)))
def test_ray_recursive_objects(ray_start):
def test_ray_recursive_objects(ray_start_regular):
class ClassA(object):
pass
@@ -347,7 +331,7 @@ def test_ray_recursive_objects(ray_start):
ray.put(obj)
def test_passing_arguments_by_value_out_of_the_box(ray_start):
def test_passing_arguments_by_value_out_of_the_box(ray_start_regular):
@ray.remote
def f(x):
return x
@@ -379,7 +363,7 @@ def test_passing_arguments_by_value_out_of_the_box(ray_start):
ray.get(ray.put(Foo))
def test_putting_object_that_closes_over_object_id(ray_start):
def test_putting_object_that_closes_over_object_id(ray_start_regular):
# This test is here to prevent a regression of
# https://github.com/ray-project/ray/issues/1317.
@@ -422,9 +406,7 @@ def test_put_get(shutdown_only):
assert value_before == value_after
def test_custom_serializers(shutdown_only):
ray.init(num_cpus=1)
def test_custom_serializers(ray_start_regular):
class Foo(object):
def __init__(self):
self.x = 3
@@ -454,7 +436,7 @@ def test_custom_serializers(shutdown_only):
assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2")
def test_serialization_final_fallback(ray_start):
def test_serialization_final_fallback(ray_start_regular):
pytest.importorskip("catboost")
# This test will only run when "catboost" is installed.
from catboost import CatBoostClassifier
@@ -471,9 +453,7 @@ def test_serialization_final_fallback(ray_start):
reconstructed_model.get_params().items())
def test_register_class(shutdown_only):
ray.init(num_cpus=2)
def test_register_class(ray_start_2_cpus):
# Check that putting an object of a class that has not been registered
# throws an exception.
class TempClass(object):
@@ -616,7 +596,7 @@ def test_register_class(shutdown_only):
assert not hasattr(c2, "method1")
def test_keyword_args(shutdown_only):
def test_keyword_args(ray_start_regular):
@ray.remote
def keyword_fct1(a, b="hello"):
return "{} {}".format(a, b)
@@ -629,8 +609,6 @@ def test_keyword_args(shutdown_only):
def keyword_fct3(a, b, c="hello", d="world"):
return "{} {} {} {}".format(a, b, c, d)
ray.init(num_cpus=1)
x = keyword_fct1.remote(1)
assert ray.get(x) == "1 hello"
x = keyword_fct1.remote(1, "hi")
@@ -886,8 +864,7 @@ def test_submit_api(shutdown_only):
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
def test_get_multiple(shutdown_only):
ray.init(num_cpus=1)
def test_get_multiple(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
assert ray.get(object_ids) == list(range(10))
@@ -898,8 +875,7 @@ def test_get_multiple(shutdown_only):
assert results == indices
def test_get_multiple_experimental(shutdown_only):
ray.init(num_cpus=1)
def test_get_multiple_experimental(ray_start_regular):
object_ids = [ray.put(i) for i in range(10)]
object_ids_tuple = tuple(object_ids)
@@ -909,8 +885,7 @@ def test_get_multiple_experimental(shutdown_only):
assert ray.experimental.get(object_ids_nparray) == list(range(10))
def test_get_dict(shutdown_only):
ray.init(num_cpus=1)
def test_get_dict(ray_start_regular):
d = {str(i): ray.put(i) for i in range(5)}
for i in range(5, 10):
d[str(i)] = i
@@ -919,9 +894,7 @@ def test_get_dict(shutdown_only):
assert result == expected
def test_wait(shutdown_only):
ray.init(num_cpus=1)
def test_wait(ray_start_regular):
@ray.remote
def f(delay):
time.sleep(delay)
@@ -976,9 +949,7 @@ def test_wait(shutdown_only):
ray.wait([1])
def test_wait_iterables(shutdown_only):
ray.init(num_cpus=1)
def test_wait_iterables(ray_start_regular):
@ray.remote
def f(delay):
time.sleep(delay)
@@ -1075,9 +1046,7 @@ def test_caching_functions_to_run(shutdown_only):
ray.worker.global_worker.run_function_on_all_workers(f)
def test_running_function_on_all_workers(shutdown_only):
ray.init(num_cpus=1)
def test_running_function_on_all_workers(ray_start_regular):
def f(worker_info):
sys.path.append("fake_directory")
@@ -1104,9 +1073,7 @@ def test_running_function_on_all_workers(shutdown_only):
assert "fake_directory" not in ray.get(get_path2.remote())
def test_profiling_api(shutdown_only):
ray.init(num_cpus=2)
def test_profiling_api(ray_start_2_cpus):
@ray.remote
def f():
with ray.profile(
@@ -1150,16 +1117,6 @@ def test_profiling_api(shutdown_only):
break
@pytest.fixture()
def ray_start_cluster():
cluster = ray.tests.cluster_utils.Cluster()
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
cluster.shutdown()
def test_wait_cluster(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=1, resources={"RemoteResource": 1})
@@ -1227,10 +1184,9 @@ def test_object_transfer_dump(ray_start_cluster):
}) == num_nodes
def test_identical_function_names(shutdown_only):
def test_identical_function_names(ray_start_regular):
# Define a bunch of remote functions and make sure that we don't
# accidentally call an older version.
ray.init(num_cpus=1)
num_calls = 200
@@ -1294,8 +1250,7 @@ def test_identical_function_names(shutdown_only):
assert result_values == num_calls * [5]
def test_illegal_api_calls(shutdown_only):
ray.init(num_cpus=1)
def test_illegal_api_calls(ray_start_regular):
# Verify that we cannot call put on an ObjectID.
x = ray.put(1)
@@ -1310,10 +1265,9 @@ def test_illegal_api_calls(shutdown_only):
# because plasma client isn't thread-safe. This needs to be fixed from the
# Arrow side. See #4107 for relevant discussions.
@pytest.mark.skipif(six.PY2, reason="Doesn't work in Python 2.")
def test_multithreading(shutdown_only):
def test_multithreading(ray_start_2_cpus):
# This test requires at least 2 CPUs to finish since the worker does not
# release resources when joining the threads.
ray.init(num_cpus=2)
def run_test_in_multi_threads(test_case, num_threads=10, num_repeats=25):
"""A helper function that runs test cases in multiple threads."""
@@ -2273,9 +2227,7 @@ def test_specific_gpus(save_gpu_ids_shutdown_only):
ray.get([g.remote() for _ in range(100)])
def test_blocking_tasks(shutdown_only):
ray.init(num_cpus=1)
def test_blocking_tasks(ray_start_regular):
@ray.remote
def f(i, j):
return (i, j)
@@ -2310,9 +2262,7 @@ def test_blocking_tasks(shutdown_only):
ray.get(sleep.remote())
def test_max_call_tasks(shutdown_only):
ray.init(num_cpus=1)
def test_max_call_tasks(ray_start_regular):
@ray.remote(max_calls=1)
def f():
return os.getpid()
@@ -2692,9 +2642,7 @@ def test_wait_reconstruction(shutdown_only):
assert len(ready_ids) == 1
def test_ray_setproctitle(shutdown_only):
ray.init(num_cpus=2)
def test_ray_setproctitle(ray_start_2_cpus):
@ray.remote
class UniqueName(object):
def __init__(self):
@@ -2739,9 +2687,7 @@ def test_duplicate_error_messages(shutdown_only):
@pytest.mark.skipif(
os.getenv("TRAVIS") is None,
reason="This test should only be run on Travis.")
def test_ray_stack(shutdown_only):
ray.init(num_cpus=2)
def test_ray_stack(ray_start_2_cpus):
def unique_name_1():
time.sleep(1000)
@@ -2797,9 +2743,7 @@ def test_socket_dir_not_existing(shutdown_only):
ray.init(num_cpus=1, raylet_socket_name=temp_raylet_socket_name)
def test_raylet_is_robust_to_random_messages(shutdown_only):
ray.init(num_cpus=1)
def test_raylet_is_robust_to_random_messages(ray_start_regular):
node_manager_address = None
node_manager_port = None
for client in ray.global_state.client_table():
@@ -2820,7 +2764,7 @@ def test_raylet_is_robust_to_random_messages(shutdown_only):
assert ray.get(f.remote()) == 1
def test_non_ascii_comment(ray_start):
def test_non_ascii_comment(ray_start_regular):
@ray.remote
def f():
# 日本語 Japanese comment