Files
ray/python/ray/tune/tests/test_api.py
T
2020-09-14 17:11:20 +00:00

1178 lines
39 KiB
Python

from collections import Counter
import shutil
import tempfile
import copy
import os
import time
import unittest
from unittest.mock import patch
import ray
from ray.rllib import _register_all
from ray import tune
from ray.tune import (DurableTrainable, Trainable, TuneError, Stopper,
EarlyStopping)
from ray.tune import register_env, register_trainable, run_experiments
from ray.tune.schedulers import (TrialScheduler, FIFOScheduler,
AsyncHyperBandScheduler)
from ray.tune.trial import Trial
from ray.tune.result import (TIMESTEPS_TOTAL, DONE, HOSTNAME, NODE_IP, PID,
EPISODES_TOTAL, TRAINING_ITERATION,
TIMESTEPS_THIS_ITER, TIME_THIS_ITER_S,
TIME_TOTAL_S, TRIAL_ID, EXPERIMENT_TAG)
from ray.tune.logger import Logger
from ray.tune.experiment import Experiment
from ray.tune.resources import Resources
from ray.tune.suggest import grid_search
from ray.tune.suggest.hyperopt import HyperOptSearch
from ray.tune.suggest.ax import AxSearch
from ray.tune.suggest._mock import _MockSuggestionAlgorithm
from ray.tune.utils import (flatten_dict, get_pinned_object,
pin_in_object_store)
from ray.tune.utils.mock import mock_storage_client, MOCK_REMOTE_DIR
class TrainableFunctionApiTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=4, num_gpus=0, object_store_memory=150 * 1024 * 1024)
self.tmpdir = tempfile.mkdtemp()
def tearDown(self):
ray.shutdown()
_register_all() # re-register the evicted objects
shutil.rmtree(self.tmpdir)
def checkAndReturnConsistentLogs(self, results, sleep_per_iter=None):
"""Checks logging is the same between APIs.
Ignore "DONE" for logging but checks that the
scheduler is notified properly with the last result.
"""
class_results = copy.deepcopy(results)
function_results = copy.deepcopy(results)
class_output = []
function_output = []
scheduler_notif = []
class MockScheduler(FIFOScheduler):
def on_trial_complete(self, runner, trial, result):
scheduler_notif.append(result)
class ClassAPILogger(Logger):
def on_result(self, result):
class_output.append(result)
class FunctionAPILogger(Logger):
def on_result(self, result):
function_output.append(result)
class _WrappedTrainable(Trainable):
def setup(self, config):
del config
self._result_iter = copy.deepcopy(class_results)
def step(self):
if sleep_per_iter:
time.sleep(sleep_per_iter)
res = self._result_iter.pop(0) # This should not fail
if not self._result_iter: # Mark "Done" for last result
res[DONE] = True
return res
def _function_trainable(config, reporter):
for result in function_results:
if sleep_per_iter:
time.sleep(sleep_per_iter)
reporter(**result)
class_trainable_name = "class_trainable"
register_trainable(class_trainable_name, _WrappedTrainable)
trials = run_experiments(
{
"function_api": {
"run": _function_trainable,
"loggers": [FunctionAPILogger],
},
"class_api": {
"run": class_trainable_name,
"loggers": [ClassAPILogger],
},
},
raise_on_failed_trial=False,
scheduler=MockScheduler())
# Ignore these fields
NO_COMPARE_FIELDS = {
HOSTNAME,
NODE_IP,
TRIAL_ID,
EXPERIMENT_TAG,
PID,
TIME_THIS_ITER_S,
TIME_TOTAL_S,
DONE, # This is ignored because FunctionAPI has different handling
"timestamp",
"time_since_restore",
"experiment_id",
"date",
}
self.assertEqual(len(class_output), len(results))
self.assertEqual(len(function_output), len(results))
def as_comparable_result(result):
return {
k: v
for k, v in result.items() if k not in NO_COMPARE_FIELDS
}
function_comparable = [
as_comparable_result(result) for result in function_output
]
class_comparable = [
as_comparable_result(result) for result in class_output
]
self.assertEqual(function_comparable, class_comparable)
self.assertEqual(sum(t.get(DONE) for t in scheduler_notif), 2)
self.assertEqual(
as_comparable_result(scheduler_notif[0]),
as_comparable_result(scheduler_notif[1]))
# Make sure the last result is the same.
self.assertEqual(
as_comparable_result(trials[0].last_result),
as_comparable_result(trials[1].last_result))
return function_output, trials
def testPinObject(self):
X = pin_in_object_store("hello")
@ray.remote
def f():
return get_pinned_object(X)
self.assertEqual(ray.get(f.remote()), "hello")
def testFetchPinned(self):
X = pin_in_object_store("hello")
def train(config, reporter):
get_pinned_object(X)
reporter(timesteps_total=100, done=True)
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100)
def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
def testRegisterEnvOverwrite(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
def train2(config, reporter):
reporter(timesteps_total=200, done=True)
register_trainable("f1", train)
register_trainable("f1", train2)
[trial] = run_experiments({
"foo": {
"run": "f1",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 200)
def testRegisterTrainable(self):
def train(config, reporter):
pass
class A:
pass
class B(Trainable):
pass
register_trainable("foo", train)
Experiment("test", train)
register_trainable("foo", B)
Experiment("test", B)
self.assertRaises(TypeError, lambda: register_trainable("foo", B()))
self.assertRaises(TuneError, lambda: Experiment("foo", B()))
self.assertRaises(TypeError, lambda: register_trainable("foo", A))
self.assertRaises(TypeError, lambda: Experiment("foo", A))
def testTrainableCallable(self):
def dummy_fn(config, reporter, steps):
reporter(timesteps_total=steps, done=True)
from functools import partial
steps = 500
register_trainable("test", partial(dummy_fn, steps=steps))
[trial] = run_experiments({
"foo": {
"run": "test",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
[trial] = tune.run(partial(dummy_fn, steps=steps)).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], steps)
def testBuiltInTrainableResources(self):
class B(Trainable):
@classmethod
def default_resource_request(cls, config):
return Resources(cpu=config["cpu"], gpu=config["gpu"])
def step(self):
return {"timesteps_this_iter": 1, "done": True}
register_trainable("B", B)
def f(cpus, gpus, queue_trials):
return run_experiments(
{
"foo": {
"run": "B",
"config": {
"cpu": cpus,
"gpu": gpus,
},
}
},
queue_trials=queue_trials)[0]
# Should all succeed
self.assertEqual(f(0, 0, False).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
self.assertEqual(f(1, 0, True).status, Trial.TERMINATED)
# Too large resource request
self.assertRaises(TuneError, lambda: f(100, 100, False))
self.assertRaises(TuneError, lambda: f(0, 100, False))
self.assertRaises(TuneError, lambda: f(100, 0, False))
# TODO(ekl) how can we test this is queued (hangs)?
# f(100, 0, True)
def testRewriteEnv(self):
def train(config, reporter):
reporter(timesteps_total=1)
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
"env": "CartPole-v0",
}
})
self.assertEqual(trial.config["env"], "CartPole-v0")
def testConfigPurity(self):
def train(config, reporter):
assert config == {"a": "b"}, config
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({
"foo": {
"run": "f1",
"config": {
"a": "b"
},
}
})
def testLogdir(self):
def train(config, reporter):
assert os.path.join(ray.utils.get_user_temp_dir(), "logdir",
"foo") in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({
"foo": {
"run": "f1",
"local_dir": os.path.join(ray.utils.get_user_temp_dir(),
"logdir"),
"config": {
"a": "b"
},
}
})
def testLogdirStartingWithTilde(self):
local_dir = "~/ray_results/local_dir"
def train(config, reporter):
cwd = os.getcwd()
assert cwd.startswith(os.path.expanduser(local_dir)), cwd
assert not cwd.startswith("~"), cwd
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({
"foo": {
"run": "f1",
"local_dir": local_dir,
"config": {
"a": "b"
},
}
})
def testLongFilename(self):
def train(config, reporter):
assert os.path.join(ray.utils.get_user_temp_dir(), "logdir",
"foo") in os.getcwd(), os.getcwd()
reporter(timesteps_total=1)
register_trainable("f1", train)
run_experiments({
"foo": {
"run": "f1",
"local_dir": os.path.join(ray.utils.get_user_temp_dir(),
"logdir"),
"config": {
"a" * 50: tune.sample_from(lambda spec: 5.0 / 7),
"b" * 50: tune.sample_from(lambda spec: "long" * 40),
},
}
})
def testBadParams(self):
def f():
run_experiments({"foo": {}})
self.assertRaises(TuneError, f)
def testBadParams2(self):
def f():
run_experiments({
"foo": {
"run": "asdf",
"bah": "this param is not allowed",
}
})
self.assertRaises(TuneError, f)
def testBadParams3(self):
def f():
run_experiments({
"foo": {
"run": grid_search("invalid grid search"),
}
})
self.assertRaises(TuneError, f)
def testBadParams4(self):
def f():
run_experiments({
"foo": {
"run": "asdf",
}
})
self.assertRaises(TuneError, f)
def testBadParams5(self):
def f():
run_experiments({"foo": {"run": "PPO", "stop": {"asdf": 1}}})
self.assertRaises(TuneError, f)
def testBadParams6(self):
def f():
run_experiments({
"foo": {
"run": "PPO",
"resources_per_trial": {
"asdf": 1
}
}
})
self.assertRaises(TuneError, f)
def testBadStoppingReturn(self):
def train(config, reporter):
reporter()
register_trainable("f1", train)
def f():
run_experiments({
"foo": {
"run": "f1",
"stop": {
"time": 10
},
}
})
self.assertRaises(TuneError, f)
def testNestedStoppingReturn(self):
def train(config, reporter):
for i in range(10):
reporter(test={"test1": {"test2": i}})
with self.assertRaises(TuneError):
[trial] = tune.run(
train, stop={
"test": {
"test1": {
"test2": 6
}
}
}).trials
[trial] = tune.run(train, stop={"test/test1/test2": 6}).trials
self.assertEqual(trial.last_result["training_iteration"], 7)
def testStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
def stop(trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStoppingMemberFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class Stopclass:
def stop(self, trial_id, result):
return result["test"] > 6
[trial] = tune.run(train, stop=Stopclass().stop).trials
self.assertEqual(trial.last_result["training_iteration"], 8)
def testStopper(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class CustomStopper(Stopper):
def __init__(self):
self._count = 0
def __call__(self, trial_id, result):
print("called")
self._count += 1
return result["test"] > 6
def stop_all(self):
return self._count > 5
trials = tune.run(train, num_samples=5, stop=CustomStopper()).trials
self.assertTrue(all(t.status == Trial.TERMINATED for t in trials))
self.assertTrue(
any(
t.last_result.get("training_iteration") is None
for t in trials))
def testEarlyStopping(self):
def train(config, reporter):
reporter(test=0)
top = 3
with self.assertRaises(ValueError):
EarlyStopping("test", top=0)
with self.assertRaises(ValueError):
EarlyStopping("test", top="0")
with self.assertRaises(ValueError):
EarlyStopping("test", std=0)
with self.assertRaises(ValueError):
EarlyStopping("test", patience=-1)
with self.assertRaises(ValueError):
EarlyStopping("test", std="0")
with self.assertRaises(ValueError):
EarlyStopping("test", mode="0")
stopper = EarlyStopping("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= top)
patience = 5
stopper = EarlyStopping("test", top=top, mode="min", patience=patience)
analysis = tune.run(train, num_samples=20, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= patience)
stopper = EarlyStopping("test", top=top, mode="min")
analysis = tune.run(train, num_samples=10, stop=stopper)
self.assertTrue(
all(t.status == Trial.TERMINATED for t in analysis.trials))
self.assertTrue(
len(analysis.dataframe(metric="test", mode="max")) <= top)
def testBadStoppingFunction(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
class CustomStopper:
def stop(self, result):
return result["test"] > 6
def stop(result):
return result["test"] > 6
with self.assertRaises(TuneError):
tune.run(train, stop=CustomStopper().stop)
with self.assertRaises(TuneError):
tune.run(train, stop=stop)
def testCustomTrialDir(self):
def train(config):
for i in range(10):
tune.report(test=i)
custom_name = "TRAIL_TRIAL"
def custom_trial_dir(trial):
return custom_name
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
trial_dirname_creator=custom_trial_dir,
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
assert len(logdirs) == 3
assert all(custom_name in dirpath for dirpath in logdirs)
def testTrialDirRegression(self):
def train(config, reporter):
for i in range(10):
reporter(test=i)
trials = tune.run(
train,
config={
"t1": tune.grid_search([1, 2, 3])
},
local_dir=self.tmpdir).trials
logdirs = {t.logdir for t in trials}
for i in [1, 2, 3]:
assert any(f"t1={i}" in dirpath for dirpath in logdirs)
for t in trials:
assert any(t.trainable_name in dirpath for dirpath in logdirs)
def testEarlyReturn(self):
def train(config, reporter):
reporter(timesteps_total=100, done=True)
time.sleep(99999)
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 100)
def testReporterNoUsage(self):
def run_task(config, reporter):
print("hello")
experiment = Experiment(run=run_task, name="ray_crash_repro")
[trial] = ray.tune.run(experiment).trials
print(trial.last_result)
self.assertEqual(trial.last_result[DONE], True)
def testRerun(self):
tmpdir = tempfile.mkdtemp()
self.addCleanup(lambda: shutil.rmtree(tmpdir))
def test(config):
tid = config["id"]
fail = config["fail"]
marker = os.path.join(tmpdir, f"t{tid}-{fail}.log")
if not os.path.exists(marker) and fail:
open(marker, "w").close()
raise ValueError
for i in range(10):
time.sleep(0.1)
tune.report(hello=123)
config = dict(
name="hi-2",
config={
"fail": tune.grid_search([True, False]),
"id": tune.grid_search(list(range(5)))
},
verbose=1,
local_dir=tmpdir,
loggers=None)
trials = tune.run(test, raise_on_failed_trial=False, **config).trials
self.assertEqual(Counter(t.status for t in trials)["ERROR"], 5)
new_trials = tune.run(test, resume="ERRORED_ONLY", **config).trials
self.assertEqual(Counter(t.status for t in new_trials)["ERROR"], 0)
self.assertTrue(
all(t.last_result.get("hello") == 123 for t in new_trials))
def testErrorReturn(self):
def train(config, reporter):
raise Exception("uh oh")
register_trainable("f1", train)
def f():
run_experiments({
"foo": {
"run": "f1",
}
})
self.assertRaises(TuneError, f)
def testSuccess(self):
def train(config, reporter):
for i in range(100):
reporter(timesteps_total=i)
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TIMESTEPS_TOTAL], 99)
def testNoRaiseFlag(self):
def train(config, reporter):
raise Exception()
register_trainable("f1", train)
[trial] = run_experiments(
{
"foo": {
"run": "f1",
}
}, raise_on_failed_trial=False)
self.assertEqual(trial.status, Trial.ERROR)
def testReportInfinity(self):
def train(config, reporter):
for _ in range(100):
reporter(mean_accuracy=float("inf"))
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["mean_accuracy"], float("inf"))
def testTrialInfoAccess(self):
class TestTrainable(Trainable):
def step(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
print(result)
return result
analysis = tune.run(TestTrainable, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def testTrialInfoAccessFunction(self):
def train(config, reporter):
reporter(name=reporter.trial_name, trial_id=reporter.trial_id)
analysis = tune.run(train, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
def track_train(config):
tune.report(
name=tune.get_trial_name(), trial_id=tune.get_trial_id())
analysis = tune.run(track_train, stop={TRAINING_ITERATION: 1})
trial = analysis.trials[0]
self.assertEqual(trial.last_result.get("name"), str(trial))
self.assertEqual(trial.last_result.get("trial_id"), trial.trial_id)
@patch("ray.tune.ray_trial_executor.TRIAL_CLEANUP_THRESHOLD", 3)
def testLotsOfStops(self):
class TestTrainable(Trainable):
def step(self):
result = {"name": self.trial_name, "trial_id": self.trial_id}
return result
def cleanup(self):
time.sleep(2)
open(os.path.join(self.logdir, "marker"), "a").close()
return 1
analysis = tune.run(
TestTrainable, num_samples=10, stop={TRAINING_ITERATION: 1})
ray.shutdown()
for trial in analysis.trials:
path = os.path.join(trial.logdir, "marker")
assert os.path.exists(path)
def testNestedResults(self):
def create_result(i):
return {"test": {"1": {"2": {"3": i, "4": False}}}}
flattened_keys = list(flatten_dict(create_result(0)))
class _MockScheduler(FIFOScheduler):
results = []
def on_trial_result(self, trial_runner, trial, result):
self.results += [result]
return TrialScheduler.CONTINUE
def on_trial_complete(self, trial_runner, trial, result):
self.complete_result = result
def train(config, reporter):
for i in range(100):
reporter(**create_result(i))
algo = _MockSuggestionAlgorithm()
scheduler = _MockScheduler()
[trial] = tune.run(
train,
scheduler=scheduler,
search_alg=algo,
stop={
"test/1/2/3": 20
}).trials
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result["test"]["1"]["2"]["3"], 20)
self.assertEqual(trial.last_result["test"]["1"]["2"]["4"], False)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 21)
self.assertEqual(len(scheduler.results), 20)
self.assertTrue(
all(
set(result) >= set(flattened_keys)
for result in scheduler.results))
self.assertTrue(set(scheduler.complete_result) >= set(flattened_keys))
self.assertEqual(len(algo.results), 20)
self.assertTrue(
all(set(result) >= set(flattened_keys) for result in algo.results))
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"1/2/3": 20})
with self.assertRaises(TuneError):
[trial] = tune.run(train, stop={"test": 1}).trials
def testReportTimeStep(self):
# Test that no timestep count are logged if never the Trainable never
# returns any.
results1 = [dict(mean_accuracy=5, done=i == 99) for i in range(100)]
logs1, _ = self.checkAndReturnConsistentLogs(results1)
self.assertTrue(all(log[TIMESTEPS_TOTAL] is None for log in logs1))
# Test that no timesteps_this_iter are logged if only timesteps_total
# are returned.
results2 = [dict(timesteps_total=5, done=i == 9) for i in range(10)]
logs2, _ = self.checkAndReturnConsistentLogs(results2)
# Re-run the same trials but with added delay. This is to catch some
# inconsistent timestep counting that was present in the multi-threaded
# FunctionRunner. This part of the test can be removed once the
# multi-threaded FunctionRunner is removed from ray/tune.
# TODO: remove once the multi-threaded function runner is gone.
logs2, _ = self.checkAndReturnConsistentLogs(results2, 0.5)
# check all timesteps_total report the same value
self.assertTrue(all(log[TIMESTEPS_TOTAL] == 5 for log in logs2))
# check that none of the logs report timesteps_this_iter
self.assertFalse(
any(hasattr(log, TIMESTEPS_THIS_ITER) for log in logs2))
# Test that timesteps_total and episodes_total are reported when
# timesteps_this_iter and episodes_this_iter despite only return zeros.
results3 = [
dict(timesteps_this_iter=0, episodes_this_iter=0)
for i in range(10)
]
logs3, _ = self.checkAndReturnConsistentLogs(results3)
self.assertTrue(all(log[TIMESTEPS_TOTAL] == 0 for log in logs3))
self.assertTrue(all(log[EPISODES_TOTAL] == 0 for log in logs3))
# Test that timesteps_total and episodes_total are properly counted
# when timesteps_this_iter and episodes_this_iter report non-zero
# values.
results4 = [
dict(timesteps_this_iter=3, episodes_this_iter=i)
for i in range(10)
]
logs4, _ = self.checkAndReturnConsistentLogs(results4)
# The last reported result should not be double-logged.
self.assertEqual(logs4[-1][TIMESTEPS_TOTAL], 30)
self.assertNotEqual(logs4[-2][TIMESTEPS_TOTAL],
logs4[-1][TIMESTEPS_TOTAL])
self.assertEqual(logs4[-1][EPISODES_TOTAL], 45)
self.assertNotEqual(logs4[-2][EPISODES_TOTAL],
logs4[-1][EPISODES_TOTAL])
def testAllValuesReceived(self):
results1 = [
dict(timesteps_total=(i + 1), my_score=i**2, done=i == 4)
for i in range(5)
]
logs1, _ = self.checkAndReturnConsistentLogs(results1)
# check if the correct number of results were reported
self.assertEqual(len(logs1), len(results1))
def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)
# check that no result was dropped or modified
complete_results = [
check_no_missing(log, result)
for log, result in zip(logs1, results1)
]
self.assertTrue(all(complete_results))
# check if done was logged exactly once
self.assertEqual(len([r for r in logs1 if r.get("done")]), 1)
def testNoDoneReceived(self):
# repeat same test but without explicitly reporting done=True
results1 = [
dict(timesteps_total=(i + 1), my_score=i**2) for i in range(5)
]
logs1, trials = self.checkAndReturnConsistentLogs(results1)
# check if the correct number of results were reported.
self.assertEqual(len(logs1), len(results1))
def check_no_missing(reported_result, result):
common_results = [reported_result[k] == result[k] for k in result]
return all(common_results)
# check that no result was dropped or modified
complete_results1 = [
check_no_missing(log, result)
for log, result in zip(logs1, results1)
]
self.assertTrue(all(complete_results1))
def testDurableTrainable(self):
class TestTrain(DurableTrainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
sync_client = mock_storage_client()
mock_get_client = "ray.tune.durable_trainable.get_cloud_sync_client"
with patch(mock_get_client) as mock_get_cloud_sync_client:
mock_get_cloud_sync_client.return_value = sync_client
test_trainable = TestTrain(remote_checkpoint_dir=MOCK_REMOTE_DIR)
checkpoint_path = test_trainable.save()
test_trainable.train()
test_trainable.state["hi"] = 2
test_trainable.restore(checkpoint_path)
self.assertEqual(test_trainable.state["hi"], 1)
self.addCleanup(shutil.rmtree, MOCK_REMOTE_DIR)
def testCheckpointDict(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1}
def step(self):
return {"timesteps_this_iter": 1, "done": True}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
result = test_trainable.save()
test_trainable.state["hi"] = 2
test_trainable.restore(result)
self.assertEqual(test_trainable.state["hi"], 1)
trials = run_experiments({
"foo": {
"run": TestTrain,
"checkpoint_at_end": True
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testMultipleCheckpoints(self):
class TestTrain(Trainable):
def setup(self, config):
self.state = {"hi": 1, "iter": 0}
def step(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def save_checkpoint(self, path):
return self.state
def load_checkpoint(self, state):
self.state = state
test_trainable = TestTrain()
checkpoint_1 = test_trainable.save()
test_trainable.train()
checkpoint_2 = test_trainable.save()
self.assertNotEqual(checkpoint_1, checkpoint_2)
test_trainable.restore(checkpoint_2)
self.assertEqual(test_trainable.state["iter"], 1)
test_trainable.restore(checkpoint_1)
self.assertEqual(test_trainable.state["iter"], 0)
trials = run_experiments({
"foo": {
"run": TestTrain,
"checkpoint_at_end": True
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testIterationCounter(self):
def train(config, reporter):
for i in range(100):
reporter(itr=i, timesteps_this_iter=1)
register_trainable("exp", train)
config = {
"my_exp": {
"run": "exp",
"config": {
"iterations": 100,
},
"stop": {
"timesteps_total": 100
},
}
}
[trial] = run_experiments(config)
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result[TRAINING_ITERATION], 100)
self.assertEqual(trial.last_result["itr"], 99)
def testBackwardsCompat(self):
class TestTrain(Trainable):
def _setup(self, config):
self.state = {"hi": 1, "iter": 0}
def _train(self):
self.state["iter"] += 1
return {"timesteps_this_iter": 1, "done": True}
def _save(self, path):
return self.state
def _restore(self, state):
self.state = state
test_trainable = TestTrain()
checkpoint_1 = test_trainable.save()
test_trainable.train()
checkpoint_2 = test_trainable.save()
self.assertNotEqual(checkpoint_1, checkpoint_2)
test_trainable.restore(checkpoint_2)
self.assertEqual(test_trainable.state["iter"], 1)
test_trainable.restore(checkpoint_1)
self.assertEqual(test_trainable.state["iter"], 0)
trials = run_experiments({
"foo": {
"run": TestTrain,
"checkpoint_at_end": True
}
})
for trial in trials:
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertTrue(trial.has_checkpoint())
def testLogToFile(self):
def train(config, reporter):
import sys
from ray import logger
for i in range(10):
reporter(timesteps_total=i)
print("PRINT_STDOUT")
print("PRINT_STDERR", file=sys.stderr)
logger.info("LOG_STDERR")
register_trainable("f1", train)
# Do not log to file
[trial] = tune.run("f1", log_to_file=False).trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
# Log to default files
[trial] = tune.run("f1", log_to_file=True).trials
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "stderr")))
with open(os.path.join(trial.logdir, "stdout"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
with open(os.path.join(trial.logdir, "stderr"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
# Log to one file
[trial] = tune.run("f1", log_to_file="combined").trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
self.assertTrue(os.path.exists(os.path.join(trial.logdir, "combined")))
with open(os.path.join(trial.logdir, "combined"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
# Log to two files
[trial] = tune.run(
"f1", log_to_file=("alt.stdout", "alt.stderr")).trials
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stdout")))
self.assertFalse(os.path.exists(os.path.join(trial.logdir, "stderr")))
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "alt.stdout")))
self.assertTrue(
os.path.exists(os.path.join(trial.logdir, "alt.stderr")))
with open(os.path.join(trial.logdir, "alt.stdout"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDOUT", content)
with open(os.path.join(trial.logdir, "alt.stderr"), "rt") as fp:
content = fp.read()
self.assertIn("PRINT_STDERR", content)
self.assertIn("LOG_STDERR", content)
def testTimeout(self):
from ray.tune.stopper import TimeoutStopper
import datetime
def train(config):
for i in range(20):
tune.report(metric=i)
time.sleep(1)
register_trainable("f1", train)
start = time.time()
tune.run("f1", time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Metric should fire first
start = time.time()
tune.run("f1", stop={"metric": 3}, time_budget_s=7)
diff = time.time() - start
self.assertLess(diff, 7)
# Timeout should fire first
start = time.time()
tune.run("f1", stop={"metric": 10}, time_budget_s=5)
diff = time.time() - start
self.assertLess(diff, 10)
# Combined stopper. Shorter timeout should win.
start = time.time()
tune.run(
"f1",
stop=TimeoutStopper(10),
time_budget_s=datetime.timedelta(seconds=3))
diff = time.time() - start
self.assertLess(diff, 9)
class ShimCreationTest(unittest.TestCase):
def testCreateScheduler(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
scheduler = "async_hyperband"
shim_scheduler = tune.create_scheduler(scheduler, **kwargs)
real_scheduler = AsyncHyperBandScheduler(**kwargs)
assert type(shim_scheduler) is type(real_scheduler)
def testCreateSearcher(self):
kwargs = {"metric": "metric_foo", "mode": "min"}
searcher_ax = "ax"
shim_searcher_ax = tune.create_searcher(searcher_ax, **kwargs)
real_searcher_ax = AxSearch(space=[], **kwargs)
assert type(shim_searcher_ax) is type(real_searcher_ax)
searcher_hyperopt = "hyperopt"
shim_searcher_hyperopt = tune.create_searcher(searcher_hyperopt,
**kwargs)
real_searcher_hyperopt = HyperOptSearch({}, **kwargs)
assert type(shim_searcher_hyperopt) is type(real_searcher_hyperopt)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))