mirror of
https://github.com/wassname/ray.git
synced 2026-07-04 14:48:25 +08:00
* [tune] Add _switch_working_directory to RayTrialExecutor (#6228) * Make _switch_working_directory before warn_if_slow * Rename _switch_working_directory to _change_working_directory
This commit is contained in:
committed by
Richard Liaw
parent
5e43b25e8c
commit
aaeb3c44a5
@@ -8,6 +8,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray
|
||||
from ray.exceptions import RayTimeoutError
|
||||
@@ -96,8 +97,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
if self._cached_actor:
|
||||
logger.debug("Cannot reuse cached runner {} for new trial".format(
|
||||
self._cached_actor))
|
||||
self._cached_actor.stop.remote()
|
||||
self._cached_actor.__ray_terminate__.remote()
|
||||
with self._change_working_directory(trial):
|
||||
self._cached_actor.stop.remote()
|
||||
self._cached_actor.__ray_terminate__.remote()
|
||||
self._cached_actor = None
|
||||
|
||||
cls = ray.remote(
|
||||
@@ -127,7 +129,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
}
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
kwargs["remote_checkpoint_dir"] = trial.remote_checkpoint_dir
|
||||
return cls.remote(**kwargs)
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
return cls.remote(**kwargs)
|
||||
|
||||
def _train(self, trial):
|
||||
"""Start one iteration of training and save remote id."""
|
||||
@@ -147,7 +151,8 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return
|
||||
|
||||
assert trial.status == Trial.RUNNING, trial.status
|
||||
remote = trial.runner.train.remote()
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.train.remote()
|
||||
|
||||
# Local Mode
|
||||
if isinstance(remote, dict):
|
||||
@@ -215,8 +220,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._cached_actor = trial.runner
|
||||
else:
|
||||
logger.debug("Trial %s: Destroying actor.", trial)
|
||||
trial.runner.stop.remote()
|
||||
trial.runner.__ray_terminate__.remote()
|
||||
with self._change_working_directory(trial):
|
||||
trial.runner.stop.remote()
|
||||
trial.runner.__ray_terminate__.remote()
|
||||
except Exception:
|
||||
logger.exception("Trial %s: Error stopping runner.", trial)
|
||||
self.set_status(trial, Trial.ERROR)
|
||||
@@ -298,14 +304,16 @@ class RayTrialExecutor(TrialExecutor):
|
||||
trial.experiment_tag = new_experiment_tag
|
||||
trial.config = new_config
|
||||
trainable = trial.runner
|
||||
with warn_if_slow("reset_config"):
|
||||
try:
|
||||
reset_val = ray.get(
|
||||
trainable.reset_config.remote(new_config),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception("Trial %s: reset_config timed out.", trial)
|
||||
return False
|
||||
with self._change_working_directory(trial):
|
||||
with warn_if_slow("reset_config"):
|
||||
try:
|
||||
reset_val = ray.get(
|
||||
trainable.reset_config.remote(new_config),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
except RayTimeoutError:
|
||||
logger.exception("Trial %s: reset_config timed out.",
|
||||
trial)
|
||||
return False
|
||||
return reset_val
|
||||
|
||||
def get_running_trials(self):
|
||||
@@ -561,14 +569,16 @@ class RayTrialExecutor(TrialExecutor):
|
||||
Checkpoint future, or None if an Exception occurs.
|
||||
"""
|
||||
result = result or trial.last_result
|
||||
if storage == Checkpoint.MEMORY:
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
else:
|
||||
with warn_if_slow("save_checkpoint_to_storage"):
|
||||
# TODO(ujvl): Make this asynchronous.
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
|
||||
with self._change_working_directory(trial):
|
||||
if storage == Checkpoint.MEMORY:
|
||||
value = trial.runner.save_to_object.remote()
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
else:
|
||||
with warn_if_slow("save_checkpoint_to_storage"):
|
||||
# TODO(ujvl): Make this asynchronous.
|
||||
value = ray.get(trial.runner.save.remote())
|
||||
checkpoint = Checkpoint(storage, value, result)
|
||||
with warn_if_slow("on_checkpoint", DEFAULT_GET_TIMEOUT) as profile:
|
||||
try:
|
||||
trial.on_checkpoint(checkpoint)
|
||||
@@ -604,18 +614,21 @@ class RayTrialExecutor(TrialExecutor):
|
||||
logger.debug("Trial %s: Attempting restore from object", trial)
|
||||
# Note that we don't store the remote since in-memory checkpoints
|
||||
# don't guarantee fault tolerance and don't need to be waited on.
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
with self._change_working_directory(trial):
|
||||
trial.runner.restore_from_object.remote(value)
|
||||
else:
|
||||
logger.debug("Trial %s: Attempting restore from %s", trial, value)
|
||||
if issubclass(trial.get_trainable_cls(), DurableTrainable):
|
||||
remote = trial.runner.restore.remote(value)
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore.remote(value)
|
||||
elif trial.sync_on_checkpoint:
|
||||
# This provides FT backwards compatibility in the
|
||||
# case where a DurableTrainable is not provided.
|
||||
logger.warning("Trial %s: Reading checkpoint into memory.",
|
||||
trial)
|
||||
data_dict = TrainableUtil.pickle_checkpoint(value)
|
||||
remote = trial.runner.restore_from_object.remote(data_dict)
|
||||
with self._change_working_directory(trial):
|
||||
remote = trial.runner.restore_from_object.remote(data_dict)
|
||||
else:
|
||||
raise AbortTrialExecution(
|
||||
"Pass in `sync_on_checkpoint=True` for driver-based trial"
|
||||
@@ -632,9 +645,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
A dict that maps ExportFormats to successfully exported models.
|
||||
"""
|
||||
if trial.export_formats and len(trial.export_formats) > 0:
|
||||
return ray.get(
|
||||
trial.runner.export_model.remote(trial.export_formats),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
with self._change_working_directory(trial):
|
||||
return ray.get(
|
||||
trial.runner.export_model.remote(trial.export_formats),
|
||||
DEFAULT_GET_TIMEOUT)
|
||||
return {}
|
||||
|
||||
def has_gpus(self):
|
||||
@@ -642,6 +656,23 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._update_avail_resources()
|
||||
return self._avail_resources.gpu > 0
|
||||
|
||||
@contextmanager
|
||||
def _change_working_directory(self, trial):
|
||||
"""Context manager changing working directory to trial logdir.
|
||||
Used in local mode.
|
||||
|
||||
For non-local mode it is no-op.
|
||||
"""
|
||||
if ray.worker._mode() == ray.worker.LOCAL_MODE:
|
||||
old_dir = os.getcwd()
|
||||
try:
|
||||
os.chdir(trial.logdir)
|
||||
yield
|
||||
finally:
|
||||
os.chdir(old_dir)
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def _to_gb(n_bytes):
|
||||
return round(n_bytes / (1024**3), 2)
|
||||
|
||||
Reference in New Issue
Block a user