[tune] Distributed example + walkthrough (#5157)

This commit is contained in:
Richard Liaw
2019-08-02 09:17:20 -07:00
committed by GitHub
parent 13fb9fe3db
commit 1eaa57c98f
28 changed files with 990 additions and 396 deletions
+62 -42
View File
@@ -47,38 +47,6 @@ class Analysis(object):
self._trial_dataframes = {}
self.fetch_trial_dataframes()
def fetch_trial_dataframes(self):
fail_count = 0
for path in self._get_trial_paths():
try:
self.trial_dataframes[path] = pd.read_csv(
os.path.join(path, EXPR_PROGRESS_FILE))
except Exception:
fail_count += 1
if fail_count:
logger.debug(
"Couldn't read results from {} paths".format(fail_count))
return self.trial_dataframes
def get_all_configs(self, prefix=False):
fail_count = 0
for path in self._get_trial_paths():
try:
with open(os.path.join(path, EXPR_PARAM_FILE)) as f:
config = json.load(f)
if prefix:
for k in list(config):
config["config:" + k] = config.pop(k)
self._configs[path] = config
except Exception:
fail_count += 1
if fail_count:
logger.warning(
"Couldn't read config from {} paths".format(fail_count))
return self._configs
def dataframe(self, metric=None, mode=None):
"""Returns a pandas.DataFrame object constructed from the trials.
@@ -110,6 +78,58 @@ class Analysis(object):
best_path = compare_op(rows, key=lambda k: rows[k][metric])
return all_configs[best_path]
def get_best_logdir(self, metric, mode="max"):
"""Retrieve the logdir corresponding to the best trial.
Args:
metric (str): Key for trial info to order on.
mode (str): One of [min, max].
"""
df = self.dataframe()
if mode == "max":
return df.iloc[df[metric].idxmax()].logdir
elif mode == "min":
return df.iloc[df[metric].idxmin()].logdir
def fetch_trial_dataframes(self):
fail_count = 0
for path in self._get_trial_paths():
try:
self.trial_dataframes[path] = pd.read_csv(
os.path.join(path, EXPR_PROGRESS_FILE))
except Exception:
fail_count += 1
if fail_count:
logger.debug(
"Couldn't read results from {} paths".format(fail_count))
return self.trial_dataframes
def get_all_configs(self, prefix=False):
"""Returns a list of all configurations.
Parameters:
prefix (bool): If True, flattens the config dict
and prepends `config/`.
"""
fail_count = 0
for path in self._get_trial_paths():
try:
with open(os.path.join(path, EXPR_PARAM_FILE)) as f:
config = json.load(f)
if prefix:
for k in list(config):
config["config/" + k] = config.pop(k)
self._configs[path] = config
except Exception:
fail_count += 1
if fail_count:
logger.warning(
"Couldn't read config from {} paths".format(fail_count))
return self._configs
def _retrieve_rows(self, metric=None, mode=None):
assert mode is None or mode in ["max", "min"]
rows = {}
@@ -135,15 +155,9 @@ class Analysis(object):
self._experiment_dir))
return _trial_paths
def get_best_logdir(self, metric, mode="max"):
df = self.dataframe()
if mode == "max":
return df.iloc[df[metric].idxmax()].logdir
elif mode == "min":
return df.iloc[df[metric].idxmin()].logdir
@property
def trial_dataframes(self):
"""List of all dataframes of the trials."""
return self._trial_dataframes
@@ -189,9 +203,15 @@ class ExperimentAnalysis(Analysis):
def _get_trial_paths(self):
"""Overwrites Analysis to only have trials of one experiment."""
_trial_paths = [
checkpoint["logdir"] for checkpoint in self._checkpoints
]
if self.trials:
_trial_paths = [t.logdir for t in self.trials]
else:
logger.warning("No `self.trials`. Drawing logdirs from checkpoint "
"file. This may result in some information that is "
"out of sync, as checkpointing is periodic.")
_trial_paths = [
checkpoint["logdir"] for checkpoint in self._checkpoints
]
if not _trial_paths:
raise TuneError("No trials found.")
return _trial_paths
+1 -1
View File
@@ -148,7 +148,7 @@ def list_trials(experiment_path,
info_keys = DEFAULT_EXPERIMENT_INFO_KEYS
col_keys = [
k for k in checkpoints_df.columns
if k in info_keys or k.startswith("config:")
if k in info_keys or k.startswith("config/")
]
checkpoints_df = checkpoints_df[col_keys]
+25 -15
View File
@@ -4,8 +4,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import argparse
from filelock import FileLock
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -22,9 +24,9 @@ EPOCH_SIZE = 512
TEST_SIZE = 256
class Net(nn.Module):
def __init__(self, config):
super(Net, self).__init__()
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
@@ -35,7 +37,7 @@ class Net(nn.Module):
return F.log_softmax(x, dim=1)
def train(model, optimizer, train_loader, device):
def train(model, optimizer, train_loader, device=torch.device("cpu")):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if batch_idx * len(data) > EPOCH_SIZE:
@@ -48,7 +50,7 @@ def train(model, optimizer, train_loader, device):
optimizer.step()
def test(model, data_loader, device):
def test(model, data_loader, device=torch.device("cpu")):
model.eval()
correct = 0
total = 0
@@ -70,11 +72,18 @@ def get_data_loaders():
[transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data", train=True, download=True, transform=mnist_transforms),
batch_size=64,
shuffle=True)
# We add FileLock here because multiple workers will want to
# download data, and this may cause overwrites since
# DataLoader is not threadsafe.
with FileLock(os.path.expanduser("~/data.lock")):
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=True,
download=True,
transform=mnist_transforms),
batch_size=64,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST("~/data", train=False, transform=mnist_transforms),
batch_size=64,
@@ -86,7 +95,7 @@ def train_mnist(config):
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
train_loader, test_loader = get_data_loaders()
model = Net(config).to(device)
model = ConvNet().to(device)
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
@@ -112,24 +121,25 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.ray_redis_address:
ray.init(redis_address=args.ray_redis_address)
datasets.MNIST("~/data", train=True, download=True)
sched = AsyncHyperBandScheduler(
time_attr="training_iteration", metric="mean_accuracy")
tune.run(
analysis = tune.run(
train_mnist,
name="exp",
scheduler=sched,
stop={
"mean_accuracy": 0.98,
"training_iteration": 5 if args.smoke_test else 20
"training_iteration": 5 if args.smoke_test else 100
},
resources_per_trial={
"cpu": 2,
"gpu": int(args.cuda)
},
num_samples=1 if args.smoke_test else 10,
num_samples=1 if args.smoke_test else 50,
config={
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
"momentum": tune.uniform(0.1, 0.9),
"use_gpu": int(args.cuda)
})
print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))
@@ -5,12 +5,13 @@ from __future__ import print_function
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from ray.tune import Trainable
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.examples.mnist_pytorch import (train, test, get_data_loaders,
ConvNet)
# Change these values if you want the training to run quicker or slower.
EPOCH_SIZE = 512
@@ -19,155 +20,35 @@ TEST_SIZE = 256
# Training settings
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
parser.add_argument(
"--batch-size",
type=int,
default=64,
metavar="N",
help="input batch size for training (default: 64)")
parser.add_argument(
"--test-batch-size",
type=int,
default=1000,
metavar="N",
help="input batch size for testing (default: 1000)")
parser.add_argument(
"--epochs",
type=int,
default=1,
metavar="N",
help="number of epochs to train (default: 1)")
parser.add_argument(
"--lr",
type=float,
default=0.01,
metavar="LR",
help="learning rate (default: 0.01)")
parser.add_argument(
"--momentum",
type=float,
default=0.5,
metavar="M",
help="SGD momentum (default: 0.5)")
parser.add_argument(
"--no-cuda",
"--use-gpu",
action="store_true",
default=False,
help="disables CUDA training")
help="enables CUDA training")
parser.add_argument(
"--redis-address",
default=None,
type=str,
help="The Redis address of the cluster.")
parser.add_argument(
"--seed",
type=int,
default=1,
metavar="S",
help="random seed (default: 1)")
"--ray-redis-address", type=str, help="The Redis address of the cluster.")
parser.add_argument(
"--smoke-test", action="store_true", help="Finish quickly for testing")
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class TrainMNIST(Trainable):
# Below comments are for documentation purposes only.
# yapf: disable
# __trainable_example_begin__
class TrainMNIST(tune.Trainable):
def _setup(self, config):
args = config.pop("args", parser.parse_args([]))
vars(args).update(config)
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {"num_workers": 1, "pin_memory": True} if args.cuda else {}
self.train_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.batch_size,
shuffle=True,
**kwargs)
self.test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
"~/data",
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
])),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs)
self.model = Net()
if args.cuda:
self.model.cuda()
use_cuda = config.get("use_gpu") and torch.cuda.is_available()
self.device = torch.device("cuda" if use_cuda else "cpu")
self.train_loader, self.test_loader = get_data_loaders()
self.model = ConvNet().to(self.device)
self.optimizer = optim.SGD(
self.model.parameters(), lr=args.lr, momentum=args.momentum)
self.args = args
def _train_iteration(self):
self.model.train()
for batch_idx, (data, target) in enumerate(self.train_loader):
if batch_idx * len(data) > EPOCH_SIZE:
return
if self.args.cuda:
data, target = data.cuda(), target.cuda()
self.optimizer.zero_grad()
output = self.model(data)
loss = F.nll_loss(output, target)
loss.backward()
self.optimizer.step()
def _test(self):
self.model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for batch_idx, (data, target) in enumerate(self.test_loader):
if batch_idx * len(data) > TEST_SIZE:
break
if self.args.cuda:
data, target = data.cuda(), target.cuda()
output = self.model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction="sum").item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(
target.data.view_as(pred)).long().cpu().sum()
test_loss = test_loss / len(self.test_loader.dataset)
accuracy = correct.item() / len(self.test_loader.dataset)
return {"mean_loss": test_loss, "mean_accuracy": accuracy}
self.model.parameters(),
lr=config.get("lr", 0.01),
momentum=config.get("momentum", 0.9))
def _train(self):
self._train_iteration()
return self._test()
train(
self.model, self.optimizer, self.train_loader, device=self.device)
acc = test(self.model, self.test_loader, self.device)
return {"mean_accuracy": acc}
def _save(self, checkpoint_dir):
checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
@@ -178,34 +59,33 @@ class TrainMNIST(Trainable):
self.model.load_state_dict(torch.load(checkpoint_path))
# __trainable_example_end__
# yapf: enable
if __name__ == "__main__":
datasets.MNIST("~/data", train=True, download=True)
args = parser.parse_args()
import ray
from ray import tune
from ray.tune.schedulers import HyperBandScheduler
ray.init(redis_address=args.redis_address)
sched = HyperBandScheduler(
time_attr="training_iteration", metric="mean_loss", mode="min")
tune.run(
ray.init(redis_address=args.ray_redis_address)
sched = ASHAScheduler(metric="mean_accuracy")
analysis = tune.run(
TrainMNIST,
scheduler=sched,
**{
"stop": {
"mean_accuracy": 0.95,
"training_iteration": 1 if args.smoke_test else 20,
"training_iteration": 3 if args.smoke_test else 20,
},
"resources_per_trial": {
"cpu": 3,
"gpu": int(not args.no_cuda)
"gpu": int(args.use_gpu)
},
"num_samples": 1 if args.smoke_test else 20,
"checkpoint_at_end": True,
"checkpoint_freq": 3,
"config": {
"args": args,
"lr": tune.uniform(0.001, 0.1),
"momentum": tune.uniform(0.1, 0.9),
}
})
print("Best config is:", analysis.get_best_config(metric="mean_accuracy"))
@@ -0,0 +1,51 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: tune-example
# The minimum number of workers nodes to launch in addition to the head
# node. This number should be >= 0.
min_workers: 2
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers.
max_workers: 2
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# Availability zone(s), comma-separated, that nodes may be launched in.
# Nodes are currently spread between zones by a round-robin approach,
# however this implementation detail should not be relied upon.
availability_zone: us-west-2a,us-west-2b
# How Ray will authenticate with newly launched nodes.
# By default Ray creates a new private keypair, but you can also use your own.
auth:
ssh_user: ubuntu
# Provider-specific config for the head node, e.g. instance type.
head_node:
InstanceType: c5.xlarge
ImageId: ami-0b294f219d14e6a82 # Deep Learning AMI (Ubuntu) Version 21.0
# Provider-specific config for worker nodes, e.g. instance type.
worker_nodes:
InstanceType: c5.xlarge
ImageId: ami-0b294f219d14e6a82 # Deep Learning AMI (Ubuntu) Version 21.0
# Run workers on spot by default. Comment this out to use on-demand.
InstanceMarketOptions:
MarketType: spot
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
}
# List of shell commands to run to set up each node.
setup_commands:
- pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.8.0.dev3-cp36-cp36m-manylinux1_x86_64.whl
- pip install torch torchvision tabulate tensorboard filelock
+8 -3
View File
@@ -33,10 +33,11 @@ class StatusReporter(object):
>>> reporter(timesteps_this_iter=1)
"""
def __init__(self, result_queue, continue_semaphore):
def __init__(self, result_queue, continue_semaphore, logdir=None):
self._queue = result_queue
self._last_report_time = None
self._continue_semaphore = continue_semaphore
self._logdir = logdir
def __call__(self, **kwargs):
"""Report updated training status.
@@ -77,6 +78,10 @@ class StatusReporter(object):
def _start(self):
self._last_report_time = time.time()
@property
def logdir(self):
return self._logdir
class _RunnerThread(threading.Thread):
"""Supervisor thread that runs your script."""
@@ -131,8 +136,8 @@ class FunctionRunner(Trainable):
# reporting to block until finished.
self._error_queue = queue.Queue(1)
self._status_reporter = StatusReporter(self._results_queue,
self._continue_semaphore)
self._status_reporter = StatusReporter(
self._results_queue, self._continue_semaphore, self.logdir)
self._last_result = {}
config = config.copy()
+1 -1
View File
@@ -52,7 +52,7 @@ class Logger(object):
raise NotImplementedError
def update_config(self, config):
"""Updates the config for all loggers."""
"""Updates the config for logger."""
pass
@@ -45,7 +45,7 @@ class AsyncHyperBandScheduler(FIFOScheduler):
metric="episode_reward_mean",
mode="max",
max_t=100,
grace_period=10,
grace_period=1,
reduction_factor=4,
brackets=1):
assert max_t > 0, "Max (time_attr) not valid!"
+36
View File
@@ -0,0 +1,36 @@
# flake8: noqa
# This is an example quickstart for Tune.
# To connect to a cluster, uncomment below:
# import ray
# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument("--redis-address")
# args = parser.parse_args()
# ray.init(redis_address=args.redis_address)
# __quick_start_begin__
import torch.optim as optim
from ray import tune
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test
def train_mnist(config):
train_loader, test_loader = get_data_loaders()
model = ConvNet()
optimizer = optim.SGD(model.parameters(), lr=config["lr"])
for i in range(10):
train(model, optimizer, train_loader)
acc = test(model, test_loader)
tune.track.log(mean_accuracy=acc)
analysis = tune.run(
train_mnist, config={"lr": tune.grid_search([0.001, 0.01, 0.1])})
print("Best config: ", analysis.get_best_config(metric="mean_accuracy"))
# Get a dataframe for analyzing trial results.
df = analysis.dataframe()
# __quick_start_end__
@@ -34,7 +34,6 @@ class ExperimentAnalysisSuite(unittest.TestCase):
global_checkpoint_period=0,
name=self.test_name,
local_dir=self.test_dir,
return_trials=False,
stop={"training_iteration": 1},
num_samples=self.num_samples,
config={
+106
View File
@@ -0,0 +1,106 @@
# flake8: noqa
# Original Code: https://github.com/pytorch/examples/blob/master/mnist/main.py
# yapf: disable
# __tutorial_imports_begin__
import numpy as np
import torch
import torch.optim as optim
from torchvision import datasets
from ray import tune
from ray.tune import track
from ray.tune.schedulers import ASHAScheduler
from ray.tune.examples.mnist_pytorch import get_data_loaders, ConvNet, train, test
# __tutorial_imports_end__
# yapf: enable
# yapf: disable
# __train_func_begin__
def train_mnist(config):
model = ConvNet()
train_loader, test_loader = get_data_loaders()
optimizer = optim.SGD(
model.parameters(), lr=config["lr"], momentum=config["momentum"])
for i in range(10):
train(model, optimizer, train_loader)
acc = test(model, test_loader)
track.log(mean_accuracy=acc)
if i % 5 == 0:
# This saves the model to the trial directory
torch.save(model, "./model.pth")
# __train_func_end__
# yapf: enable
# __eval_func_begin__
search_space = {
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
"momentum": tune.uniform(0.1, 0.9)
}
# Uncomment this to enable distributed execution
# `ray.init(redis_address=...)`
analysis = tune.run(train_mnist, config=search_space)
# __eval_func_end__
#__plot_begin__
dfs = analysis.trial_dataframes
[d.mean_accuracy.plot() for d in dfs.values()]
#__plot_end__
# __run_scheduler_begin__
analysis = tune.run(
train_mnist,
num_samples=30,
scheduler=ASHAScheduler(metric="mean_accuracy", mode="max"),
config=search_space)
# Obtain a trial dataframe from all run trials of this `tune.run` call.
dfs = analysis.trial_dataframes
# __run_scheduler_end__
# yapf: disable
# __plot_scheduler_begin__
# Plot by epoch
ax = None # This plots everything on the same plot
for d in dfs.values():
ax = d.mean_accuracy.plot(ax=ax, legend=False)
# __plot_scheduler_end__
# yapf: enable
# __run_searchalg_begin__
from hyperopt import hp
from ray.tune.suggest.hyperopt import HyperOptSearch
space = {
"lr": hp.loguniform("lr", 1e-10, 0.1),
"momentum": hp.uniform("momentum", 0.1, 0.9),
}
hyperopt_search = HyperOptSearch(
space, max_concurrent=2, reward_attr="mean_accuracy")
analysis = tune.run(train_mnist, num_samples=10, search_alg=hyperopt_search)
# __run_searchalg_end__
# __run_analysis_begin__
import os
df = analysis.dataframe()
logdir = analysis.get_best_logdir("mean_accuracy", mode="max")
model = torch.load(os.path.join(logdir, "model.pth"))
# __run_analysis_end__
from ray.tune.examples.mnist_pytorch_trainable import TrainMNIST
# __trainable_run_begin__
search_space = {
"lr": tune.sample_from(lambda spec: 10**(-10 * np.random.rand())),
"momentum": tune.uniform(0.1, 0.9)
}
analysis = tune.run(
TrainMNIST, config=search_space, stop={"training_iteration": 10})
# __trainable_run_end__
+6 -5
View File
@@ -51,6 +51,7 @@ class TrackSession(object):
self.trial_id = trial_name + "_" + self.trial_id
if self.is_tune_session:
self._logger = _ReporterHook(_tune_reporter)
self._logdir = _tune_reporter.logdir
else:
self._initialize_logging(trial_name, experiment_dir, upload_dir,
trial_config)
@@ -60,6 +61,8 @@ class TrackSession(object):
experiment_dir=None,
upload_dir=None,
trial_config=None):
if upload_dir:
raise NotImplementedError("Upload Dir is not yet implemented.")
# TODO(rliaw): In other parts of the code, this is `local_dir`.
if experiment_dir is None:
@@ -74,11 +77,10 @@ class TrackSession(object):
# misc metadata to save as well
self.trial_config["trial_id"] = self.trial_id
self._logger = UnifiedLogger(self.trial_config, self._logdir,
self._upload_dir)
self._logger = UnifiedLogger(self.trial_config, self._logdir)
def log(self, **metrics):
"""Logs all named arguments specified in **metrics.
"""Logs all named arguments specified in `metrics`.
This will log trial metrics locally, and they will be synchronized
with the driver periodically through ray.
@@ -86,10 +88,9 @@ class TrackSession(object):
Arguments:
metrics: named arguments with corresponding values to log.
"""
self._iteration += 1
# TODO: Implement a batching mechanism for multiple calls to `log`
# within the same iteration.
self._iteration += 1
metrics_dict = metrics.copy()
metrics_dict.update({"trial_id": self.trial_id})
+3 -1
View File
@@ -459,7 +459,9 @@ class Trainable(object):
Args:
checkpoint_dir (str): The directory where the checkpoint
file must be stored.
file must be stored. In a Tune run, this defaults to
`<self.logdir>/checkpoint_<ITER>` (which is the same as
`local_dir/exp_name/trial_name/checkpoint_<ITER>`).
Returns:
checkpoint (str | dict): If string, the return value is
+3
View File
@@ -266,6 +266,9 @@ def run(run_or_experiment,
trials = runner.get_trials()
if return_trials:
return trials
logger.info("Returning an analysis object by default. You can call "
"`analysis.trials` to retrieve a list of trials. "
"This message will be removed in future versions of Tune.")
return ExperimentAnalysis(runner.checkpoint_file, trials=trials)
+1 -1
View File
@@ -180,7 +180,7 @@ def deep_update(original, new_dict, new_keys_allowed, whitelist):
return original
def flatten_dict(dt, delimiter=":"):
def flatten_dict(dt, delimiter="/"):
dt = copy.deepcopy(dt)
while any(isinstance(v, dict) for v in dt.values()):
remove = []