mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 11:37:28 +08:00
[tune] use dated experiment dir per default (#11104)
This commit is contained in:
@@ -106,7 +106,7 @@ class AutoMLSearcher(SearchAlgorithm):
|
||||
deep_insert(path.split("."), value, new_spec["config"])
|
||||
|
||||
trial = create_trial_from_spec(
|
||||
new_spec, exp.name, self._parser, experiment_tag=tag)
|
||||
new_spec, exp.dir_name, self._parser, experiment_tag=tag)
|
||||
|
||||
# AutoML specific fields set in Trial
|
||||
trial.results = []
|
||||
|
||||
@@ -10,7 +10,7 @@ from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.sample import Domain
|
||||
from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, \
|
||||
TimeoutStopper
|
||||
from ray.tune.utils import detect_checkpoint_function
|
||||
from ray.tune.utils import date_str, detect_checkpoint_function
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -137,8 +137,18 @@ class Experiment:
|
||||
"within your trainable function.")
|
||||
self._run_identifier = Experiment.register_if_needed(run)
|
||||
self.name = name or self._run_identifier
|
||||
|
||||
# If the name has been set explicitly, we don't want to create
|
||||
# dated directories. The same is true for string run identifiers.
|
||||
if int(os.environ.get("TUNE_DISABLE_DATED_SUBDIR", 0)) == 1 or name \
|
||||
or isinstance(run, str):
|
||||
self.dir_name = self.name
|
||||
else:
|
||||
self.dir_name = "{}_{}".format(self.name, date_str())
|
||||
|
||||
if upload_dir:
|
||||
self.remote_checkpoint_dir = os.path.join(upload_dir, self.name)
|
||||
self.remote_checkpoint_dir = os.path.join(upload_dir,
|
||||
self.dir_name)
|
||||
else:
|
||||
self.remote_checkpoint_dir = None
|
||||
|
||||
@@ -249,8 +259,16 @@ class Experiment:
|
||||
return run_object
|
||||
elif isinstance(run_object, type) or callable(run_object):
|
||||
name = "DEFAULT"
|
||||
if hasattr(run_object, "__name__"):
|
||||
name = run_object.__name__
|
||||
if hasattr(run_object, "_name"):
|
||||
name = run_object._name
|
||||
elif hasattr(run_object, "__name__"):
|
||||
fn_name = run_object.__name__
|
||||
if fn_name == "<lambda>":
|
||||
name = "lambda"
|
||||
elif fn_name.startswith("<"):
|
||||
name = "DEFAULT"
|
||||
else:
|
||||
name = fn_name
|
||||
else:
|
||||
logger.warning(
|
||||
"No name detected on trainable. Using {}.".format(name))
|
||||
@@ -287,7 +305,7 @@ class Experiment:
|
||||
@property
|
||||
def checkpoint_dir(self):
|
||||
if self.local_dir:
|
||||
return os.path.join(self.local_dir, self.name)
|
||||
return os.path.join(self.local_dir, self.dir_name)
|
||||
|
||||
@property
|
||||
def run_identifier(self):
|
||||
|
||||
@@ -77,7 +77,7 @@ class BasicVariantGenerator(SearchAlgorithm):
|
||||
self._trial_generator,
|
||||
self._generate_trials(
|
||||
experiment.spec.get("num_samples", 1), experiment.spec,
|
||||
experiment.name))
|
||||
experiment.dir_name))
|
||||
|
||||
def next_trial(self):
|
||||
"""Provides one Trial object to be queued into the TrialRunner.
|
||||
|
||||
@@ -111,7 +111,7 @@ class SearchGenerator(SearchAlgorithm):
|
||||
"""
|
||||
if not self.is_finished():
|
||||
return self.create_trial_if_possible(self._experiment.spec,
|
||||
self._experiment.name)
|
||||
self._experiment.dir_name)
|
||||
return None
|
||||
|
||||
def create_trial_if_possible(self, experiment_spec: Dict,
|
||||
|
||||
@@ -1194,6 +1194,50 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
# With strict metric checking disabled, this should not raise
|
||||
tune.run(train, metric="acc")
|
||||
|
||||
def testTrialDirCreation(self):
|
||||
def test_trial_dir(config):
|
||||
return 1.0
|
||||
|
||||
# Per default, the directory should be named `test_trial_dir_{date}`
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tune.run(test_trial_dir, local_dir=tmp_dir)
|
||||
|
||||
subdirs = list(os.listdir(tmp_dir))
|
||||
self.assertNotIn("test_trial_dir", subdirs)
|
||||
found = False
|
||||
for subdir in subdirs:
|
||||
if subdir.startswith("test_trial_dir_"): # Date suffix
|
||||
found = True
|
||||
break
|
||||
self.assertTrue(found)
|
||||
|
||||
# If we set an explicit name, no date should be appended
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tune.run(test_trial_dir, local_dir=tmp_dir, name="my_test_exp")
|
||||
|
||||
subdirs = list(os.listdir(tmp_dir))
|
||||
self.assertIn("my_test_exp", subdirs)
|
||||
found = False
|
||||
for subdir in subdirs:
|
||||
if subdir.startswith("my_test_exp_"): # Date suffix
|
||||
found = True
|
||||
break
|
||||
self.assertFalse(found)
|
||||
|
||||
# Don't append date if we set the env variable
|
||||
os.environ["TUNE_DISABLE_DATED_SUBDIR"] = "1"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tune.run(test_trial_dir, local_dir=tmp_dir)
|
||||
|
||||
subdirs = list(os.listdir(tmp_dir))
|
||||
self.assertIn("test_trial_dir", subdirs)
|
||||
found = False
|
||||
for subdir in subdirs:
|
||||
if subdir.startswith("test_trial_dir_"): # Date suffix
|
||||
found = True
|
||||
break
|
||||
self.assertFalse(found)
|
||||
|
||||
|
||||
class ShimCreationTest(unittest.TestCase):
|
||||
def testCreateScheduler(self):
|
||||
|
||||
@@ -3,7 +3,6 @@ from typing import Sequence
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from collections import deque
|
||||
import copy
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import platform
|
||||
import shutil
|
||||
@@ -22,7 +21,7 @@ from ray.tune.registry import get_trainable_cls, validate_trainable
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR, DONE, TRAINING_ITERATION
|
||||
from ray.tune.resources import Resources, json_to_resources, resources_to_json
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.utils import flatten_dict
|
||||
from ray.tune.utils import date_str, flatten_dict
|
||||
from ray.utils import binary_to_hex, hex_to_binary
|
||||
|
||||
DEBUG_PRINT_INTERVAL = 5
|
||||
@@ -36,10 +35,6 @@ MAX_LEN_IDENTIFIER = int(
|
||||
os.environ.get("MAX_LEN_IDENTIFIER", 130)))
|
||||
|
||||
|
||||
def date_str():
|
||||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
class Location:
|
||||
"""Describes the location at which Trial is placed to run."""
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from ray.tune.utils.util import (
|
||||
deep_update, flatten_dict, get_pinned_object, merge_dicts,
|
||||
deep_update, date_str, flatten_dict, get_pinned_object, merge_dicts,
|
||||
pin_in_object_store, unflattened_lookup, UtilMonitor,
|
||||
validate_save_restore, warn_if_slow, diagnose_serialization,
|
||||
detect_checkpoint_function, detect_reporter, detect_config_single,
|
||||
env_integer)
|
||||
|
||||
__all__ = [
|
||||
"deep_update", "flatten_dict", "get_pinned_object", "merge_dicts",
|
||||
"pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"deep_update", "date_str", "flatten_dict", "get_pinned_object",
|
||||
"merge_dicts", "pin_in_object_store", "unflattened_lookup", "UtilMonitor",
|
||||
"validate_save_restore", "warn_if_slow", "diagnose_serialization",
|
||||
"detect_checkpoint_function", "detect_reporter", "detect_config_single",
|
||||
"env_integer"
|
||||
|
||||
@@ -5,6 +5,7 @@ import inspect
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict, deque, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
@@ -153,6 +154,10 @@ class Tee(object):
|
||||
self.stream2.flush(*args, **kwargs)
|
||||
|
||||
|
||||
def date_str():
|
||||
return datetime.today().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
|
||||
def env_integer(key, default):
|
||||
# TODO(rliaw): move into ray.constants
|
||||
if key in os.environ:
|
||||
|
||||
Reference in New Issue
Block a user