mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 20:47:56 +08:00
Refactor pytest fixtures for ray core (#4390)
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
import time
|
||||
@@ -10,46 +9,11 @@ import time
|
||||
import ray
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.tests.cluster_utils import Cluster
|
||||
from ray.tests.conftest import generate_internal_config_map
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
g = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
head_node_args={
|
||||
"num_cpus": 1,
|
||||
"_internal_config": json.dumps({
|
||||
"num_heartbeats_timeout": 10
|
||||
})
|
||||
})
|
||||
yield g
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
g.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def start_connected_longer_cluster():
|
||||
"""Creates a cluster with a longer timeout."""
|
||||
g = Cluster(
|
||||
initialize_head=True,
|
||||
connect=True,
|
||||
head_node_args={
|
||||
"num_cpus": 1,
|
||||
"_internal_config": json.dumps({
|
||||
"num_heartbeats_timeout": 20
|
||||
})
|
||||
})
|
||||
yield g
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
g.shutdown()
|
||||
|
||||
|
||||
def test_cluster():
|
||||
"""Basic test for adding and removing nodes in cluster."""
|
||||
g = Cluster(initialize_head=False)
|
||||
@@ -70,7 +34,11 @@ def test_shutdown():
|
||||
assert not any(n.any_processes_alive() for n in [node, node2])
|
||||
|
||||
|
||||
def test_internal_config(start_connected_longer_cluster):
|
||||
@pytest.mark.parametrize(
|
||||
"ray_start_cluster_head",
|
||||
[generate_internal_config_map(num_heartbeats_timeout=20)],
|
||||
indirect=True)
|
||||
def test_internal_config(ray_start_cluster_head):
|
||||
"""Checks that the internal configuration setting works.
|
||||
|
||||
We set the cluster to timeout nodes after 2 seconds of no timeouts. We
|
||||
@@ -78,7 +46,7 @@ def test_internal_config(start_connected_longer_cluster):
|
||||
of sync, then wait another 2 seconds (giving 1 second of leeway) to check
|
||||
that the client has timed out.
|
||||
"""
|
||||
cluster = start_connected_longer_cluster
|
||||
cluster = ray_start_cluster_head
|
||||
worker = cluster.add_node()
|
||||
cluster.wait_for_nodes()
|
||||
|
||||
@@ -90,13 +58,13 @@ def test_internal_config(start_connected_longer_cluster):
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
|
||||
|
||||
def test_wait_for_nodes(start_connected_cluster):
|
||||
def test_wait_for_nodes(ray_start_cluster_head):
|
||||
"""Unit test for `Cluster.wait_for_nodes`.
|
||||
|
||||
Adds 4 workers, waits, then removes 4 workers, waits,
|
||||
then adds 1 worker, waits, and removes 1 worker, waits.
|
||||
"""
|
||||
cluster = start_connected_cluster
|
||||
cluster = ray_start_cluster_head
|
||||
workers = [cluster.add_node() for i in range(4)]
|
||||
cluster.wait_for_nodes()
|
||||
[cluster.remove_node(w) for w in workers]
|
||||
@@ -110,8 +78,8 @@ def test_wait_for_nodes(start_connected_cluster):
|
||||
assert ray.global_state.cluster_resources()["CPU"] == 1
|
||||
|
||||
|
||||
def test_worker_plasma_store_failure(start_connected_cluster):
|
||||
cluster = start_connected_cluster
|
||||
def test_worker_plasma_store_failure(ray_start_cluster_head):
|
||||
cluster = ray_start_cluster_head
|
||||
worker = cluster.add_node()
|
||||
cluster.wait_for_nodes()
|
||||
# Log monitor doesn't die for some reason
|
||||
|
||||
Reference in New Issue
Block a user