[tune] Added WandbLogger (#9725)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
krfricke
2020-07-30 22:09:03 +02:00
committed by GitHub
parent 68f3fec744
commit 619e44e54a
24 changed files with 1376 additions and 6 deletions
+1
View File
@@ -0,0 +1 @@
project_id: 643
+77
View File
@@ -0,0 +1,77 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: sgd-pytorch
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 2
initial_workers: 2
max_workers: 2
target_utilization_fraction: 0.9
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 20
# docker:
# image: tensorflow/tensorflow:1.5.0-py3
# container_name: ray_docker
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
head_node:
InstanceType: p3.8xlarge
ImageId: latest_dlami
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
worker_nodes:
InstanceType: p3.8xlarge
ImageId: latest_dlami
# Run workers on spot by default. Comment this out to use on-demand.
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
setup_commands:
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
- pip install -U ipdb ray[rllib] torch torchvision
- cp -r ~/tune ~/anaconda3/lib/python3.6/site-packages/ray
- cp -r ~/torch_ ~/anaconda3/lib/python3.6/site-packages/ray/util/sgd
- cp -r ~/autoscaler ~/anaconda3/lib/python3.6/site-packages/ray/
# Install apex.
# - rm -rf apex || true
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
file_mounts: {
~/tune: ./tune/,
~/torch_: ./util/sgd/torch/,
~/autoscaler: ./autoscaler/
}
# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# # Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
+1
View File
@@ -0,0 +1 @@
project_id: 578
+17
View File
@@ -93,6 +93,14 @@ py_test(
tags = ["exclusive"],
)
py_test(
name = "test_integration_wandb",
size = "small",
srcs = ["tests/test_integration_wandb.py"],
deps = [":tune_lib"],
tags = ["exclusive"],
)
py_test(
name = "test_logger",
size = "small",
@@ -523,6 +531,15 @@ py_test(
args = ["--smoke-test"]
)
py_test(
name = "wandb_example",
size = "small",
srcs = ["examples/wandb_example.py"],
deps = [":tune_lib"],
tags = ["exclusive", "example"],
args = ["--mock-api"]
)
py_test(
name = "xgboost_example",
size = "small",
+43
View File
@@ -0,0 +1,43 @@
class Trial:
hypers: dict = {} # static
config: dict = {} # static
status: str = None
trace: List[Dict] = []
checkpoints: List[str] = []
space = {}
trials = []
trial_checkpoints = {}
while not Optimizer.is_finished():
while Optimizer.has_next(space, trials, state):
trials += [Optimizer.next(space, trials, state)]
trial = Optimizer.choose(trials, state)
if Optimizer.should_stop(trial, trials, state):
Executor.stop(trial)
elif Optimizer.should_pause(trial, state):
Executor.pause(trial)
elif Optimizer.should_restore(trial, state):
restore(trial, trial.checkpoints[-1])
elif Optimizer.should_save(trial, state):
checkpoint = save(trial)
elif Optimizer.should_continue(trial, state):
step(trial)
exp = Experiment(logdir, name, restore=True)
failed_trials = exp.get_failed_trials()
run(failed_trials)
exp = Experiment(logdir, name, restore=True)
trials = exp.trials_finished()
trials.reset_status()
run(trials)
optimizer = Optimizer(sweep, metric, *parameters)
sweep.configure_server()
sweep.add_logger(Logging)
sweep.set_executor(executor)
sweep.run(func, verbose=verbose)
+207
View File
@@ -0,0 +1,207 @@
storage = TrialStorage(location)
trials = storage.get_trials()
failed_trials = trials.filter(status=Failed)
parameters = [t.hypers for t in failed_trials]
# Builder Pattern
factory = TrialFactory()
factory.queue(grid)
run(func, factory)
factory = TrialFactory()
factory.queue(distribution, num_samples=3, repeat=5)
run(func, factory)
factory = TrialFactory(optimizer)
factory.queue(distribution, delay_feedback=3, num_samples=20, max_concurrent=3)
run(func, factory)
optimizer.restore(storage)
factory = TrialFactory(optimizer)
factory.queue(parameter_list)
factory.queue(distribution, num_samples=3)
# single process
trials = []
while factory.has_next():
x = factory.next()
trial = build(func, x)
trials.put(trial)
storage.save(factory, trials)
while not trial.done():
result = get_next_result(trial)
log(result)
storage.update_checkpoint(trial)
factory.update(result)
storage.save(factory, trials)
# concurrent
trials = []
class Actor:
def __init__(self):
pass
def configure():
pass
def step():
pass
def save():
pass
def restore():
pass
factory, trials = storage.recover()
optimizer = factory.optimizer
result_streams = []
while factory.has_next() or not trials.not_done():
while factory.has_next():
x = factory.next()
trial = build(func, x)
trials.put(trial)
storage.save(factory, trials)
while Cluster.has_space(trials.live()) and trials.has_pending():
trial = trials.pop_pending()
handle = Actor.configure(trial)
result_streams.add(handle)
trial, handle, payload = process_next(result_streams)
if payload.type == "SAVE":
trial.update(payload.checkpoint)
storage.save(trials)
elif payload.type == "STEP":
trial.track(payload.result)
log(payload.result)
else:
pass
if should_checkpoint(trial):
Executor.save(handle)
elif not is_finished(trial):
action = Scheduler(trial, trials)
Executor.execute_action(action, trial)
elif is_finished(trial):
factory.update(trial, result)
storage.save(factory, trials)
# concurrent with checkpointing
# concurrent with pbt
while factory.has_next() or not trials.not_done():
# ...
trial, handle, payload = process_next(result_streams)
elif not is_finished(trial):
action = pbt(trial, trials)
factory.queue(new_hps, trial3.checkpoint)
Executor.execute_action(action, trial)
# Restore last experiment
exp = Experiment.restore(storage=X)
trials = exp.get_trial(filter=failed)
run(func, manual_list)
run(func, space, searcher)
run(func, grid)
run(func, manual_list, checkpoints)
run(func, manual_list)
run(func, exp)
# Core concepts:
# Result: Dict[str, value]
# t_state: Any
# Trial: hps[Dict], static_config[Dict]
# TrialTrace: List[Result], t_state, Trial
# Trainable: t_state, Trial -> t_state, Result
# Optimizer: o_state, List[TrialTrace], Trainable -> (
# o_state, List[TrialTrace])
# SearchAlg: state, Dict[hps, Result] -> state, hps
# Execution concepts
# Checkpoint
# LiveTrial: TrialTrace, location, status, is_idle
# Status: PENDING, SAVING, RESTORING, TRAINING, SETUP, STOP, ERROR
# Trainer: Trainable, location, t_state, Trial -> t_state, Result
step(o_state, LiveTrial, List[LiveTrial]) -> LiveTrial, *args
Server(List[LiveTrial]) -> List[LiveTrial]
checkpointer(LiveTrial, manager_state) -> TrialTrace
Logger(TrialTrace)
Optimizer(o_trace, ...)
Syncer()
TrialExecutor(reuse_actors, queue_trials)
ServerConfig(server_port)
Optimizer(stop, search_alg, scheduler)
Experiment(resume, local_dir)
CheckpointManager(
sync_on_checkpoint,
keep_checkpoints_num,
global_checkpoint_period,
export_formats,
checkpoint_score_attr
)
### Tune commands
tune.set_log_config(
upload_dir,
sync_to_cloud,
trial_name_creator,
sync_to_driver,
progress_reporter,
loggers,
verbose
)
tune.set_server(ServerConfig)
tune.run(
experiment,
trainable_fn,
raise_on_failed_trial, # where can this go?
max_failures: int or "fail-fast",
trial_executor,
restore_from, # checkpoint path to restore from
resources_per_trial,
num_samples,
search_space, # I'm not a big fan of this because Search Algs have their own search_space too
Optimizer,
CheckpointManager)
+42
View File
@@ -0,0 +1,42 @@
checkpoint_manager = State(location)
checkpoint_manager.optimizer_state
checkpoint_manager.generator_state
checkpoint_manager.trial_state
# How much have we learned
optimizer = Optimizer.from_checkpoint(checkpoint_manager)
optimizer = Optimizer(space, checkpoint=checkpoint_manager)
for x, y in warm_start:
optimizer.report(x, y)
samples = [optimizer.sample(random=True) for i in range(50)]
spec = TrialSpec(func, local_dir, checkpoint)
generator = TrialGenerator.from_checkpoint(checkpoint, optimizer)
generator = TrialGenerator.from_trials(trials)
generator = TrialGenerator.from_spec(spec, optimizer)
generator.configure(checkpoint_callback)
generator.queue(samples)
generator.queue(num_samples=50, repeat=3, max_concurrent=4)
generator.next()
generator = TrialGenerator.from_multi_spec(spec)
run(generator)
###################################################
# Exploration process
trial_list = get_trials(checkpoint_manager)
failed_trials = [t.reset() for t in trial_list if t.status == "FAILED"]
generator = TrialGenerator.from_trials(failed_trials)
tune.run(generator)
builder = Builder()
for params in samples:
yield builder.build(params)
+71
View File
@@ -0,0 +1,71 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: sgd-pytorch
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 0
initial_workers: 0
max_workers: 0
target_utilization_fraction: 0.9
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 20
# docker:
# image: tensorflow/tensorflow:1.5.0-py3
# container_name: ray_docker
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
head_node:
InstanceType: p3.8xlarge
ImageId: latest_dlami
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
worker_nodes:
InstanceType: p3.8xlarge
ImageId: latest_dlami
# Run workers on spot by default. Comment this out to use on-demand.
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
setup_commands:
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
- pip install -U ipdb ray[rllib] torch torchvision
# Install apex.
# - rm -rf apex || true
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
file_mounts: {
}
# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# # Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --redis-port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
+103
View File
@@ -0,0 +1,103 @@
import argparse
import tempfile
from unittest.mock import MagicMock
import numpy as np
import wandb
from ray import tune
from ray.tune import Trainable
from ray.tune.integration.wandb import WandbLogger, WandbTrainableMixin, \
wandb_mixin
from ray.tune.logger import DEFAULT_LOGGERS
def train_function(config, checkpoint_dir=None):
for i in range(30):
loss = config["mean"] + config["sd"] * np.random.randn()
tune.report(loss=loss)
def tune_function(api_key_file):
"""Example for using a WandbLogger with the function API"""
tune.run(
train_function,
config={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
"wandb": {
"api_key_file": api_key_file,
"project": "Wandb_example"
}
},
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
@wandb_mixin
def decorated_train_function(config, checkpoint_dir=None):
for i in range(30):
loss = config["mean"] + config["sd"] * np.random.randn()
tune.report(loss=loss)
wandb.log(dict(loss=loss))
def tune_decorated(api_key_file):
"""Example for using the @wandb_mixin decorator with the function API"""
tune.run(
decorated_train_function,
config={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
"wandb": {
"api_key_file": api_key_file,
"project": "Wandb_example"
}
})
class WandbTrainable(WandbTrainableMixin, Trainable):
def step(self):
for i in range(30):
loss = self.config["mean"] + self.config["sd"] * np.random.randn()
wandb.log({"loss": loss})
return {"loss": loss, "done": True}
def tune_trainable(api_key_file):
"""Example for using a WandTrainableMixin with the class API"""
tune.run(
WandbTrainable,
config={
"mean": tune.grid_search([1, 2, 3, 4, 5]),
"sd": tune.uniform(0.2, 0.8),
"wandb": {
"api_key_file": api_key_file,
"project": "Wandb_example"
}
})
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--mock-api", action="store_true", help="Mock Wandb API access")
args, _ = parser.parse_known_args()
api_key_file = "~/.wandb_api_key"
if args.mock_api:
WandbLogger._logger_process_cls = MagicMock
decorated_train_function.__mixins__ = tuple()
WandbTrainable._wandb = MagicMock()
wandb = MagicMock() # noqa: F811
temp_file = tempfile.NamedTemporaryFile()
temp_file.write(b"1234")
temp_file.flush()
api_key_file = temp_file.name
tune_function(api_key_file)
tune_decorated(api_key_file)
tune_trainable(api_key_file)
if args.mock_api:
temp_file.close()
+9 -1
View File
@@ -369,7 +369,15 @@ def detect_checkpoint_function(train_func, abort=False):
def wrap_function(train_func):
class ImplicitFunc(FunctionRunner):
if hasattr(train_func, "__mixins__"):
inherit_from = train_func.__mixins__ + (FunctionRunner, )
else:
inherit_from = (FunctionRunner, )
class ImplicitFunc(*inherit_from):
_name = train_func.__name__ if hasattr(train_func, "__name__") \
else "func"
def _trainable_func(self, config, reporter, checkpoint_dir):
func_args = inspect.getfullargspec(train_func).args
if len(func_args) > 1: # more arguments than just the config
+345
View File
@@ -0,0 +1,345 @@
import os
from multiprocessing import Process, Queue
from numbers import Number
from ray import logger
from ray.tune import Trainable
from ray.tune.function_runner import FunctionRunner
from ray.tune.logger import Logger
try:
import wandb
except ImportError:
logger.error("pip install 'wandb' to use WandbLogger/WandbTrainableMixin.")
wandb = None
WANDB_ENV_VAR = "WANDB_API_KEY"
_WANDB_QUEUE_END = (None, )
def wandb_mixin(func):
"""wandb_mixin
Weights and biases (https://www.wandb.com/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
Trainable mixin helps initializing the Wandb API for use with the
``Trainable`` class or with `@wandb_mixin` for the function API.
For basic usage, just prepend your training function with the
``@wandb_mixin`` decorator:
.. code-block:: python
from ray.tune.integration.wandb import wandb_mixin
@wandb_mixin
def train_fn(config):
wandb.log()
Wandb configuration is done by passing a ``wandb`` key to
the ``config`` parameter of ``tune.run()`` (see example below).
The content of the ``wandb`` config entry is passed to ``wandb.init()``
as keyword arguments. The exception are the following settings, which
are used to configure the ``WandbTrainableMixin`` itself:
Args:
api_key_file (str): Path to file containing the Wandb API KEY.
api_key (str): Wandb API Key. Alternative to setting `api_key_file`.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
by Tune, but can be overwritten by filling out the respective configuration
values.
Please see here for all other valid configuration settings:
https://docs.wandb.com/library/init
Example:
.. code-block:: python
from ray import tune
from ray.tune.integration.wandb import wandb_mixin
@wandb_mixin
def train_fn(config):
for i in range(10):
loss = self.config["a"] + self.config["b"]
wandb.log({"loss": loss})
tune.report(loss=loss, done=True)
tune.run(
train_fn,
config={
# define search space here
"a": tune.choice([1, 2, 3]),
"b": tune.choice([4, 5, 6]),
# wandb configuration
"wandb": {
"project": "Optimization_Project",
"api_key_file": "/path/to/file"
}
})
"""
func.__mixins__ = (WandbTrainableMixin, )
func.__wandb_group__ = func.__name__
return func
def _set_api_key(wandb_config):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
api_key = wandb_config.pop("api_key", None)
if api_key_file:
if api_key:
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
with open(api_key_file, "rt") as fp:
api_key = fp.readline().strip()
if api_key:
os.environ[WANDB_ENV_VAR] = api_key
elif not os.environ.get(WANDB_ENV_VAR):
raise ValueError(
"No WandB API key found. Either set the {} environment "
"variable or pass `api_key` or `api_key_file` in the config".
format(WANDB_ENV_VAR))
class _WandbLoggingProcess(Process):
"""
We need a `multiprocessing.Process` to allow multiple concurrent
wandb logging instances locally.
"""
def __init__(self, queue, exclude, to_config, *args, **kwargs):
super(_WandbLoggingProcess, self).__init__()
self.queue = queue
self._exclude = set(exclude)
self._to_config = set(to_config)
self.args = args
self.kwargs = kwargs
def run(self):
wandb.init(*self.args, **self.kwargs)
while True:
result = self.queue.get()
if result == _WANDB_QUEUE_END:
break
log, config_update = self._handle_result(result)
wandb.config.update(config_update, allow_val_change=True)
wandb.log(log)
wandb.join()
def _handle_result(self, result):
config_update = result.get("config", {}).copy()
log = {}
for k, v in result.items():
if k in self._to_config:
config_update[k] = v
elif k in self._exclude:
continue
elif not isinstance(v, Number):
continue
else:
log[k] = v
config_update.pop("callbacks", None) # Remove callbacks
return log, config_update
class WandbLogger(Logger):
"""WandbLogger
Weights and biases (https://www.wandb.com/) is a tool for experiment
tracking, model optimization, and dataset versioning. This Ray Tune
``Logger`` sends metrics to Wandb for automatic tracking and
visualization.
Wandb configuration is done by passing a ``wandb`` key to
the ``config`` parameter of ``tune.run()`` (see example below).
The content of the ``wandb`` config entry is passed to ``wandb.init()``
as keyword arguments. The exception are the following settings, which
are used to configure the WandbLogger itself:
Args:
api_key_file (str): Path to file containing the Wandb API KEY.
api_key (str): Wandb API Key. Alternative to setting ``api_key_file``.
excludes (list): List of metrics that should be excluded from
the log.
log_config (bool): Boolean indicating if the ``config`` parameter of
the ``results`` dict should be logged. This makes sense if
parameters will change during training, e.g. with
PopulationBasedTraining. Defaults to False.
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
by Tune, but can be overwritten by filling out the respective configuration
values.
Please see here for all other valid configuration settings:
https://docs.wandb.com/library/init
Example:
.. code-block:: python
from ray.tune.logger import DEFAULT_LOGGERS
from ray.tune.integration.wandb import WandbLogger
tune.run(
train_fn,
config={
# define search space here
"parameter_1": tune.choice([1, 2, 3]),
"parameter_2": tune.choice([4, 5, 6]),
# wandb configuration
"wandb": {
"project": "Optimization_Project",
"api_key_file": "/path/to/file",
"log_config": True
}
},
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
"""
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]
# Use these result keys to update `wandb.config`
_config_results = [
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
"pid", "date"
]
_logger_process_cls = _WandbLoggingProcess
def _init(self):
config = self.config.copy()
try:
wandb_config = config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb logger specified but no configuration has been passed. "
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
exclude_results = self._exclude_results.copy()
# Additional excludes
additional_excludes = wandb_config.pop("excludes", [])
exclude_results += additional_excludes
# Log config keys on each result?
log_config = wandb_config.pop("log_config", False)
if not log_config:
exclude_results += ["config"]
# Fill trial ID and name
trial_id = self.trial.trial_id
trial_name = str(self.trial)
# Project name for Wandb
try:
wandb_project = wandb_config.pop("project")
except KeyError:
raise ValueError(
"You need to specify a `project` in your wandb `config` dict.")
# Grouping
wandb_group = wandb_config.pop("group", self.trial.trainable_name)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(wandb_config)
self._queue = Queue()
self._wandb = self._logger_process_cls(
queue=self._queue,
exclude=exclude_results,
to_config=self._config_results,
**wandb_init_kwargs)
self._wandb.start()
def on_result(self, result):
self._queue.put(result)
def close(self):
self._queue.put(_WANDB_QUEUE_END)
self._wandb.join(timeout=10)
class WandbTrainableMixin:
_wandb = wandb
def __init__(self, config, *args, **kwargs):
if not isinstance(self, Trainable):
raise ValueError(
"The `WandbTrainableMixin` can only be used as a mixin "
"for `tune.Trainable` classes. Please make sure your "
"class inherits from both. For example: "
"`class YourTrainable(WandbTrainableMixin)`.")
super().__init__(config, *args, **kwargs)
config = config.copy()
try:
wandb_config = config.pop("wandb").copy()
except KeyError:
raise ValueError(
"Wandb mixin specified but no configuration has been passed. "
"Make sure to include a `wandb` key in your `config` dict "
"containing at least a `project` specification.")
_set_api_key(wandb_config)
# Fill trial ID and name
trial_id = self.trial_id
trial_name = self.trial_name
# Project name for Wandb
try:
wandb_project = wandb_config.pop("project")
except KeyError:
raise ValueError(
"You need to specify a `project` in your wandb `config` dict.")
# Grouping
if isinstance(self, FunctionRunner):
default_group = self._name
else:
default_group = type(self).__name__
wandb_group = wandb_config.pop("group", default_group)
wandb_init_kwargs = dict(
id=trial_id,
name=trial_name,
resume=True,
reinit=True,
allow_val_change=True,
group=wandb_group,
project=wandb_project,
config=config)
wandb_init_kwargs.update(wandb_config)
self.wandb = self._wandb.init(**wandb_init_kwargs)
def stop(self):
self._wandb.join()
if hasattr(super(), "stop"):
super().stop()
+6 -2
View File
@@ -469,6 +469,10 @@ def _get_trial_info(trial, parameters, metrics):
result = trial.last_result
config = trial.config
trial_info = [str(trial), trial.status, str(trial.location)]
trial_info += [unflattened_lookup(param, config) for param in parameters]
trial_info += [unflattened_lookup(metric, result) for metric in metrics]
trial_info += [
unflattened_lookup(param, config, default=None) for param in parameters
]
trial_info += [
unflattened_lookup(metric, result, default=None) for metric in metrics
]
return trial_info
+73
View File
@@ -0,0 +1,73 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: sgd-pytorch
# The maximum number of workers nodes to launch in addition to the head
# node. This takes precedence over min_workers. min_workers default to 0.
min_workers: 0
initial_workers: 0
max_workers: 0
target_utilization_fraction: 0.9
# If a node is idle for this many minutes, it will be removed.
idle_timeout_minutes: 20
# docker:
# image: tensorflow/tensorflow:1.5.0-py3
# container_name: ray_docker
# Cloud-provider specific configuration.
provider:
type: aws
region: us-west-2
# How Ray will authenticate with newly launched nodes.
auth:
ssh_user: ubuntu
head_node:
InstanceType: g3.8xlarge
ImageId: latest_dlami
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
worker_nodes:
InstanceType: g3.8xlarge
ImageId: latest_dlami
# Run workers on spot by default. Comment this out to use on-demand.
InstanceMarketOptions:
MarketType: spot
# SpotOptions:
# MaxPrice: "9.0"
setup_commands:
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
- pip install -U ipdb ray[rllib] torch torchvision
# Install apex.
# - rm -rf apex || true
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
file_mounts: {
~/anaconda3/lib/python3.6/site-packages/ray/tune: ./tune/,
~/anaconda3/lib/python3.6/site-packages/ray/util/sgd/torch: ./util/sgd/torch/
}
# Custom commands that will be run on the head node after common setup.
head_setup_commands: []
# Custom commands that will be run on worker nodes after common setup.
worker_setup_commands: []
# # Command to start ray on the head node. You don't need to change this.
head_start_ray_commands:
- ray stop
- ray start --head --redis-port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
# Command to start ray on worker nodes. You don't need to change this.
worker_start_ray_commands:
- ray stop
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
Binary file not shown.
@@ -0,0 +1,298 @@
import os
import tempfile
from collections import namedtuple
from multiprocessing import Queue
import unittest
from ray.tune import Trainable
from ray.tune.function_runner import wrap_function
from ray.tune.integration.wandb import _WandbLoggingProcess, \
_WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, \
wandb_mixin
from ray.tune.result import TRIAL_INFO
from ray.tune.trial import TrialInfo
Trial = namedtuple("MockTrial",
["config", "trial_id", "trial_name", "trainable_name"])
Trial.__str__ = lambda t: t.trial_name
class _MockWandbLoggingProcess(_WandbLoggingProcess):
def __init__(self, queue, exclude, to_config, *args, **kwargs):
super(_MockWandbLoggingProcess,
self).__init__(queue, exclude, to_config, *args, **kwargs)
self.logs = Queue()
self.config_updates = Queue()
def run(self):
while True:
result = self.queue.get()
if result == _WANDB_QUEUE_END:
break
log, config_update = self._handle_result(result)
self.config_updates.put(config_update)
self.logs.put(log)
class WandbTestLogger(WandbLogger):
_logger_process_cls = _MockWandbLoggingProcess
class _MockWandbAPI(object):
def init(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
return self
class _MockWandbTrainableMixin(WandbTrainableMixin):
_wandb = _MockWandbAPI()
class WandbTestTrainable(_MockWandbTrainableMixin, Trainable):
pass
class WandbIntegrationTest(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def testWandbLoggerConfig(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
if WANDB_ENV_VAR in os.environ:
del os.environ[WANDB_ENV_VAR]
# Needs at least a project
with self.assertRaises(ValueError):
logger = WandbTestLogger(trial_config, "/tmp", trial)
# No API key
trial_config["wandb"] = {"project": "test_project"}
with self.assertRaises(ValueError):
logger = WandbTestLogger(trial_config, "/tmp", trial)
# API Key in config
trial_config["wandb"] = {"project": "test_project", "api_key": "1234"}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
logger.close()
del os.environ[WANDB_ENV_VAR]
# API Key file
with tempfile.NamedTemporaryFile("wt") as fp:
fp.write("5678")
fp.flush()
trial_config["wandb"] = {
"project": "test_project",
"api_key_file": fp.name
}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
logger.close()
del os.environ[WANDB_ENV_VAR]
# API Key in env
os.environ[WANDB_ENV_VAR] = "9012"
trial_config["wandb"] = {"project": "test_project"}
logger = WandbTestLogger(trial_config, "/tmp", trial)
logger.close()
# From now on, the API key is in the env variable.
# Default configuration
trial_config["wandb"] = {"project": "test_project"}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertEqual(logger._wandb.kwargs["project"], "test_project")
self.assertEqual(logger._wandb.kwargs["id"], trial.trial_id)
self.assertEqual(logger._wandb.kwargs["name"], trial.trial_name)
self.assertEqual(logger._wandb.kwargs["group"], trial.trainable_name)
self.assertIn("config", logger._wandb._exclude)
logger.close()
# log config.
trial_config["wandb"] = {"project": "test_project", "log_config": True}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertNotIn("config", logger._wandb._exclude)
self.assertNotIn("metric", logger._wandb._exclude)
logger.close()
# Exclude metric.
trial_config["wandb"] = {
"project": "test_project",
"excludes": ["metric"]
}
logger = WandbTestLogger(trial_config, "/tmp", trial)
self.assertIn("config", logger._wandb._exclude)
self.assertIn("metric", logger._wandb._exclude)
logger.close()
def testWandbLoggerReporting(self):
trial_config = {"par1": 4, "par2": 9.12345678}
trial = Trial(trial_config, 0, "trial_0", "trainable")
trial_config["wandb"] = {
"project": "test_project",
"api_key": "1234",
"excludes": ["metric2"]
}
logger = WandbTestLogger(trial_config, "/tmp", trial)
r1 = {
"metric1": 0.8,
"metric2": 1.4,
"const": "text",
"config": trial_config
}
logger.on_result(r1)
logged = logger._wandb.logs.get(timeout=10)
self.assertIn("metric1", logged)
self.assertNotIn("metric2", logged)
self.assertNotIn("const", logged)
self.assertNotIn("config", logged)
logger.close()
def testWandbMixinConfig(self):
config = {"par1": 4, "par2": 9.12345678}
trial = Trial(config, 0, "trial_0", "trainable")
trial_info = TrialInfo(trial)
config[TRIAL_INFO] = trial_info
if WANDB_ENV_VAR in os.environ:
del os.environ[WANDB_ENV_VAR]
# Needs at least a project
with self.assertRaises(ValueError):
trainable = WandbTestTrainable(config)
# No API key
config["wandb"] = {"project": "test_project"}
with self.assertRaises(ValueError):
trainable = WandbTestTrainable(config)
# API Key in config
config["wandb"] = {"project": "test_project", "api_key": "1234"}
trainable = WandbTestTrainable(config)
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
del os.environ[WANDB_ENV_VAR]
# API Key file
with tempfile.NamedTemporaryFile("wt") as fp:
fp.write("5678")
fp.flush()
config["wandb"] = {
"project": "test_project",
"api_key_file": fp.name
}
trainable = WandbTestTrainable(config)
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
del os.environ[WANDB_ENV_VAR]
# API Key in env
os.environ[WANDB_ENV_VAR] = "9012"
config["wandb"] = {"project": "test_project"}
trainable = WandbTestTrainable(config)
# From now on, the API key is in the env variable.
# Default configuration
config["wandb"] = {"project": "test_project"}
config[TRIAL_INFO] = trial_info
trainable = WandbTestTrainable(config)
self.assertEqual(trainable.wandb.kwargs["project"], "test_project")
self.assertEqual(trainable.wandb.kwargs["id"], trial.trial_id)
self.assertEqual(trainable.wandb.kwargs["name"], trial.trial_name)
self.assertEqual(trainable.wandb.kwargs["group"], "WandbTestTrainable")
def testWandbDecoratorConfig(self):
config = {"par1": 4, "par2": 9.12345678}
trial = Trial(config, 0, "trial_0", "trainable")
trial_info = TrialInfo(trial)
@wandb_mixin
def train_fn(config):
return 1
train_fn.__mixins__ = (_MockWandbTrainableMixin, )
config[TRIAL_INFO] = trial_info
if WANDB_ENV_VAR in os.environ:
del os.environ[WANDB_ENV_VAR]
# Needs at least a project
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(config)
# No API key
config["wandb"] = {"project": "test_project"}
with self.assertRaises(ValueError):
wrapped = wrap_function(train_fn)(config)
# API Key in config
config["wandb"] = {"project": "test_project", "api_key": "1234"}
wrapped = wrap_function(train_fn)(config)
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
del os.environ[WANDB_ENV_VAR]
# API Key file
with tempfile.NamedTemporaryFile("wt") as fp:
fp.write("5678")
fp.flush()
config["wandb"] = {
"project": "test_project",
"api_key_file": fp.name
}
wrapped = wrap_function(train_fn)(config)
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
del os.environ[WANDB_ENV_VAR]
# API Key in env
os.environ[WANDB_ENV_VAR] = "9012"
config["wandb"] = {"project": "test_project"}
wrapped = wrap_function(train_fn)(config)
# From now on, the API key is in the env variable.
# Default configuration
config["wandb"] = {"project": "test_project"}
config[TRIAL_INFO] = trial_info
wrapped = wrap_function(train_fn)(config)
self.assertEqual(wrapped.wandb.kwargs["project"], "test_project")
self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id)
self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))
+28
View File
@@ -0,0 +1,28 @@
import pickle
import ray
import torch
import torch.nn.functional as F
import torch.nn as nn
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
self.fc = nn.Linear(192, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 3))
x = x.view(-1, 192)
x = self.fc(x)
return F.log_softmax(x, dim=1)
def save_me():
model = ConvNet()
torch.save(model, "./test.th")
return 1
ray_func = ray.remote(save_me)
ray.init()
ray.get(ray_func.remote())
+1
View File
@@ -0,0 +1 @@
trialv2.py
+5 -3
View File
@@ -216,7 +216,7 @@ def flatten_dict(dt, delimiter="/"):
return dt
def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
"""
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
@@ -232,8 +232,10 @@ def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
base = base[int(key)]
else:
raise KeyError()
except KeyError:
return default
except KeyError as e:
if "default" in kwargs:
return kwargs["default"]
raise e
return base
+1
View File
@@ -24,5 +24,6 @@ timm
torch>=1.5.0
torchvision>=0.6.0
tune-sklearn==0.0.5
wandb
xgboost
zoopt>=0.4.0