mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 11:26:24 +08:00
[tune] Trial scheduler interface (#1160)
* trial scheduler interface * remove * update
This commit is contained in:
@@ -7,6 +7,7 @@ import time
|
||||
import traceback
|
||||
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_scheduler import FIFOScheduler, TrialScheduler
|
||||
|
||||
|
||||
class TrialRunner(object):
|
||||
@@ -31,9 +32,10 @@ class TrialRunner(object):
|
||||
misleading benchmark results.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, scheduler=None):
|
||||
"""Initializes a new TrialRunner."""
|
||||
|
||||
self._scheduler_alg = scheduler or FIFOScheduler()
|
||||
self._trials = []
|
||||
self._running = {}
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
@@ -61,7 +63,7 @@ class TrialRunner(object):
|
||||
else:
|
||||
for trial in self._trials:
|
||||
if trial.status == Trial.PENDING:
|
||||
assert self._has_resources(trial.resources), \
|
||||
assert self.has_resources(trial.resources), \
|
||||
("Insufficient cluster resources to launch trial",
|
||||
(trial.resources, self._avail_resources))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
@@ -89,6 +91,7 @@ class TrialRunner(object):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
messages = ["== Status =="]
|
||||
messages.append(self._scheduler_alg.debug_string())
|
||||
messages.append(
|
||||
"Resources used: {}/{} CPUs, {}/{} GPUs".format(
|
||||
self._committed_resources.cpu,
|
||||
@@ -103,6 +106,14 @@ class TrialRunner(object):
|
||||
" - {}:\t{}".format(t, t.progress_string()))
|
||||
return "\n".join(messages) + "\n"
|
||||
|
||||
def has_resources(self, resources):
|
||||
"""Returns whether this runner has at least the specified resources."""
|
||||
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
assert cpu_avail >= 0 and gpu_avail >= 0
|
||||
return resources.cpu <= cpu_avail and resources.gpu <= gpu_avail
|
||||
|
||||
def _can_launch_more(self):
|
||||
self._update_avail_resources()
|
||||
trial = self._get_runnable()
|
||||
@@ -139,27 +150,27 @@ class TrialRunner(object):
|
||||
if trial.should_stop(result):
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
# TODO(rliaw): This implements checkpoint in a blocking manner
|
||||
if trial.should_checkpoint():
|
||||
trial.checkpoint()
|
||||
self._running[trial.train_remote()] = trial
|
||||
decision = self._scheduler_alg.on_trial_result(
|
||||
self, trial, result)
|
||||
if decision == TrialScheduler.CONTINUE:
|
||||
if trial.should_checkpoint():
|
||||
# TODO(rliaw): This is a blocking call
|
||||
trial.checkpoint()
|
||||
self._running[trial.train_remote()] = trial
|
||||
elif decision == TrialScheduler.PAUSE:
|
||||
self._pause_trial(trial)
|
||||
elif decision == TrialScheduler.STOP:
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
assert False, "Invalid scheduling decision: {}".format(
|
||||
decision)
|
||||
except Exception:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
if trial.status == Trial.RUNNING:
|
||||
self._stop_trial(trial, error=True)
|
||||
|
||||
def _get_runnable(self):
|
||||
for trial in self._trials:
|
||||
if (trial.status == Trial.PENDING and
|
||||
self._has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def _has_resources(self, resources):
|
||||
cpu_avail = self._avail_resources.cpu - self._committed_resources.cpu
|
||||
gpu_avail = self._avail_resources.gpu - self._committed_resources.gpu
|
||||
assert cpu_avail >= 0 and gpu_avail >= 0
|
||||
return resources.cpu <= cpu_avail and resources.gpu <= gpu_avail
|
||||
return self._scheduler_alg.choose_trial_to_run(self)
|
||||
|
||||
def _commit_resources(self, resources):
|
||||
self._committed_resources = Resources(
|
||||
@@ -174,8 +185,12 @@ class TrialRunner(object):
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def _stop_trial(self, trial, error=False):
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop(error=error)
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
def _pause_trial(self, trial):
|
||||
trial.pause()
|
||||
self._return_resources(trial.resources)
|
||||
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
||||
from ray.tune.trial import Trial
|
||||
|
||||
|
||||
class TrialScheduler(object):
|
||||
CONTINUE = "CONTINUE"
|
||||
PAUSE = "PAUSE"
|
||||
STOP = "STOP"
|
||||
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
"""Called on each intermediate result returned by a trial.
|
||||
|
||||
At this point, the trial scheduler can make a decision by returning
|
||||
one of CONTINUE, PAUSE, and STOP."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def choose_trial_to_run(self, trial_runner, trials):
|
||||
"""Called to choose a new trial to run.
|
||||
|
||||
This should return one of the trials in trial_runner that is in
|
||||
the PENDING or PAUSED state."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def debug_string(self):
|
||||
"""Returns a human readable message for printing to the console."""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FIFOScheduler(TrialScheduler):
|
||||
def on_trial_result(self, trial_runner, trial, result):
|
||||
return TrialScheduler.CONTINUE
|
||||
|
||||
def choose_trial_to_run(self, trial_runner):
|
||||
for trial in trial_runner.get_trials():
|
||||
if (trial.status == Trial.PENDING and
|
||||
trial_runner.has_resources(trial.resources)):
|
||||
return trial
|
||||
return None
|
||||
|
||||
def debug_string(self):
|
||||
return "Using FIFO scheduling algorithm."
|
||||
Reference in New Issue
Block a user