[tune] use dated experiment dir per default (#11104)

This commit is contained in:
Kai Fricke
2020-09-30 22:43:59 +01:00
committed by GitHub
parent f0dba6bd2b
commit c77cfaa5ad
9 changed files with 82 additions and 17 deletions
+1 -1
View File
@@ -106,7 +106,7 @@ class AutoMLSearcher(SearchAlgorithm):
deep_insert(path.split("."), value, new_spec["config"])
trial = create_trial_from_spec(
new_spec, exp.name, self._parser, experiment_tag=tag)
new_spec, exp.dir_name, self._parser, experiment_tag=tag)
# AutoML specific fields set in Trial
trial.results = []
+23 -5
View File
@@ -10,7 +10,7 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
from ray.tune.sample import Domain
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
TimeoutStopper
from ray.tune.utils import detect_checkpoint_function
from ray.tune.utils import date_str, detect_checkpoint_function
logger = logging.getLogger(__name__)
@@ -137,8 +137,18 @@ class Experiment:
"within your trainable function.")
self._run_identifier = Experiment.register_if_needed(run)
self.name = name or self._run_identifier
# If the name has been set explicitly, we don't want to create
# dated directories. The same is true for string run identifiers.
if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
or isinstance(run, str):
self.dir_name = self.name
else:
self.dir_name = "{}_{}".format(self.name, date_str())
if upload_dir:
self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
self.remote_checkpoint_dir = os.path.join(upload_dir,
self.dir_name)
else:
self.remote_checkpoint_dir = None
@@ -249,8 +259,16 @@ class Experiment:
return run_object
elif isinstance(run_object, type) or callable(run_object):
name = "DEFAULT"
if hasattr(run_object, "__name__"):
name = run_object.__name__
if hasattr(run_object, "_name"):
name = run_object._name
elif hasattr(run_object, "__name__"):
fn_name = run_object.__name__
if fn_name == "<lambda>":
name = "lambda"
elif fn_name.startswith("<"):
name = "DEFAULT"
else:
name = fn_name
else:
logger.warning(
"No name detected on trainable. Using {}.".format(name))
@@ -287,7 +305,7 @@ class Experiment:
@property
def checkpoint_dir(self):
if self.local_dir:
return os.path.join(self.local_dir, self.name)
return os.path.join(self.local_dir, self.dir_name)
@property
def run_identifier(self):
+1 -1
View File
@@ -77,7 +77,7 @@ class BasicVariantGenerator(SearchAlgorithm):
self._trial_generator,
self._generate_trials(
experiment.spec.get("num_samples", 1), experiment.spec,
experiment.name))
experiment.dir_name))
def next_trial(self):
"""Provides one Trial object to be queued into the TrialRunner.
+1 -1
View File
@@ -111,7 +111,7 @@ class SearchGenerator(SearchAlgorithm):
"""
if not self.is_finished():
return self.create_trial_if_possible(self._experiment.spec,
self._experiment.name)
self._experiment.dir_name)
return None
def create_trial_if_possible(self, experiment_spec: Dict,
+44
View File
@@ -1194,6 +1194,50 @@ class TrainableFunctionApiTest(unittest.TestCase):
# With strict metric checking disabled, this should not raise
tune.run(train, metric="acc")
def testTrialDirCreation(self):
def test_trial_dir(config):
return 1.0
# Per default, the directory should be named `test_trial_dir_{date}`
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir)
subdirs = list(os.listdir(tmp_dir))
self.assertNotIn("test_trial_dir", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("test_trial_dir_"): # Date suffix
found = True
break
self.assertTrue(found)
# If we set an explicit name, no date should be appended
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir, name="my_test_exp")
subdirs = list(os.listdir(tmp_dir))
self.assertIn("my_test_exp", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("my_test_exp_"): # Date suffix
found = True
break
self.assertFalse(found)
# Don't append date if we set the env variable
os.environ["TUNE_DISABLE_DATED_SUBDIR"] = "1"
with tempfile.TemporaryDirectory() as tmp_dir:
tune.run(test_trial_dir, local_dir=tmp_dir)
subdirs = list(os.listdir(tmp_dir))
self.assertIn("test_trial_dir", subdirs)
found = False
for subdir in subdirs:
if subdir.startswith("test_trial_dir_"): # Date suffix
found = True
break
self.assertFalse(found)
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
+1 -6
View File
@@ -3,7 +3,6 @@ from typing import Sequence
import ray.cloudpickle as cloudpickle
from collections import deque
import copy
from datetime import datetime
import logging
import platform
import shutil
@@ -22,7 +21,7 @@ from ray.tune.registry import get_trainable_cls, validate_trainable
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
from ray.tune.resources import Resources, json_to_resources, resources_to_json
from ray.tune.trainable import TrainableUtil
from ray.tune.utils import flatten_dict
from ray.tune.utils import date_str, flatten_dict
from ray.utils import binary_to_hex, hex_to_binary
DEBUG_PRINT_INTERVAL = 5
@@ -36,10 +35,6 @@ MAX_LEN_IDENTIFIER = int(
os.environ.get("MAX_LEN_IDENTIFIER", 130)))
def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
class Location:
"""Describes the location at which Trial is placed to run."""
+3 -3
View File
@@ -1,13 +1,13 @@
from ray.tune.utils.util import (
deep_update, flatten_dict, get_pinned_object, merge_dicts,
deep_update, date_str, flatten_dict, get_pinned_object, merge_dicts,
pin_in_object_store, unflattened_lookup, UtilMonitor,
validate_save_restore, warn_if_slow, diagnose_serialization,
detect_checkpoint_function, detect_reporter, detect_config_single,
env_integer)
__all__ = [
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"deep_update", "date_str", "flatten_dict", "get_pinned_object",
"merge_dicts", "pin_in_object_store", "unflattened_lookup", "UtilMonitor",
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
"env_integer"
+5
View File
@@ -5,6 +5,7 @@ import inspect
import threading
import time
from collections import defaultdict, deque, Mapping, Sequence
from datetime import datetime
from threading import Thread
import numpy as np
@@ -153,6 +154,10 @@ class Tee(object):
self.stream2.flush(*args, **kwargs)
def date_str():
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
def env_integer(key, default):
# TODO(rliaw): move into ray.constants
if key in os.environ: