mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 01:46:10 +08:00
[tune] Check actor start -> test_cluster (#8056)
* test * info * ok * hard_stop * codefix
This commit is contained in:
@@ -27,6 +27,17 @@ from ray.tune.utils.mock import (MockDurableTrainer, MockRemoteTrainer,
|
||||
MOCK_REMOTE_DIR)
|
||||
|
||||
|
||||
def _check_trial_running(trial):
|
||||
if trial.runner:
|
||||
ray.get(trial.runner.get_info.remote())
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_running_trials(runner):
|
||||
return [t for t in runner.get_trials() if t.status == Trial.RUNNING]
|
||||
|
||||
|
||||
def _start_new_cluster():
|
||||
cluster = Cluster(
|
||||
initialize_head=True,
|
||||
@@ -89,6 +100,10 @@ def test_counting_resources(start_connected_cluster):
|
||||
runner.add_trial(t)
|
||||
|
||||
runner.step() # run 1
|
||||
running_trials = _get_running_trials(runner)
|
||||
assert len(running_trials) == 1
|
||||
assert _check_trial_running(running_trials[0])
|
||||
assert ray.available_resources().get("CPU", 0) == 0
|
||||
nodes += [cluster.add_node(num_cpus=1)]
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.cluster_resources()["CPU"] == 2
|
||||
@@ -146,16 +161,32 @@ def test_remove_node_before_result(start_connected_emptyhead_cluster):
|
||||
trial = Trial("__fake", **kwargs)
|
||||
runner.add_trial(trial)
|
||||
|
||||
runner.step() # Start trial
|
||||
runner.step() # Start trial, call _train once
|
||||
running_trials = _get_running_trials(runner)
|
||||
assert len(running_trials) == 1
|
||||
assert _check_trial_running(running_trials[0])
|
||||
assert not trial.last_result
|
||||
assert trial.status == Trial.RUNNING
|
||||
cluster.remove_node(node)
|
||||
cluster.add_node(num_cpus=1)
|
||||
cluster.wait_for_nodes()
|
||||
assert ray.cluster_resources()["CPU"] == 1
|
||||
|
||||
# Process result (x2), process save, process result.
|
||||
for _ in range(4):
|
||||
runner.step()
|
||||
# Process result: fetch data, invoke _train again
|
||||
runner.step()
|
||||
assert trial.last_result.get("training_iteration") == 1
|
||||
|
||||
# Process result: discover failure, recover, _train (from scratch)
|
||||
runner.step()
|
||||
|
||||
runner.step() # Process result, invoke _train
|
||||
assert trial.last_result.get("training_iteration") == 1
|
||||
runner.step() # Process result, invoke _save
|
||||
assert trial.last_result.get("training_iteration") == 2
|
||||
# process save, invoke _train
|
||||
runner.step()
|
||||
# process result
|
||||
runner.step()
|
||||
assert trial.status == Trial.TERMINATED
|
||||
|
||||
with pytest.raises(TuneError):
|
||||
@@ -303,7 +334,6 @@ def test_trial_migration(start_connected_emptyhead_cluster, trainable_id):
|
||||
runner.step()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not very consistent.")
|
||||
@pytest.mark.parametrize("trainable_id", ["__fake", "__fake_durable"])
|
||||
def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
"""Removing a node in full cluster causes Trial to be requeued."""
|
||||
@@ -330,6 +360,9 @@ def test_trial_requeue(start_connected_emptyhead_cluster, trainable_id):
|
||||
runner.step() # Process result, dispatch save
|
||||
runner.step() # Process save
|
||||
|
||||
running_trials = _get_running_trials(runner)
|
||||
assert len(running_trials) == 1
|
||||
assert _check_trial_running(running_trials[0])
|
||||
cluster.remove_node(node)
|
||||
cluster.wait_for_nodes()
|
||||
runner.step() # Process result, dispatch save
|
||||
|
||||
Reference in New Issue
Block a user