mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 17:55:15 +08:00
[tune] Tune Pausing (#1136)
* fix yaml bug * add ext agent * gpus * update * tuning * docs * Sun Oct 15 21:09:25 PDT 2017 * lint * update * Sun Oct 15 22:39:55 PDT 2017 * Sun Oct 15 22:40:17 PDT 2017 * Sun Oct 15 22:43:06 PDT 2017 * Sun Oct 15 22:46:06 PDT 2017 * Sun Oct 15 22:46:21 PDT 2017 * Sun Oct 15 22:48:11 PDT 2017 * Sun Oct 15 22:48:44 PDT 2017 * Sun Oct 15 22:49:23 PDT 2017 * Sun Oct 15 22:50:21 PDT 2017 * Sun Oct 15 22:53:00 PDT 2017 * Sun Oct 15 22:53:34 PDT 2017 * Sun Oct 15 22:54:33 PDT 2017 * Sun Oct 15 22:54:50 PDT 2017 * Sun Oct 15 22:55:20 PDT 2017 * Sun Oct 15 22:56:56 PDT 2017 * Sun Oct 15 22:59:03 PDT 2017 * fix * Update tune_mnist_ray.py * remove script trial * fix * reorder * fix ex * py2 support * upd * comments * comments * cleanup readme * fix trial * annotate * Update rllib.rst * init pausing * Docs, Lint * fix danglings and restore endpoint moved to trialrunner * renaming * nit * start always starts from checkpoint * smalls * nits * lint * last change
This commit is contained in:
@@ -314,13 +314,31 @@ class _MockAgent(Agent):
|
||||
_default_config = {}
|
||||
|
||||
def _init(self):
|
||||
pass
|
||||
self.info = None
|
||||
|
||||
def _train(self):
|
||||
return TrainingResult(
|
||||
episode_reward_mean=10, episode_len_mean=10,
|
||||
timesteps_this_iter=10, info={})
|
||||
|
||||
def _save(self):
|
||||
path = os.path.join(self.logdir, "mock_agent.pkl")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump(self.info, f)
|
||||
return path
|
||||
|
||||
def _restore(self, checkpoint_path):
|
||||
with open(checkpoint_path, 'rb') as f:
|
||||
info = pickle.load(f)
|
||||
self.info = info
|
||||
|
||||
def set_info(self, info):
|
||||
self.info = info
|
||||
return info
|
||||
|
||||
def get_info(self):
|
||||
return self.info
|
||||
|
||||
|
||||
def get_agent_class(alg):
|
||||
"""Returns the class of an known agent given its name."""
|
||||
|
||||
@@ -57,9 +57,8 @@ def main(argv):
|
||||
runner.add_trial(
|
||||
Trial(
|
||||
args.env, args.alg, args.config, args.local_dir, None,
|
||||
args.resources, args.stop, args.checkpoint_freq,
|
||||
args.restore, args.upload_dir))
|
||||
|
||||
args.resources, args.stop, args.checkpoint_freq, args.restore,
|
||||
args.upload_dir))
|
||||
ray.init(
|
||||
redis_address=args.redis_address, num_cpus=args.num_cpus,
|
||||
num_gpus=args.num_gpus)
|
||||
|
||||
+47
-14
@@ -26,6 +26,7 @@ class Trial(object):
|
||||
|
||||
PENDING = "PENDING"
|
||||
RUNNING = "RUNNING"
|
||||
PAUSED = "PAUSED"
|
||||
TERMINATED = "TERMINATED"
|
||||
ERROR = "ERROR"
|
||||
|
||||
@@ -56,12 +57,11 @@ class Trial(object):
|
||||
self.resources = resources
|
||||
self.stopping_criterion = stopping_criterion
|
||||
self.checkpoint_freq = checkpoint_freq
|
||||
self.restore_path = restore_path
|
||||
self.upload_dir = upload_dir
|
||||
|
||||
# Local trial state that is updated during the run
|
||||
self.last_result = None
|
||||
self.checkpoint_path = None
|
||||
self._checkpoint_path = restore_path
|
||||
self.agent = None
|
||||
self.status = Trial.PENDING
|
||||
self.location = None
|
||||
@@ -73,16 +73,9 @@ class Trial(object):
|
||||
be thrown.
|
||||
"""
|
||||
|
||||
self.status = Trial.RUNNING
|
||||
agent_cls = get_agent_class(self.alg)
|
||||
cls = ray.remote(
|
||||
num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)(
|
||||
agent_cls)
|
||||
self.agent = cls.remote(
|
||||
self.env_creator, self.config, self.local_dir, self.upload_dir,
|
||||
experiment_tag=self.experiment_tag)
|
||||
if self.restore_path:
|
||||
ray.get(self.agent.restore.remote(self.restore_path))
|
||||
self._setup_agent()
|
||||
if self._checkpoint_path:
|
||||
self.restore_from_path(path=self._checkpoint_path)
|
||||
|
||||
def stop(self, error=False):
|
||||
"""Stops this trial.
|
||||
@@ -111,6 +104,21 @@ class Trial(object):
|
||||
finally:
|
||||
self.agent = None
|
||||
|
||||
def pause(self):
|
||||
"""We want to release resources (specifically GPUs) when pausing an
|
||||
experiment. This results in a state similar to TERMINATED."""
|
||||
|
||||
assert self.status == Trial.RUNNING, self.status
|
||||
self.checkpoint()
|
||||
self.stop()
|
||||
self.status = Trial.PAUSED
|
||||
|
||||
def resume(self):
|
||||
"""Resume PAUSED tasks. This is a blocking call."""
|
||||
|
||||
assert self.status == Trial.PAUSED, self.status
|
||||
self.start()
|
||||
|
||||
def train_remote(self):
|
||||
"""Returns Ray future for one iteration of training."""
|
||||
|
||||
@@ -174,11 +182,36 @@ class Trial(object):
|
||||
"""
|
||||
|
||||
path = ray.get(self.agent.save.remote())
|
||||
self.checkpoint_path = path
|
||||
self._checkpoint_path = path
|
||||
print("Saved checkpoint to:", path)
|
||||
|
||||
return path
|
||||
|
||||
def restore_from_path(self, path):
|
||||
"""Restores agent state from specified path.
|
||||
|
||||
Args:
|
||||
path (str): A path where state will be restored.
|
||||
"""
|
||||
|
||||
if self.agent is None:
|
||||
print("Unable to restore - no agent")
|
||||
else:
|
||||
try:
|
||||
ray.get(self.agent.restore.remote(path))
|
||||
except:
|
||||
print("Error restoring agent:", traceback.format_exc())
|
||||
self.status = Trial.ERROR
|
||||
|
||||
def _setup_agent(self):
|
||||
self.status = Trial.RUNNING
|
||||
agent_cls = get_agent_class(self.alg)
|
||||
cls = ray.remote(
|
||||
num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)(
|
||||
agent_cls)
|
||||
self.agent = cls.remote(
|
||||
self.env_creator, self.config, self.local_dir, self.upload_dir,
|
||||
experiment_tag=self.experiment_tag)
|
||||
|
||||
def __str__(self):
|
||||
identifier = '{}_{}'.format(self.alg, self.env_name)
|
||||
if self.experiment_tag:
|
||||
|
||||
@@ -35,7 +35,7 @@ class TrialRunner(object):
|
||||
"""Initializes a new TrialRunner."""
|
||||
|
||||
self._trials = []
|
||||
self._pending = {}
|
||||
self._running = {}
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
self._committed_resources = Resources(cpu=0, gpu=0)
|
||||
|
||||
@@ -43,7 +43,7 @@ class TrialRunner(object):
|
||||
"""Returns whether all trials have finished running."""
|
||||
|
||||
for t in self._trials:
|
||||
if t.status in [Trial.PENDING, Trial.RUNNING]:
|
||||
if t.status in [Trial.PENDING, Trial.RUNNING, Trial.PAUSED]:
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -56,7 +56,7 @@ class TrialRunner(object):
|
||||
|
||||
if self._can_launch_more():
|
||||
self._launch_trial()
|
||||
elif self._pending:
|
||||
elif self._running:
|
||||
self._process_events()
|
||||
else:
|
||||
for trial in self._trials:
|
||||
@@ -64,6 +64,9 @@ class TrialRunner(object):
|
||||
assert self._has_resources(trial.resources), \
|
||||
("Insufficient cluster resources to launch trial",
|
||||
(trial.resources, self._avail_resources))
|
||||
elif trial.status == Trial.PAUSED:
|
||||
assert False, "There are paused trials, but no more \
|
||||
pending trials with sufficient resources."
|
||||
assert False, "Called step when all trials finished?"
|
||||
|
||||
def get_trials(self):
|
||||
@@ -110,14 +113,14 @@ class TrialRunner(object):
|
||||
self._commit_resources(trial.resources)
|
||||
try:
|
||||
trial.start()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
self._running[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error starting agent, retrying:", traceback.format_exc())
|
||||
time.sleep(2)
|
||||
trial.stop(error=True)
|
||||
try:
|
||||
trial.start()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
self._running[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error starting agent, abort:", traceback.format_exc())
|
||||
trial.stop(error=True)
|
||||
@@ -125,27 +128,25 @@ class TrialRunner(object):
|
||||
# have been lost
|
||||
|
||||
def _process_events(self):
|
||||
[result_id], _ = ray.wait(self._pending.keys())
|
||||
trial = self._pending[result_id]
|
||||
del self._pending[result_id]
|
||||
[result_id], _ = ray.wait(self._running.keys())
|
||||
trial = self._running[result_id]
|
||||
del self._running[result_id]
|
||||
try:
|
||||
result = ray.get(result_id)
|
||||
print("result", result)
|
||||
trial.last_result = result
|
||||
|
||||
if trial.should_stop(result):
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop()
|
||||
self._stop_trial(trial)
|
||||
else:
|
||||
# TODO(rliaw): This implements checkpoint in a blocking manner
|
||||
if trial.should_checkpoint():
|
||||
trial.checkpoint()
|
||||
self._pending[trial.train_remote()] = trial
|
||||
self._running[trial.train_remote()] = trial
|
||||
except:
|
||||
print("Error processing event:", traceback.format_exc())
|
||||
if trial.status == Trial.RUNNING:
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop(error=True)
|
||||
self._stop_trial(trial, error=True)
|
||||
|
||||
def _get_runnable(self):
|
||||
for trial in self._trials:
|
||||
@@ -172,6 +173,10 @@ class TrialRunner(object):
|
||||
assert self._committed_resources.cpu >= 0
|
||||
assert self._committed_resources.gpu >= 0
|
||||
|
||||
def _stop_trial(self, trial, error=False):
|
||||
self._return_resources(trial.resources)
|
||||
trial.stop(error=error)
|
||||
|
||||
def _update_avail_resources(self):
|
||||
clients = ray.global_state.client_table()
|
||||
local_schedulers = [
|
||||
|
||||
Reference in New Issue
Block a user