From 35b1d6189b1420ea6b4d9aec46fcae2e9bb42c7d Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 29 Jan 2018 18:48:45 -0800 Subject: [PATCH] [tune] save error msg, cleanup after object checkpoints --- python/ray/rllib/a3c/a3c.py | 4 +-- python/ray/rllib/agent.py | 4 +-- python/ray/rllib/dqn/dqn.py | 4 +-- python/ray/rllib/es/es.py | 4 +-- python/ray/rllib/ppo/ppo.py | 4 +-- python/ray/tune/examples/hyperband_example.py | 4 +-- python/ray/tune/trainable.py | 13 +++++++--- python/ray/tune/trial.py | 16 +++++++++--- python/ray/tune/trial_runner.py | 25 +++++++++++-------- 9 files changed, 49 insertions(+), 29 deletions(-) diff --git a/python/ray/rllib/a3c/a3c.py b/python/ray/rllib/a3c/a3c.py index 546938edb..f356c789d 100644 --- a/python/ray/rllib/a3c/a3c.py +++ b/python/ray/rllib/a3c/a3c.py @@ -114,9 +114,9 @@ class A3CAgent(Agent): return result - def _save(self): + def _save(self, checkpoint_dir): checkpoint_path = os.path.join( - self.logdir, "checkpoint-{}".format(self.iteration)) + checkpoint_dir, "checkpoint-{}".format(self.iteration)) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) extra_data = { diff --git a/python/ray/rllib/agent.py b/python/ray/rllib/agent.py index f6fe143d0..4624e20ef 100644 --- a/python/ray/rllib/agent.py +++ b/python/ray/rllib/agent.py @@ -147,8 +147,8 @@ class _MockAgent(Agent): episode_reward_mean=10, episode_len_mean=10, timesteps_this_iter=10, info={}) - def _save(self): - path = os.path.join(self.logdir, "mock_agent.pkl") + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "mock_agent.pkl") with open(path, 'wb') as f: pickle.dump(self.info, f) return path diff --git a/python/ray/rllib/dqn/dqn.py b/python/ray/rllib/dqn/dqn.py index 5441f11e1..825cba2a7 100644 --- a/python/ray/rllib/dqn/dqn.py +++ b/python/ray/rllib/dqn/dqn.py @@ -218,10 +218,10 @@ class DQNAgent(Agent): else: self.local_evaluator.sample(no_replay=True) - def _save(self): + def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, - os.path.join(self.logdir, "checkpoint"), + os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) extra_data = [ self.local_evaluator.save(), diff --git a/python/ray/rllib/es/es.py b/python/ray/rllib/es/es.py index 651982940..82f9786dc 100644 --- a/python/ray/rllib/es/es.py +++ b/python/ray/rllib/es/es.py @@ -300,9 +300,9 @@ class ESAgent(Agent): return result - def _save(self): + def _save(self, checkpoint_dir): checkpoint_path = os.path.join( - self.logdir, "checkpoint-{}".format(self.iteration)) + checkpoint_dir, "checkpoint-{}".format(self.iteration)) weights = self.policy.get_weights() objects = [ weights, diff --git a/python/ray/rllib/ppo/ppo.py b/python/ray/rllib/ppo/ppo.py index 3491d6cca..ad1773ead 100644 --- a/python/ray/rllib/ppo/ppo.py +++ b/python/ray/rllib/ppo/ppo.py @@ -244,10 +244,10 @@ class PPOAgent(Agent): return result - def _save(self): + def _save(self, checkpoint_dir): checkpoint_path = self.saver.save( self.local_evaluator.sess, - os.path.join(self.logdir, "checkpoint"), + os.path.join(checkpoint_dir, "checkpoint"), global_step=self.iteration) agent_state = ray.get( [a.save.remote() for a in self.remote_evaluators]) diff --git a/python/ray/tune/examples/hyperband_example.py b/python/ray/tune/examples/hyperband_example.py index ac4fdbc80..6196cef2b 100755 --- a/python/ray/tune/examples/hyperband_example.py +++ b/python/ray/tune/examples/hyperband_example.py @@ -35,8 +35,8 @@ class MyTrainableClass(Trainable): # objectives such as loss or accuracy (see tune/result.py). return TrainingResult(episode_reward_mean=v, timesteps_this_iter=1) - def _save(self): - path = os.path.join(self.logdir, "checkpoint") + def _save(self, checkpoint_dir): + path = os.path.join(checkpoint_dir, "checkpoint") with open(path, "w") as f: f.write(json.dumps({"timestep": self.timestep})) return path diff --git a/python/ray/tune/trainable.py b/python/ray/tune/trainable.py index 4d7f88b69..91a456add 100644 --- a/python/ray/tune/trainable.py +++ b/python/ray/tune/trainable.py @@ -141,17 +141,20 @@ class Trainable(object): return result - def save(self): + def save(self, checkpoint_dir=None): """Saves the current model state to a checkpoint. Subclasses should override ``_save()`` instead to save state. This method dumps additional metadata alongside the saved path. + Args: + checkpoint_dir (str): Optional dir to place the checkpoint. + Returns: Checkpoint path that may be passed to restore(). """ - checkpoint_path = self._save() + checkpoint_path = self._save(checkpoint_dir or self.logdir) pickle.dump( [self._experiment_id, self._iteration, self._timesteps_total, self._time_total], @@ -166,7 +169,8 @@ class Trainable(object): Object holding checkpoint data. """ - checkpoint_prefix = self.save() + tmpdir = tempfile.mkdtemp("save_to_object", dir=self.logdir) + checkpoint_prefix = self.save(tmpdir) data = {} base_dir = os.path.dirname(checkpoint_prefix) @@ -185,6 +189,7 @@ class Trainable(object): len(compressed))) f.write(compressed) + shutil.rmtree(tmpdir) return out.getvalue() def restore(self, checkpoint_path): @@ -234,7 +239,7 @@ class Trainable(object): raise NotImplementedError - def _save(self): + def _save(self, checkpoint_dir): """Subclasses should override this to implement save().""" raise NotImplementedError diff --git a/python/ray/tune/trial.py b/python/ray/tune/trial.py index a3621454f..4765bd6a3 100644 --- a/python/ray/tune/trial.py +++ b/python/ray/tune/trial.py @@ -20,6 +20,10 @@ DEBUG_PRINT_INTERVAL = 5 MAX_LEN_IDENTIFIER = 130 +def date_str(): + return datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + + class Resources( namedtuple("Resources", [ "cpu", "gpu", "driver_cpu_limit", "driver_gpu_limit"])): @@ -126,7 +130,7 @@ class Trial(object): elif self._checkpoint_obj: self.restore_from_obj(self._checkpoint_obj) - def stop(self, error=False, stop_logger=True): + def stop(self, error=False, error_msg=None, stop_logger=True): """Stops this trial. Stops this trial, releasing all allocating resources. If stopping the @@ -135,6 +139,8 @@ class Trial(object): Args: error (bool): Whether to mark this trial as terminated in error. + error_msg (str): Optional error message. + stop_logger (bool): Whether to shut down the trial logger. """ if error: @@ -143,6 +149,11 @@ class Trial(object): self.status = Trial.TERMINATED try: + if error_msg and self.logdir: + error_file = os.path.join( + self.logdir, "error_{}.txt".format(date_str())) + with open(error_file, "w") as f: + f.write(error_msg) if self.runner: stop_tasks = [] stop_tasks.append(self.runner.stop.remote()) @@ -317,8 +328,7 @@ class Trial(object): os.makedirs(self.local_dir) self.logdir = tempfile.mkdtemp( prefix="{}_{}".format( - self, - datetime.today().strftime("%Y-%m-%d_%H-%M-%S")), + str(self)[:MAX_LEN_IDENTIFIER], date_str()), dir=self.local_dir) self.result_logger = UnifiedLogger( self.config, self.logdir, self.upload_dir) diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index f80f9bd1b..c1ffc135e 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -195,15 +195,17 @@ class TrialRunner(object): trial.start() self._running[trial.train_remote()] = trial except Exception: - print("Error starting runner, retrying:", traceback.format_exc()) + error_msg = traceback.format_exc() + print("Error starting runner, retrying:", error_msg) time.sleep(2) - trial.stop(error=True) + trial.stop(error=True, error_msg=error_msg) try: trial.start() self._running[trial.train_remote()] = trial except Exception: - print("Error starting runner, abort:", traceback.format_exc()) - trial.stop(error=True) + error_msg = traceback.format_exc() + print("Error starting runner, abort:", error_msg) + trial.stop(error=True, error_msg=error_msg) # note that we don't return the resources, since they may # have been lost @@ -236,10 +238,11 @@ class TrialRunner(object): assert False, "Invalid scheduling decision: {}".format( decision) except Exception: - print("Error processing event:", traceback.format_exc()) + error_msg = traceback.format_exc() + print("Error processing event:", error_msg) if trial.status == Trial.RUNNING: self._scheduler_alg.on_trial_error(self, trial) - self._stop_trial(trial, error=True) + self._stop_trial(trial, error=True, error_msg=error_msg) def _get_runnable(self): return self._scheduler_alg.choose_trial_to_run(self) @@ -272,6 +275,7 @@ class TrialRunner(object): result for the trial and calls `scheduler.on_trial_complete` if RUNNING.""" error = False + error_msg = None if trial.status in [Trial.ERROR, Trial.TERMINATED]: return @@ -287,16 +291,17 @@ class TrialRunner(object): trial.update_last_result(result, terminate=True) self._scheduler_alg.on_trial_complete(self, trial, result) except Exception: - print("Error processing event:", traceback.format_exc()) + error_msg = traceback.format_exc() + print("Error processing event:", error_msg) self._scheduler_alg.on_trial_error(self, trial) error = True - self._stop_trial(trial, error=error) + self._stop_trial(trial, error=error, error_msg=error_msg) - def _stop_trial(self, trial, error=False): + def _stop_trial(self, trial, error=False, error_msg=None): """Only returns resources if resources allocated.""" prior_status = trial.status - trial.stop(error=error) + trial.stop(error=error, error_msg=error_msg) if prior_status == Trial.RUNNING: self._return_resources(trial.resources)