From c3e9d94b18bc33ea022de5f30864a2bfc683b433 Mon Sep 17 00:00:00 2001 From: Richard Liaw Date: Sat, 6 Jul 2019 00:54:24 -0700 Subject: [PATCH] [tune][minor] Reduce checkpointing frequency (#4859) --- python/ray/tune/tests/test_cluster.py | 26 +++++++------- python/ray/tune/tests/test_commands.py | 34 ++++++++----------- .../tune/tests/test_experiment_analysis.py | 1 + python/ray/tune/tests/test_trial_runner.py | 9 +++-- python/ray/tune/trial_runner.py | 11 ++++-- python/ray/tune/tune.py | 5 +++ 6 files changed, 45 insertions(+), 41 deletions(-) diff --git a/python/ray/tune/tests/test_cluster.py b/python/ray/tune/tests/test_cluster.py index 8c8274bcb..0eb8c8ab8 100644 --- a/python/ray/tune/tests/test_cluster.py +++ b/python/ray/tune/tests/test_cluster.py @@ -272,7 +272,7 @@ def test_cluster_down_simple(start_connected_cluster, tmpdir): cluster.wait_for_nodes() dirpath = str(tmpdir) - runner = TrialRunner(BasicVariantGenerator(), local_checkpoint_dir=dirpath) + runner = TrialRunner(local_checkpoint_dir=dirpath, checkpoint_period=0) kwargs = { "stopping_criterion": { "training_iteration": 2 @@ -359,15 +359,16 @@ from ray import tune ray.init(redis_address="{redis_address}") -kwargs = dict( - run="PG", - env="CartPole-v1", + +tune.run( + "PG", + name="experiment", + config=dict(env="CartPole-v1"), stop=dict(training_iteration=10), local_dir="{checkpoint_dir}", + global_checkpoint_period=0, checkpoint_freq=1, - max_failures=1) - -tune.run_experiments( + max_failures=1, dict(experiment=kwargs), raise_on_failed_trial=False) """.format( @@ -449,15 +450,14 @@ ray.init(redis_address="{redis_address}") {fail_class_code} -kwargs = dict( - run={fail_class}, +tune.run( + {fail_class}, + name="experiment", stop=dict(training_iteration=5), local_dir="{checkpoint_dir}", checkpoint_freq=1, - max_failures=1) - -tune.run_experiments( - dict(experiment=kwargs), + global_checkpoint_period=0, + max_failures=1, raise_on_failed_trial=False) """.format( redis_address=cluster.redis_address, diff --git a/python/ray/tune/tests/test_commands.py b/python/ray/tune/tests/test_commands.py index ee6452127..277d9ed0b 100644 --- a/python/ray/tune/tests/test_commands.py +++ b/python/ray/tune/tests/test_commands.py @@ -67,16 +67,13 @@ def test_ls(start_ray, tmpdir): experiment_name = "test_ls" experiment_path = os.path.join(str(tmpdir), experiment_name) num_samples = 3 - tune.run_experiments({ - experiment_name: { - "run": "__fake", - "stop": { - "training_iteration": 1 - }, - "num_samples": num_samples, - "local_dir": str(tmpdir) - } - }) + tune.run( + "__fake", + name=experiment_name, + stop={"training_iteration": 1}, + num_samples=num_samples, + local_dir=str(tmpdir), + global_checkpoint_period=0) columns = ["status", "episode_reward_mean", "training_iteration"] limit = 2 @@ -104,16 +101,13 @@ def test_lsx(start_ray, tmpdir): num_experiments = 3 for i in range(num_experiments): experiment_name = "test_lsx{}".format(i) - tune.run_experiments({ - experiment_name: { - "run": "__fake", - "stop": { - "training_iteration": 1 - }, - "num_samples": 1, - "local_dir": project_path - } - }) + tune.run( + "__fake", + name=experiment_name, + stop={"training_iteration": 1}, + num_samples=1, + local_dir=project_path, + global_checkpoint_period=0) limit = 2 with Capturing() as output: diff --git a/python/ray/tune/tests/test_experiment_analysis.py b/python/ray/tune/tests/test_experiment_analysis.py index 2697e7838..7a2c70f84 100644 --- a/python/ray/tune/tests/test_experiment_analysis.py +++ b/python/ray/tune/tests/test_experiment_analysis.py @@ -32,6 +32,7 @@ class ExperimentAnalysisSuite(unittest.TestCase): def run_test_exp(self): self.ea = run( MyTrainableClass, + global_checkpoint_period=0, name=self.test_name, local_dir=self.test_dir, return_trials=False, diff --git a/python/ray/tune/tests/test_trial_runner.py b/python/ray/tune/tests/test_trial_runner.py index 5c4e2616a..b1b7b6633 100644 --- a/python/ray/tune/tests/test_trial_runner.py +++ b/python/ray/tune/tests/test_trial_runner.py @@ -2086,7 +2086,7 @@ class TrialRunnerTest(unittest.TestCase): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(local_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) trials = [ Trial( "__fake", @@ -2145,8 +2145,7 @@ class TrialRunnerTest(unittest.TestCase): ray.init(num_cpus=3) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(local_checkpoint_dir=tmpdir) - + runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) runner.add_trial( Trial( "__fake", @@ -2200,7 +2199,7 @@ class TrialRunnerTest(unittest.TestCase): }, checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(local_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) runner.add_trial(trial) for i in range(5): runner.step() @@ -2221,7 +2220,7 @@ class TrialRunnerTest(unittest.TestCase): ray.init() trial = Trial("__fake", checkpoint_freq=1) tmpdir = tempfile.mkdtemp() - runner = TrialRunner(local_checkpoint_dir=tmpdir) + runner = TrialRunner(local_checkpoint_dir=tmpdir, checkpoint_period=0) runner.add_trial(trial) for i in range(5): runner.step() diff --git a/python/ray/tune/trial_runner.py b/python/ray/tune/trial_runner.py index 5b20afe3d..c512c80d3 100644 --- a/python/ray/tune/trial_runner.py +++ b/python/ray/tune/trial_runner.py @@ -111,6 +111,7 @@ class TrialRunner(object): resume=False, server_port=TuneServer.DEFAULT_PORT, verbose=True, + checkpoint_period=10, trial_executor=None): """Initializes a new TrialRunner. @@ -174,6 +175,8 @@ class TrialRunner(object): logger.info("Starting a new experiment.") self._start_time = time.time() + self._last_checkpoint_time = -float("inf") + self._checkpoint_period = checkpoint_period self._session_str = datetime.fromtimestamp( self._start_time).strftime("%Y-%m-%d_%H-%M-%S") @@ -235,18 +238,20 @@ class TrialRunner(object): """Saves execution state to `self._local_checkpoint_dir`. Overwrites the current session checkpoint, which starts when self - is instantiated. + is instantiated. Throttle depends on self._checkpoint_period. """ if not self._local_checkpoint_dir: return - + if time.time() - self._last_checkpoint_time < self._checkpoint_period: + return + self._last_checkpoint_time = time.time() runner_state = { "checkpoints": list( self.trial_executor.get_checkpoints().values()), "runner_data": self.__getstate__(), "stats": { "start_time": self._start_time, - "timestamp": time.time() + "timestamp": self._last_checkpoint_time } } tmp_file_name = os.path.join(self._local_checkpoint_dir, diff --git a/python/ray/tune/tune.py b/python/ray/tune/tune.py index 375021094..d0c126227 100644 --- a/python/ray/tune/tune.py +++ b/python/ray/tune/tune.py @@ -49,6 +49,7 @@ def run(run_or_experiment, sync_to_driver=None, checkpoint_freq=0, checkpoint_at_end=False, + global_checkpoint_period=10, export_formats=None, max_failures=3, restore=None, @@ -113,6 +114,9 @@ def run(run_or_experiment, checkpoints. A value of 0 (default) disables checkpointing. checkpoint_at_end (bool): Whether to checkpoint at the end of the experiment regardless of the checkpoint_freq. Default is False. + global_checkpoint_period (int): Seconds between global checkpointing. + This does not affect `checkpoint_freq`, which specifies frequency + for individual trials. export_formats (list): List of formats that exported at the end of the experiment. Default is None. max_failures (int): Try to recover a trial from its last @@ -212,6 +216,7 @@ def run(run_or_experiment, local_checkpoint_dir=experiment.checkpoint_dir, remote_checkpoint_dir=experiment.remote_checkpoint_dir, sync_to_cloud=sync_to_cloud, + checkpoint_period=global_checkpoint_period, resume=resume, launch_web_server=with_server, server_port=server_port,