mirror of
https://github.com/wassname/ray.git
synced 2026-06-30 02:56:05 +08:00
[tune] remove some bottlenecks in trialrunner (#12476)
This commit is contained in:
@@ -25,7 +25,7 @@ from ray.tune.utils import warn_if_slow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RESOURCE_REFRESH_PERIOD = 0.5 # Refresh resources every 500 ms
|
||||
TUNE_STATE_REFRESH_PERIOD = 10 # Refresh resources every 10 s
|
||||
BOTTLENECK_WARN_PERIOD_S = 60
|
||||
NONTRIVIAL_WAIT_TIME_THRESHOLD_S = 1e-3
|
||||
DEFAULT_GET_TIMEOUT = 60.0 # seconds
|
||||
@@ -139,7 +139,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
queue_trials=False,
|
||||
reuse_actors=False,
|
||||
ray_auto_init=None,
|
||||
refresh_period=RESOURCE_REFRESH_PERIOD):
|
||||
refresh_period=None):
|
||||
if ray_auto_init is None:
|
||||
if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
|
||||
logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
|
||||
@@ -164,8 +164,15 @@ class RayTrialExecutor(TrialExecutor):
|
||||
self._avail_resources = Resources(cpu=0, gpu=0)
|
||||
self._committed_resources = Resources(cpu=0, gpu=0)
|
||||
self._resources_initialized = False
|
||||
|
||||
if refresh_period is None:
|
||||
refresh_period = float(
|
||||
os.environ.get("TUNE_STATE_REFRESH_PERIOD",
|
||||
TUNE_STATE_REFRESH_PERIOD))
|
||||
self._refresh_period = refresh_period
|
||||
self._last_resource_refresh = float("-inf")
|
||||
self._last_ip_refresh = float("-inf")
|
||||
self._last_ip_addresses = set()
|
||||
self._last_nontrivial_wait = time.time()
|
||||
if not ray.is_initialized() and ray_auto_init:
|
||||
logger.info("Initializing Ray automatically."
|
||||
@@ -423,11 +430,17 @@ class RayTrialExecutor(TrialExecutor):
|
||||
return list(self._running.values())
|
||||
|
||||
def get_alive_node_ips(self):
|
||||
now = time.time()
|
||||
if now - self._last_ip_refresh < self._refresh_period:
|
||||
return self._last_ip_addresses
|
||||
logger.debug("Checking ips from Ray state.")
|
||||
self._last_ip_refresh = now
|
||||
nodes = ray.state.nodes()
|
||||
ip_addresses = set()
|
||||
for node in nodes:
|
||||
if node["alive"]:
|
||||
ip_addresses.add(node["NodeManagerAddress"])
|
||||
self._last_ip_addresses = ip_addresses
|
||||
return ip_addresses
|
||||
|
||||
def get_current_trial_ips(self):
|
||||
@@ -525,6 +538,9 @@ class RayTrialExecutor(TrialExecutor):
|
||||
"Resource invalid: {}".format(resources))
|
||||
|
||||
def _update_avail_resources(self, num_retries=5):
|
||||
if time.time() - self._last_resource_refresh < self._refresh_period:
|
||||
return
|
||||
logger.debug("Checking Ray cluster resources.")
|
||||
resources = None
|
||||
for i in range(num_retries):
|
||||
if i > 0:
|
||||
@@ -534,10 +550,10 @@ class RayTrialExecutor(TrialExecutor):
|
||||
time.sleep(0.5)
|
||||
try:
|
||||
resources = ray.cluster_resources()
|
||||
except Exception:
|
||||
except Exception as exc:
|
||||
# TODO(rliaw): Remove this when local mode is fixed.
|
||||
# https://github.com/ray-project/ray/issues/4147
|
||||
logger.debug("Using resources for local machine.")
|
||||
logger.debug(f"{exc}: Using resources for local machine.")
|
||||
resources = ResourceSpec().resolve(True).to_resource_dict()
|
||||
if resources:
|
||||
break
|
||||
@@ -575,9 +591,7 @@ class RayTrialExecutor(TrialExecutor):
|
||||
has exceeded self._refresh_period. This also assumes that the
|
||||
cluster is not resizing very frequently.
|
||||
"""
|
||||
if time.time() - self._last_resource_refresh > self._refresh_period:
|
||||
self._update_avail_resources()
|
||||
|
||||
self._update_avail_resources()
|
||||
currently_available = Resources.subtract(self._avail_resources,
|
||||
self._committed_resources)
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ def create_resettable_function(num_resets: defaultdict):
|
||||
class ActorReuseTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
ray.init(num_cpus=1, num_gpus=0)
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
|
||||
@@ -76,6 +76,7 @@ class _PerTrialSyncerCallback(SyncerCallback):
|
||||
def start_connected_cluster():
|
||||
# Start the Ray processes.
|
||||
cluster = _start_new_cluster()
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
@@ -98,6 +99,7 @@ def start_connected_emptyhead_cluster():
|
||||
_register_all()
|
||||
register_trainable("__fake_remote", MockRemoteTrainer)
|
||||
register_trainable("__fake_durable", MockDurableTrainer)
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
yield cluster
|
||||
# The code after the yield will run as teardown code.
|
||||
ray.shutdown()
|
||||
|
||||
@@ -13,6 +13,9 @@ from ray.tune.trial import Trial, ExportFormat
|
||||
|
||||
|
||||
class RunExperimentTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
@@ -36,6 +36,9 @@ def create_mock_components():
|
||||
|
||||
|
||||
class TrialRunnerTest2(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["TUNE_STATE_REFRESH_PERIOD"] = "0.1"
|
||||
|
||||
def tearDown(self):
|
||||
ray.shutdown()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
@@ -360,7 +360,8 @@ class TrialRunner:
|
||||
trials=self._trials,
|
||||
trial=next_trial)
|
||||
elif self.trial_executor.get_running_trials():
|
||||
self._process_events() # blocking
|
||||
with warn_if_slow("process_events"):
|
||||
self._process_events() # blocking
|
||||
else:
|
||||
self.trial_executor.on_no_available_trials(self)
|
||||
|
||||
@@ -453,11 +454,13 @@ class TrialRunner:
|
||||
self._update_trial_queue(blocking=wait_for_trial)
|
||||
with warn_if_slow("choose_trial_to_run"):
|
||||
trial = self._scheduler_alg.choose_trial_to_run(self)
|
||||
logger.debug("Running trial {}".format(trial))
|
||||
if trial:
|
||||
logger.debug("Running trial {}".format(trial))
|
||||
return trial
|
||||
|
||||
def _process_events(self):
|
||||
failed_trial = self.trial_executor.get_next_failed_trial()
|
||||
with warn_if_slow("get_next_failed_trial"):
|
||||
failed_trial = self.trial_executor.get_next_failed_trial()
|
||||
if failed_trial:
|
||||
error_msg = (
|
||||
"{} (IP: {}) detected as stale. This is likely because the "
|
||||
@@ -478,14 +481,14 @@ class TrialRunner:
|
||||
trials=self._trials,
|
||||
trial=trial)
|
||||
elif trial.is_saving:
|
||||
with warn_if_slow("process_trial_save") as profile:
|
||||
with warn_if_slow("process_trial_save") as _profile:
|
||||
self._process_trial_save(trial)
|
||||
with warn_if_slow("callbacks.on_trial_save"):
|
||||
self._callbacks.on_trial_save(
|
||||
iteration=self._iteration,
|
||||
trials=self._trials,
|
||||
trial=trial)
|
||||
if profile.too_slow and trial.sync_on_checkpoint:
|
||||
if _profile.too_slow and trial.sync_on_checkpoint:
|
||||
# TODO(ujvl): Suggest using DurableTrainable once
|
||||
# API has converged.
|
||||
|
||||
|
||||
@@ -115,14 +115,14 @@ def get_pinned_object(pinned_id):
|
||||
|
||||
|
||||
class warn_if_slow:
|
||||
"""Prints a warning if a given operation is slower than 100ms.
|
||||
"""Prints a warning if a given operation is slower than 500ms.
|
||||
|
||||
Example:
|
||||
>>> with warn_if_slow("some_operation"):
|
||||
... ray.get(something)
|
||||
"""
|
||||
|
||||
DEFAULT_THRESHOLD = 0.5
|
||||
DEFAULT_THRESHOLD = float(os.environ.get("TUNE_WARN_THRESHOLD_S", 0.5))
|
||||
|
||||
def __init__(self, name, threshold=None):
|
||||
self.name = name
|
||||
@@ -137,10 +137,10 @@ class warn_if_slow:
|
||||
now = time.time()
|
||||
if now - self.start > self.threshold and now - START_OF_TIME > 60.0:
|
||||
self.too_slow = True
|
||||
_duration = now - self.start
|
||||
logger.warning(
|
||||
"The `%s` operation took %s seconds to complete, "
|
||||
"which may be a performance bottleneck.", self.name,
|
||||
now - self.start)
|
||||
f"The `{self.name}` operation took {_duration:.3f} s, "
|
||||
"which may be a performance bottleneck.")
|
||||
|
||||
|
||||
class Tee(object):
|
||||
|
||||
Reference in New Issue
Block a user