[tune] Tune Pausing (#1136)

* fix yaml bug

* add ext agent

* gpus

* update

* tuning

* docs

* Sun Oct 15 21:09:25 PDT 2017

* lint

* update

* Sun Oct 15 22:39:55 PDT 2017

* Sun Oct 15 22:40:17 PDT 2017

* Sun Oct 15 22:43:06 PDT 2017

* Sun Oct 15 22:46:06 PDT 2017

* Sun Oct 15 22:46:21 PDT 2017

* Sun Oct 15 22:48:11 PDT 2017

* Sun Oct 15 22:48:44 PDT 2017

* Sun Oct 15 22:49:23 PDT 2017

* Sun Oct 15 22:50:21 PDT 2017

* Sun Oct 15 22:53:00 PDT 2017

* Sun Oct 15 22:53:34 PDT 2017

* Sun Oct 15 22:54:33 PDT 2017

* Sun Oct 15 22:54:50 PDT 2017

* Sun Oct 15 22:55:20 PDT 2017

* Sun Oct 15 22:56:56 PDT 2017

* Sun Oct 15 22:59:03 PDT 2017

* fix

* Update tune_mnist_ray.py

* remove script trial

* fix

* reorder

* fix ex

* py2 support

* upd

* comments

* comments

* cleanup readme

* fix trial

* annotate

* Update rllib.rst

* init pausing

* Docs, Lint

* fix danglings and restore endpoint moved to trialrunner

* renaming

* nit

* start always starts from checkpoint

* smalls

* nits

* lint

* last change
This commit is contained in:
Richard Liaw
2017-10-22 23:04:15 -07:00
committed by GitHub
parent 81ca27dc08
commit 0c9817fa76
5 changed files with 146 additions and 31 deletions
+19 -1
View File
@@ -314,13 +314,31 @@ class _MockAgent(Agent):
_default_config = {}
def _init(self):
pass
self.info = None
def _train(self):
return TrainingResult(
episode_reward_mean=10, episode_len_mean=10,
timesteps_this_iter=10, info={})
def _save(self):
path = os.path.join(self.logdir, "mock_agent.pkl")
with open(path, 'wb') as f:
pickle.dump(self.info, f)
return path
def _restore(self, checkpoint_path):
with open(checkpoint_path, 'rb') as f:
info = pickle.load(f)
self.info = info
def set_info(self, info):
self.info = info
return info
def get_info(self):
return self.info
def get_agent_class(alg):
"""Returns the class of an known agent given its name."""
+2 -3
View File
@@ -57,9 +57,8 @@ def main(argv):
runner.add_trial(
Trial(
args.env, args.alg, args.config, args.local_dir, None,
args.resources, args.stop, args.checkpoint_freq,
args.restore, args.upload_dir))
args.resources, args.stop, args.checkpoint_freq, args.restore,
args.upload_dir))
ray.init(
redis_address=args.redis_address, num_cpus=args.num_cpus,
num_gpus=args.num_gpus)
+47 -14
View File
@@ -26,6 +26,7 @@ class Trial(object):
PENDING = "PENDING"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
TERMINATED = "TERMINATED"
ERROR = "ERROR"
@@ -56,12 +57,11 @@ class Trial(object):
self.resources = resources
self.stopping_criterion = stopping_criterion
self.checkpoint_freq = checkpoint_freq
self.restore_path = restore_path
self.upload_dir = upload_dir
# Local trial state that is updated during the run
self.last_result = None
self.checkpoint_path = None
self._checkpoint_path = restore_path
self.agent = None
self.status = Trial.PENDING
self.location = None
@@ -73,16 +73,9 @@ class Trial(object):
be thrown.
"""
self.status = Trial.RUNNING
agent_cls = get_agent_class(self.alg)
cls = ray.remote(
num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)(
agent_cls)
self.agent = cls.remote(
self.env_creator, self.config, self.local_dir, self.upload_dir,
experiment_tag=self.experiment_tag)
if self.restore_path:
ray.get(self.agent.restore.remote(self.restore_path))
self._setup_agent()
if self._checkpoint_path:
self.restore_from_path(path=self._checkpoint_path)
def stop(self, error=False):
"""Stops this trial.
@@ -111,6 +104,21 @@ class Trial(object):
finally:
self.agent = None
def pause(self):
"""We want to release resources (specifically GPUs) when pausing an
experiment. This results in a state similar to TERMINATED."""
assert self.status == Trial.RUNNING, self.status
self.checkpoint()
self.stop()
self.status = Trial.PAUSED
def resume(self):
"""Resume PAUSED tasks. This is a blocking call."""
assert self.status == Trial.PAUSED, self.status
self.start()
def train_remote(self):
"""Returns Ray future for one iteration of training."""
@@ -174,11 +182,36 @@ class Trial(object):
"""
path = ray.get(self.agent.save.remote())
self.checkpoint_path = path
self._checkpoint_path = path
print("Saved checkpoint to:", path)
return path
def restore_from_path(self, path):
"""Restores agent state from specified path.
Args:
path (str): A path where state will be restored.
"""
if self.agent is None:
print("Unable to restore - no agent")
else:
try:
ray.get(self.agent.restore.remote(path))
except:
print("Error restoring agent:", traceback.format_exc())
self.status = Trial.ERROR
def _setup_agent(self):
self.status = Trial.RUNNING
agent_cls = get_agent_class(self.alg)
cls = ray.remote(
num_cpus=self.resources.cpu, num_gpus=self.resources.gpu)(
agent_cls)
self.agent = cls.remote(
self.env_creator, self.config, self.local_dir, self.upload_dir,
experiment_tag=self.experiment_tag)
def __str__(self):
identifier = '{}_{}'.format(self.alg, self.env_name)
if self.experiment_tag:
+18 -13
View File
@@ -35,7 +35,7 @@ class TrialRunner(object):
"""Initializes a new TrialRunner."""
self._trials = []
self._pending = {}
self._running = {}
self._avail_resources = Resources(cpu=0, gpu=0)
self._committed_resources = Resources(cpu=0, gpu=0)
@@ -43,7 +43,7 @@ class TrialRunner(object):
"""Returns whether all trials have finished running."""
for t in self._trials:
if t.status in [Trial.PENDING, Trial.RUNNING]:
if t.status in [Trial.PENDING, Trial.RUNNING, Trial.PAUSED]:
return False
return True
@@ -56,7 +56,7 @@ class TrialRunner(object):
if self._can_launch_more():
self._launch_trial()
elif self._pending:
elif self._running:
self._process_events()
else:
for trial in self._trials:
@@ -64,6 +64,9 @@ class TrialRunner(object):
assert self._has_resources(trial.resources), \
("Insufficient cluster resources to launch trial",
(trial.resources, self._avail_resources))
elif trial.status == Trial.PAUSED:
assert False, "There are paused trials, but no more \
pending trials with sufficient resources."
assert False, "Called step when all trials finished?"
def get_trials(self):
@@ -110,14 +113,14 @@ class TrialRunner(object):
self._commit_resources(trial.resources)
try:
trial.start()
self._pending[trial.train_remote()] = trial
self._running[trial.train_remote()] = trial
except:
print("Error starting agent, retrying:", traceback.format_exc())
time.sleep(2)
trial.stop(error=True)
try:
trial.start()
self._pending[trial.train_remote()] = trial
self._running[trial.train_remote()] = trial
except:
print("Error starting agent, abort:", traceback.format_exc())
trial.stop(error=True)
@@ -125,27 +128,25 @@ class TrialRunner(object):
# have been lost
def _process_events(self):
[result_id], _ = ray.wait(self._pending.keys())
trial = self._pending[result_id]
del self._pending[result_id]
[result_id], _ = ray.wait(self._running.keys())
trial = self._running[result_id]
del self._running[result_id]
try:
result = ray.get(result_id)
print("result", result)
trial.last_result = result
if trial.should_stop(result):
self._return_resources(trial.resources)
trial.stop()
self._stop_trial(trial)
else:
# TODO(rliaw): This implements checkpoint in a blocking manner
if trial.should_checkpoint():
trial.checkpoint()
self._pending[trial.train_remote()] = trial
self._running[trial.train_remote()] = trial
except:
print("Error processing event:", traceback.format_exc())
if trial.status == Trial.RUNNING:
self._return_resources(trial.resources)
trial.stop(error=True)
self._stop_trial(trial, error=True)
def _get_runnable(self):
for trial in self._trials:
@@ -172,6 +173,10 @@ class TrialRunner(object):
assert self._committed_resources.cpu >= 0
assert self._committed_resources.gpu >= 0
def _stop_trial(self, trial, error=False):
self._return_resources(trial.resources)
trial.stop(error=error)
def _update_avail_resources(self):
clients = ray.global_state.client_table()
local_schedulers = [