mirror of
https://github.com/wassname/ray.git
synced 2026-07-01 20:40:56 +08:00
Refactor pytest fixtures for ray core (#4390)
This commit is contained in:
@@ -33,23 +33,7 @@ from ray.utils import _random_string
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start():
|
||||
# Start the Ray processes.
|
||||
ray.init(num_cpus=1)
|
||||
yield None
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shutdown_only():
|
||||
yield None
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
def test_simple_serialization(ray_start):
|
||||
def test_simple_serialization(ray_start_regular):
|
||||
primitive_objects = [
|
||||
# Various primitive types.
|
||||
0,
|
||||
@@ -116,7 +100,7 @@ def test_simple_serialization(ray_start):
|
||||
assert type(obj) == type(new_obj_2)
|
||||
|
||||
|
||||
def test_complex_serialization(ray_start):
|
||||
def test_complex_serialization(ray_start_regular):
|
||||
def assert_equal(obj1, obj2):
|
||||
module_numpy = (type(obj1).__module__ == np.__name__
|
||||
or type(obj2).__module__ == np.__name__)
|
||||
@@ -319,7 +303,7 @@ def test_complex_serialization(ray_start):
|
||||
assert_equal(obj, ray.get(ray.put(obj)))
|
||||
|
||||
|
||||
def test_ray_recursive_objects(ray_start):
|
||||
def test_ray_recursive_objects(ray_start_regular):
|
||||
class ClassA(object):
|
||||
pass
|
||||
|
||||
@@ -347,7 +331,7 @@ def test_ray_recursive_objects(ray_start):
|
||||
ray.put(obj)
|
||||
|
||||
|
||||
def test_passing_arguments_by_value_out_of_the_box(ray_start):
|
||||
def test_passing_arguments_by_value_out_of_the_box(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(x):
|
||||
return x
|
||||
@@ -379,7 +363,7 @@ def test_passing_arguments_by_value_out_of_the_box(ray_start):
|
||||
ray.get(ray.put(Foo))
|
||||
|
||||
|
||||
def test_putting_object_that_closes_over_object_id(ray_start):
|
||||
def test_putting_object_that_closes_over_object_id(ray_start_regular):
|
||||
# This test is here to prevent a regression of
|
||||
# https://github.com/ray-project/ray/issues/1317.
|
||||
|
||||
@@ -422,9 +406,7 @@ def test_put_get(shutdown_only):
|
||||
assert value_before == value_after
|
||||
|
||||
|
||||
def test_custom_serializers(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_custom_serializers(ray_start_regular):
|
||||
class Foo(object):
|
||||
def __init__(self):
|
||||
self.x = 3
|
||||
@@ -454,7 +436,7 @@ def test_custom_serializers(shutdown_only):
|
||||
assert ray.get(f.remote()) == ((3, "string1", Bar.__name__), "string2")
|
||||
|
||||
|
||||
def test_serialization_final_fallback(ray_start):
|
||||
def test_serialization_final_fallback(ray_start_regular):
|
||||
pytest.importorskip("catboost")
|
||||
# This test will only run when "catboost" is installed.
|
||||
from catboost import CatBoostClassifier
|
||||
@@ -471,9 +453,7 @@ def test_serialization_final_fallback(ray_start):
|
||||
reconstructed_model.get_params().items())
|
||||
|
||||
|
||||
def test_register_class(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def test_register_class(ray_start_2_cpus):
|
||||
# Check that putting an object of a class that has not been registered
|
||||
# throws an exception.
|
||||
class TempClass(object):
|
||||
@@ -616,7 +596,7 @@ def test_register_class(shutdown_only):
|
||||
assert not hasattr(c2, "method1")
|
||||
|
||||
|
||||
def test_keyword_args(shutdown_only):
|
||||
def test_keyword_args(ray_start_regular):
|
||||
@ray.remote
|
||||
def keyword_fct1(a, b="hello"):
|
||||
return "{} {}".format(a, b)
|
||||
@@ -629,8 +609,6 @@ def test_keyword_args(shutdown_only):
|
||||
def keyword_fct3(a, b, c="hello", d="world"):
|
||||
return "{} {} {} {}".format(a, b, c, d)
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
x = keyword_fct1.remote(1)
|
||||
assert ray.get(x) == "1 hello"
|
||||
x = keyword_fct1.remote(1, "hi")
|
||||
@@ -886,8 +864,7 @@ def test_submit_api(shutdown_only):
|
||||
assert ray.get([id1, id2, id3, id4]) == [0, 1, "test", 2]
|
||||
|
||||
|
||||
def test_get_multiple(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
def test_get_multiple(ray_start_regular):
|
||||
object_ids = [ray.put(i) for i in range(10)]
|
||||
assert ray.get(object_ids) == list(range(10))
|
||||
|
||||
@@ -898,8 +875,7 @@ def test_get_multiple(shutdown_only):
|
||||
assert results == indices
|
||||
|
||||
|
||||
def test_get_multiple_experimental(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
def test_get_multiple_experimental(ray_start_regular):
|
||||
object_ids = [ray.put(i) for i in range(10)]
|
||||
|
||||
object_ids_tuple = tuple(object_ids)
|
||||
@@ -909,8 +885,7 @@ def test_get_multiple_experimental(shutdown_only):
|
||||
assert ray.experimental.get(object_ids_nparray) == list(range(10))
|
||||
|
||||
|
||||
def test_get_dict(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
def test_get_dict(ray_start_regular):
|
||||
d = {str(i): ray.put(i) for i in range(5)}
|
||||
for i in range(5, 10):
|
||||
d[str(i)] = i
|
||||
@@ -919,9 +894,7 @@ def test_get_dict(shutdown_only):
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_wait(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_wait(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
time.sleep(delay)
|
||||
@@ -976,9 +949,7 @@ def test_wait(shutdown_only):
|
||||
ray.wait([1])
|
||||
|
||||
|
||||
def test_wait_iterables(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_wait_iterables(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(delay):
|
||||
time.sleep(delay)
|
||||
@@ -1075,9 +1046,7 @@ def test_caching_functions_to_run(shutdown_only):
|
||||
ray.worker.global_worker.run_function_on_all_workers(f)
|
||||
|
||||
|
||||
def test_running_function_on_all_workers(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_running_function_on_all_workers(ray_start_regular):
|
||||
def f(worker_info):
|
||||
sys.path.append("fake_directory")
|
||||
|
||||
@@ -1104,9 +1073,7 @@ def test_running_function_on_all_workers(shutdown_only):
|
||||
assert "fake_directory" not in ray.get(get_path2.remote())
|
||||
|
||||
|
||||
def test_profiling_api(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def test_profiling_api(ray_start_2_cpus):
|
||||
@ray.remote
|
||||
def f():
|
||||
with ray.profile(
|
||||
@@ -1150,16 +1117,6 @@ def test_profiling_api(shutdown_only):
|
||||
break
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def ray_start_cluster():
|
||||
cluster = ray.tests.cluster_utils.Cluster()
|
||||
yield cluster
|
||||
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
cluster.shutdown()
|
||||
|
||||
|
||||
def test_wait_cluster(ray_start_cluster):
|
||||
cluster = ray_start_cluster
|
||||
cluster.add_node(num_cpus=1, resources={"RemoteResource": 1})
|
||||
@@ -1227,10 +1184,9 @@ def test_object_transfer_dump(ray_start_cluster):
|
||||
}) == num_nodes
|
||||
|
||||
|
||||
def test_identical_function_names(shutdown_only):
|
||||
def test_identical_function_names(ray_start_regular):
|
||||
# Define a bunch of remote functions and make sure that we don't
|
||||
# accidentally call an older version.
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
num_calls = 200
|
||||
|
||||
@@ -1294,8 +1250,7 @@ def test_identical_function_names(shutdown_only):
|
||||
assert result_values == num_calls * [5]
|
||||
|
||||
|
||||
def test_illegal_api_calls(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
def test_illegal_api_calls(ray_start_regular):
|
||||
|
||||
# Verify that we cannot call put on an ObjectID.
|
||||
x = ray.put(1)
|
||||
@@ -1310,10 +1265,9 @@ def test_illegal_api_calls(shutdown_only):
|
||||
# because plasma client isn't thread-safe. This needs to be fixed from the
|
||||
# Arrow side. See #4107 for relevant discussions.
|
||||
@pytest.mark.skipif(six.PY2, reason="Doesn't work in Python 2.")
|
||||
def test_multithreading(shutdown_only):
|
||||
def test_multithreading(ray_start_2_cpus):
|
||||
# This test requires at least 2 CPUs to finish since the worker does not
|
||||
# release resources when joining the threads.
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def run_test_in_multi_threads(test_case, num_threads=10, num_repeats=25):
|
||||
"""A helper function that runs test cases in multiple threads."""
|
||||
@@ -2273,9 +2227,7 @@ def test_specific_gpus(save_gpu_ids_shutdown_only):
|
||||
ray.get([g.remote() for _ in range(100)])
|
||||
|
||||
|
||||
def test_blocking_tasks(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_blocking_tasks(ray_start_regular):
|
||||
@ray.remote
|
||||
def f(i, j):
|
||||
return (i, j)
|
||||
@@ -2310,9 +2262,7 @@ def test_blocking_tasks(shutdown_only):
|
||||
ray.get(sleep.remote())
|
||||
|
||||
|
||||
def test_max_call_tasks(shutdown_only):
|
||||
ray.init(num_cpus=1)
|
||||
|
||||
def test_max_call_tasks(ray_start_regular):
|
||||
@ray.remote(max_calls=1)
|
||||
def f():
|
||||
return os.getpid()
|
||||
@@ -2692,9 +2642,7 @@ def test_wait_reconstruction(shutdown_only):
|
||||
assert len(ready_ids) == 1
|
||||
|
||||
|
||||
def test_ray_setproctitle(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def test_ray_setproctitle(ray_start_2_cpus):
|
||||
@ray.remote
|
||||
class UniqueName(object):
|
||||
def __init__(self):
|
||||
@@ -2739,9 +2687,7 @@ def test_duplicate_error_messages(shutdown_only):
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("TRAVIS") is None,
|
||||
reason="This test should only be run on Travis.")
|
||||
def test_ray_stack(shutdown_only):
|
||||
ray.init(num_cpus=2)
|
||||
|
||||
def test_ray_stack(ray_start_2_cpus):
|
||||
def unique_name_1():
|
||||
time.sleep(1000)
|
||||
|
||||
@@ -2797,9 +2743,7 @@ def test_socket_dir_not_existing(shutdown_only):
|
||||
ray.init(num_cpus=1, raylet_socket_name=temp_raylet_socket_name)
|
||||
|
||||
|
||||
def test_raylet_is_robust_to_random_messages(shutdown_only):
|
||||
|
||||
ray.init(num_cpus=1)
|
||||
def test_raylet_is_robust_to_random_messages(ray_start_regular):
|
||||
node_manager_address = None
|
||||
node_manager_port = None
|
||||
for client in ray.global_state.client_table():
|
||||
@@ -2820,7 +2764,7 @@ def test_raylet_is_robust_to_random_messages(shutdown_only):
|
||||
assert ray.get(f.remote()) == 1
|
||||
|
||||
|
||||
def test_non_ascii_comment(ray_start):
|
||||
def test_non_ascii_comment(ray_start_regular):
|
||||
@ray.remote
|
||||
def f():
|
||||
# 日本語 Japanese comment
|
||||
|
||||
Reference in New Issue
Block a user