mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 21:53:18 +08:00
[Tune] Parametrize Cloud Syncing Frequency (#8771)
This commit is contained in:
+41
-10
@@ -5,6 +5,7 @@ import time
|
||||
|
||||
from shlex import quote
|
||||
|
||||
from ray import ray_constants
|
||||
from ray import services
|
||||
from ray.tune.cluster_info import get_ssh_key, get_ssh_user
|
||||
from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
||||
@@ -12,7 +13,13 @@ from ray.tune.sync_client import (CommandBasedClient, get_sync_client,
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SYNC_PERIOD = 300
|
||||
# Syncing period for syncing local checkpoints to cloud.
|
||||
# In env variable is not set, sync happens every 300 seconds.
|
||||
CLOUD_SYNC_PERIOD = ray_constants.env_integer(
|
||||
key="TUNE_CLOUD_SYNC_S", default=300)
|
||||
|
||||
# Syncing period for syncing worker logs to driver.
|
||||
NODE_SYNC_PERIOD = 300
|
||||
|
||||
_log_sync_warned = False
|
||||
_syncers = {}
|
||||
@@ -70,12 +77,23 @@ class Syncer:
|
||||
self.last_sync_down_time = float("-inf")
|
||||
self.sync_client = sync_client
|
||||
|
||||
def sync_up_if_needed(self):
|
||||
if time.time() - self.last_sync_up_time > SYNC_PERIOD:
|
||||
def sync_up_if_needed(self, sync_period):
|
||||
"""Syncs up if time since last sync up is greather than sync_period.
|
||||
|
||||
Arguments:
|
||||
sync_period (int): Time period between subsequent syncs.
|
||||
"""
|
||||
|
||||
if time.time() - self.last_sync_up_time > sync_period:
|
||||
self.sync_up()
|
||||
|
||||
def sync_down_if_needed(self):
|
||||
if time.time() - self.last_sync_down_time > SYNC_PERIOD:
|
||||
def sync_down_if_needed(self, sync_period):
|
||||
"""Syncs down if time since last sync down is greather than sync_period.
|
||||
|
||||
Arguments:
|
||||
sync_period (int): Time period between subsequent syncs.
|
||||
"""
|
||||
if time.time() - self.last_sync_down_time > sync_period:
|
||||
self.sync_down()
|
||||
|
||||
def sync_up(self):
|
||||
@@ -131,6 +149,19 @@ class Syncer:
|
||||
return self._remote_dir
|
||||
|
||||
|
||||
class CloudSyncer(Syncer):
|
||||
"""Syncer for syncing files to/from the cloud."""
|
||||
|
||||
def __init__(self, local_dir, remote_dir, sync_client):
|
||||
super(CloudSyncer, self).__init__(local_dir, remote_dir, sync_client)
|
||||
|
||||
def sync_up_if_needed(self):
|
||||
return super(CloudSyncer, self).sync_up_if_needed(CLOUD_SYNC_PERIOD)
|
||||
|
||||
def sync_down_if_needed(self):
|
||||
return super(CloudSyncer, self).sync_down_if_needed(CLOUD_SYNC_PERIOD)
|
||||
|
||||
|
||||
class NodeSyncer(Syncer):
|
||||
"""Syncer for syncing files to/from a remote dir to a local dir."""
|
||||
|
||||
@@ -158,12 +189,12 @@ class NodeSyncer(Syncer):
|
||||
def sync_up_if_needed(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncer, self).sync_up_if_needed()
|
||||
return super(NodeSyncer, self).sync_up_if_needed(NODE_SYNC_PERIOD)
|
||||
|
||||
def sync_down_if_needed(self):
|
||||
if not self.has_remote_target():
|
||||
return True
|
||||
return super(NodeSyncer, self).sync_down_if_needed()
|
||||
return super(NodeSyncer, self).sync_down_if_needed(NODE_SYNC_PERIOD)
|
||||
|
||||
def sync_up_to_new_location(self, worker_ip):
|
||||
if worker_ip != self.worker_ip:
|
||||
@@ -226,16 +257,16 @@ def get_cloud_syncer(local_dir, remote_dir=None, sync_function=None):
|
||||
return _syncers[key]
|
||||
|
||||
if not remote_dir:
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, NOOP)
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, NOOP)
|
||||
return _syncers[key]
|
||||
|
||||
client = get_sync_client(sync_function)
|
||||
|
||||
if client:
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, client)
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, client)
|
||||
return _syncers[key]
|
||||
sync_client = get_cloud_sync_client(remote_dir)
|
||||
_syncers[key] = Syncer(local_dir, remote_dir, sync_client)
|
||||
_syncers[key] = CloudSyncer(local_dir, remote_dir, sync_client)
|
||||
return _syncers[key]
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ray.tune.error import TuneError
|
||||
from ray.tune.ray_trial_executor import RayTrialExecutor
|
||||
from ray.tune.resources import Resources
|
||||
from ray.tune.suggest import BasicVariantGenerator
|
||||
from ray.tune.syncer import Syncer
|
||||
from ray.tune.syncer import CloudSyncer
|
||||
from ray.tune.trainable import TrainableUtil
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
@@ -521,7 +521,7 @@ def test_cluster_down_full(start_connected_cluster, tmpdir, trainable_id):
|
||||
|
||||
mock_get_client = "ray.tune.trial_runner.get_cloud_syncer"
|
||||
with patch(mock_get_client) as mock_get_cloud_syncer:
|
||||
mock_syncer = Syncer(local_dir, upload_dir, mock_storage_client())
|
||||
mock_syncer = CloudSyncer(local_dir, upload_dir, mock_storage_client())
|
||||
mock_get_cloud_syncer.return_value = mock_syncer
|
||||
|
||||
tune.run_experiments(all_experiments, raise_on_failed_trial=False)
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -151,6 +152,38 @@ class TestSyncFunctionality(unittest.TestCase):
|
||||
shutil.rmtree(tmpdir)
|
||||
shutil.rmtree(tmpdir2)
|
||||
|
||||
@patch("ray.tune.sync_client.S3_PREFIX", "test")
|
||||
def testCloudSyncPeriod(self):
|
||||
"""Tests that changing CLOUD_SYNC_PERIOD affects syncing frequency."""
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
|
||||
def trainable(config):
|
||||
for i in range(10):
|
||||
time.sleep(1)
|
||||
tune.report(score=i)
|
||||
|
||||
mock = unittest.mock.Mock()
|
||||
|
||||
def counter(local, remote):
|
||||
mock()
|
||||
|
||||
tune.syncer.CLOUD_SYNC_PERIOD = 1
|
||||
[trial] = tune.run(
|
||||
trainable,
|
||||
name="foo",
|
||||
max_failures=0,
|
||||
local_dir=tmpdir,
|
||||
upload_dir="test",
|
||||
sync_to_cloud=counter,
|
||||
stop={
|
||||
"training_iteration": 10
|
||||
},
|
||||
global_checkpoint_period=0.5,
|
||||
).trials
|
||||
|
||||
self.assertEqual(mock.call_count, 12)
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
def testClusterSyncFunction(self):
|
||||
def sync_func_driver(source, target):
|
||||
assert ":" in source, "Source {} not a remote path.".format(source)
|
||||
|
||||
@@ -144,7 +144,9 @@ def run(run_or_experiment,
|
||||
from upload_dir. If string, then it must be a string template that
|
||||
includes `{source}` and `{target}` for the syncer to run. If not
|
||||
provided, the sync command defaults to standard S3 or gsutil sync
|
||||
commands.
|
||||
commands. By default local_dir is synced to remote_dir every 300
|
||||
seconds. To change this, set the TUNE_CLOUD_SYNC_S
|
||||
environment variable in the driver machine.
|
||||
sync_to_driver (func|str|bool): Function for syncing trial logdir from
|
||||
remote node to local. If string, then it must be a string template
|
||||
that includes `{source}` and `{target}` for the syncer to run.
|
||||
|
||||
Reference in New Issue
Block a user