[tune] remove some bottlenecks in trialrunner (#12476)

This commit is contained in:
Richard Liaw
2020-11-30 14:54:25 -08:00
committed by GitHub
parent f5fe3794c8
commit 9ce7ad17fd
8 changed files with 45 additions and 17 deletions
+21 -7
View File
@@ -25,7 +25,7 @@ from ray.tune.utils import warn_if_slow
logger = logging.getLogger(__name__)
RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
BOTTLENECK_WARN_PERIOD_S = 60
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
DEFAULT_GET_TIMEOUT = 60.0 # seconds
@@ -139,7 +139,7 @@ class RayTrialExecutor(TrialExecutor):
queue_trials=False,
reuse_actors=False,
ray_auto_init=None,
refresh_period=RESOURCE_REFRESH_PERIOD):
refresh_period=None):
if ray_auto_init is None:
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
@@ -164,8 +164,15 @@ class RayTrialExecutor(TrialExecutor):
self._avail_resources = Resources(cpu=0, gpu=0)
self._committed_resources = Resources(cpu=0, gpu=0)
self._resources_initialized = False
if refresh_period is None:
refresh_period = float(
os.environ.get("TUNE_STATE_REFRESH_PERIOD",
TUNE_STATE_REFRESH_PERIOD))
self._refresh_period = refresh_period
self._last_resource_refresh = float("-inf")
self._last_ip_refresh = float("-inf")
self._last_ip_addresses = set()
self._last_nontrivial_wait = time.time()
if not ray.is_initialized() and ray_auto_init:
logger.info("Initializing Ray automatically."
@@ -423,11 +430,17 @@ class RayTrialExecutor(TrialExecutor):
return list(self._running.values())
def get_alive_node_ips(self):
now = time.time()
if now - self._last_ip_refresh < self._refresh_period:
return self._last_ip_addresses
logger.debug("Checking ips from Ray state.")
self._last_ip_refresh = now
nodes = ray.state.nodes()
ip_addresses = set()
for node in nodes:
if node["alive"]:
ip_addresses.add(node["NodeManagerAddress"])
self._last_ip_addresses = ip_addresses
return ip_addresses
def get_current_trial_ips(self):
@@ -525,6 +538,9 @@ class RayTrialExecutor(TrialExecutor):
"Resource invalid: {}".format(resources))
def _update_avail_resources(self, num_retries=5):
if time.time() - self._last_resource_refresh < self._refresh_period:
return
logger.debug("Checking Ray cluster resources.")
resources = None
for i in range(num_retries):
if i > 0:
@@ -534,10 +550,10 @@ class RayTrialExecutor(TrialExecutor):
time.sleep(0.5)
try:
resources = ray.cluster_resources()
except Exception:
except Exception as exc:
# TODO(rliaw): Remove this when local mode is fixed.
# https://github.com/ray-project/ray/issues/4147
logger.debug("Using resources for local machine.")
logger.debug(f"{exc}: Using resources for local machine.")
resources = ResourceSpec().resolve(True).to_resource_dict()
if resources:
break
@@ -575,9 +591,7 @@ class RayTrialExecutor(TrialExecutor):
has exceeded self._refresh_period. This also assumes that the
cluster is not resizing very frequently.
"""
if time.time() - self._last_resource_refresh > self._refresh_period:
self._update_avail_resources()
self._update_avail_resources()
currently_available = Resources.subtract(self._avail_resources,
self._committed_resources)
@@ -87,6 +87,7 @@ def create_resettable_function(num_resets: defaultdict):
class ActorReuseTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1, num_gpus=0)
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
def tearDown(self):
ray.shutdown()
+2
View File
@@ -76,6 +76,7 @@ class _PerTrialSyncerCallback(SyncerCallback):
def start_connected_cluster():
# Start the Ray processes.
cluster = _start_new_cluster()
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
@@ -98,6 +99,7 @@ def start_connected_emptyhead_cluster():
_register_all()
register_trainable("__fake_remote", MockRemoteTrainer)
register_trainable("__fake_durable", MockDurableTrainer)
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
yield cluster
# The code after the yield will run as teardown code.
ray.shutdown()
@@ -13,6 +13,9 @@ from ray.tune.trial import Trial, ExportFormat
class RunExperimentTest(unittest.TestCase):
def setUp(self):
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
@@ -36,6 +36,9 @@ def create_mock_components():
class TrialRunnerTest2(unittest.TestCase):
def setUp(self):
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
+8 -5
View File
@@ -360,7 +360,8 @@ class TrialRunner:
trials=self._trials,
trial=next_trial)
elif self.trial_executor.get_running_trials():
self._process_events() # blocking
with warn_if_slow("process_events"):
self._process_events() # blocking
else:
self.trial_executor.on_no_available_trials(self)
@@ -453,11 +454,13 @@ class TrialRunner:
self._update_trial_queue(blocking=wait_for_trial)
with warn_if_slow("choose_trial_to_run"):
trial = self._scheduler_alg.choose_trial_to_run(self)
logger.debug("Running trial {}".format(trial))
if trial:
logger.debug("Running trial {}".format(trial))
return trial
def _process_events(self):
failed_trial = self.trial_executor.get_next_failed_trial()
with warn_if_slow("get_next_failed_trial"):
failed_trial = self.trial_executor.get_next_failed_trial()
if failed_trial:
error_msg = (
"{} (IP: {}) detected as stale. This is likely because the "
@@ -478,14 +481,14 @@ class TrialRunner:
trials=self._trials,
trial=trial)
elif trial.is_saving:
with warn_if_slow("process_trial_save") as profile:
with warn_if_slow("process_trial_save") as _profile:
self._process_trial_save(trial)
with warn_if_slow("callbacks.on_trial_save"):
self._callbacks.on_trial_save(
iteration=self._iteration,
trials=self._trials,
trial=trial)
if profile.too_slow and trial.sync_on_checkpoint:
if _profile.too_slow and trial.sync_on_checkpoint:
# TODO(ujvl): Suggest using DurableTrainable once
# API has converged.
+5 -5
View File
@@ -115,14 +115,14 @@ def get_pinned_object(pinned_id):
class warn_if_slow:
"""Prints a warning if a given operation is slower than 100ms.
"""Prints a warning if a given operation is slower than 500ms.
Example:
>>> with warn_if_slow("some_operation"):
... ray.get(something)
"""
DEFAULT_THRESHOLD = 0.5
DEFAULT_THRESHOLD = float(os.environ.get("TUNE_WARN_THRESHOLD_S", 0.5))
def __init__(self, name, threshold=None):
self.name = name
@@ -137,10 +137,10 @@ class warn_if_slow:
now = time.time()
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
self.too_slow = True
_duration = now - self.start
logger.warning(
"The `%s` operation took %s seconds to complete, "
"which may be a performance bottleneck.", self.name,
now - self.start)
f"The `{self.name}` operation took {_duration:.3f} s, "
"which may be a performance bottleneck.")
class Tee(object):