mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 11:01:06 +08:00
[tune] Added WandbLogger (#9725)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: Kai Fricke <kai@anyscale.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
project_id: 643
|
||||
@@ -0,0 +1,77 @@
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: sgd-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 2
|
||||
initial_workers: 2
|
||||
max_workers: 2
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
# If a node is idle for this many minutes, it will be removed.
|
||||
idle_timeout_minutes: 20
|
||||
# docker:
|
||||
# image: tensorflow/tensorflow:1.5.0-py3
|
||||
# container_name: ray_docker
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-west-2
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
# Run workers on spot by default. Comment this out to use on-demand.
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
setup_commands:
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- pip install -U ipdb ray[rllib] torch torchvision
|
||||
- cp -r ~/tune ~/anaconda3/lib/python3.6/site-packages/ray
|
||||
- cp -r ~/torch_ ~/anaconda3/lib/python3.6/site-packages/ray/util/sgd
|
||||
- cp -r ~/autoscaler ~/anaconda3/lib/python3.6/site-packages/ray/
|
||||
# Install apex.
|
||||
# - rm -rf apex || true
|
||||
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
|
||||
|
||||
|
||||
file_mounts: {
|
||||
~/tune: ./tune/,
|
||||
~/torch_: ./util/sgd/torch/,
|
||||
~/autoscaler: ./autoscaler/
|
||||
}
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
||||
|
||||
# Custom commands that will be run on worker nodes after common setup.
|
||||
worker_setup_commands: []
|
||||
|
||||
# # Command to start ray on the head node. You don't need to change this.
|
||||
head_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --head --port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
|
||||
|
||||
# Command to start ray on worker nodes. You don't need to change this.
|
||||
worker_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
project_id: 578
|
||||
@@ -93,6 +93,14 @@ py_test(
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_integration_wandb",
|
||||
size = "small",
|
||||
srcs = ["tests/test_integration_wandb.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_logger",
|
||||
size = "small",
|
||||
@@ -523,6 +531,15 @@ py_test(
|
||||
args = ["--smoke-test"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "wandb_example",
|
||||
size = "small",
|
||||
srcs = ["examples/wandb_example.py"],
|
||||
deps = [":tune_lib"],
|
||||
tags = ["exclusive", "example"],
|
||||
args = ["--mock-api"]
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "xgboost_example",
|
||||
size = "small",
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
class Trial:
|
||||
hypers: dict = {} # static
|
||||
config: dict = {} # static
|
||||
status: str = None
|
||||
trace: List[Dict] = []
|
||||
checkpoints: List[str] = []
|
||||
|
||||
space = {}
|
||||
trials = []
|
||||
|
||||
trial_checkpoints = {}
|
||||
|
||||
while not Optimizer.is_finished():
|
||||
while Optimizer.has_next(space, trials, state):
|
||||
trials += [Optimizer.next(space, trials, state)]
|
||||
trial = Optimizer.choose(trials, state)
|
||||
if Optimizer.should_stop(trial, trials, state):
|
||||
Executor.stop(trial)
|
||||
elif Optimizer.should_pause(trial, state):
|
||||
Executor.pause(trial)
|
||||
elif Optimizer.should_restore(trial, state):
|
||||
restore(trial, trial.checkpoints[-1])
|
||||
elif Optimizer.should_save(trial, state):
|
||||
checkpoint = save(trial)
|
||||
elif Optimizer.should_continue(trial, state):
|
||||
step(trial)
|
||||
|
||||
|
||||
exp = Experiment(logdir, name, restore=True)
|
||||
failed_trials = exp.get_failed_trials()
|
||||
run(failed_trials)
|
||||
|
||||
exp = Experiment(logdir, name, restore=True)
|
||||
trials = exp.trials_finished()
|
||||
trials.reset_status()
|
||||
run(trials)
|
||||
|
||||
|
||||
optimizer = Optimizer(sweep, metric, *parameters)
|
||||
sweep.configure_server()
|
||||
sweep.add_logger(Logging)
|
||||
sweep.set_executor(executor)
|
||||
sweep.run(func, verbose=verbose)
|
||||
@@ -0,0 +1,207 @@
|
||||
storage = TrialStorage(location)
|
||||
trials = storage.get_trials()
|
||||
failed_trials = trials.filter(status=Failed)
|
||||
parameters = [t.hypers for t in failed_trials]
|
||||
|
||||
|
||||
# Builder Pattern
|
||||
|
||||
factory = TrialFactory()
|
||||
factory.queue(grid)
|
||||
run(func, factory)
|
||||
|
||||
factory = TrialFactory()
|
||||
factory.queue(distribution, num_samples=3, repeat=5)
|
||||
run(func, factory)
|
||||
|
||||
factory = TrialFactory(optimizer)
|
||||
factory.queue(distribution, delay_feedback=3, num_samples=20, max_concurrent=3)
|
||||
run(func, factory)
|
||||
|
||||
optimizer.restore(storage)
|
||||
|
||||
factory = TrialFactory(optimizer)
|
||||
factory.queue(parameter_list)
|
||||
factory.queue(distribution, num_samples=3)
|
||||
|
||||
# single process
|
||||
|
||||
trials = []
|
||||
|
||||
while factory.has_next():
|
||||
x = factory.next()
|
||||
trial = build(func, x)
|
||||
trials.put(trial)
|
||||
storage.save(factory, trials)
|
||||
|
||||
while not trial.done():
|
||||
result = get_next_result(trial)
|
||||
log(result)
|
||||
storage.update_checkpoint(trial)
|
||||
factory.update(result)
|
||||
storage.save(factory, trials)
|
||||
|
||||
|
||||
# concurrent
|
||||
|
||||
trials = []
|
||||
|
||||
class Actor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def configure():
|
||||
pass
|
||||
|
||||
def step():
|
||||
pass
|
||||
|
||||
def save():
|
||||
pass
|
||||
|
||||
def restore():
|
||||
pass
|
||||
|
||||
factory, trials = storage.recover()
|
||||
optimizer = factory.optimizer
|
||||
result_streams = []
|
||||
|
||||
while factory.has_next() or not trials.not_done():
|
||||
while factory.has_next():
|
||||
x = factory.next()
|
||||
trial = build(func, x)
|
||||
trials.put(trial)
|
||||
storage.save(factory, trials)
|
||||
|
||||
while Cluster.has_space(trials.live()) and trials.has_pending():
|
||||
trial = trials.pop_pending()
|
||||
handle = Actor.configure(trial)
|
||||
result_streams.add(handle)
|
||||
|
||||
trial, handle, payload = process_next(result_streams)
|
||||
|
||||
if payload.type == "SAVE":
|
||||
trial.update(payload.checkpoint)
|
||||
storage.save(trials)
|
||||
elif payload.type == "STEP":
|
||||
trial.track(payload.result)
|
||||
log(payload.result)
|
||||
else:
|
||||
pass
|
||||
|
||||
if should_checkpoint(trial):
|
||||
Executor.save(handle)
|
||||
elif not is_finished(trial):
|
||||
action = Scheduler(trial, trials)
|
||||
Executor.execute_action(action, trial)
|
||||
elif is_finished(trial):
|
||||
factory.update(trial, result)
|
||||
storage.save(factory, trials)
|
||||
|
||||
|
||||
|
||||
# concurrent with checkpointing
|
||||
|
||||
|
||||
|
||||
# concurrent with pbt
|
||||
|
||||
while factory.has_next() or not trials.not_done():
|
||||
# ...
|
||||
|
||||
trial, handle, payload = process_next(result_streams)
|
||||
elif not is_finished(trial):
|
||||
action = pbt(trial, trials)
|
||||
factory.queue(new_hps, trial3.checkpoint)
|
||||
|
||||
Executor.execute_action(action, trial)
|
||||
|
||||
|
||||
|
||||
# Restore last experiment
|
||||
exp = Experiment.restore(storage=X)
|
||||
trials = exp.get_trial(filter=failed)
|
||||
|
||||
run(func, manual_list)
|
||||
run(func, space, searcher)
|
||||
run(func, grid)
|
||||
run(func, manual_list, checkpoints)
|
||||
run(func, manual_list)
|
||||
|
||||
|
||||
|
||||
|
||||
run(func, exp)
|
||||
|
||||
|
||||
# Core concepts:
|
||||
# Result: Dict[str, value]
|
||||
# t_state: Any
|
||||
# Trial: hps[Dict], static_config[Dict]
|
||||
# TrialTrace: List[Result], t_state, Trial
|
||||
|
||||
# Trainable: t_state, Trial -> t_state, Result
|
||||
|
||||
# Optimizer: o_state, List[TrialTrace], Trainable -> (
|
||||
# o_state, List[TrialTrace])
|
||||
|
||||
# SearchAlg: state, Dict[hps, Result] -> state, hps
|
||||
|
||||
|
||||
# Execution concepts
|
||||
|
||||
# Checkpoint
|
||||
# LiveTrial: TrialTrace, location, status, is_idle
|
||||
# Status: PENDING, SAVING, RESTORING, TRAINING, SETUP, STOP, ERROR
|
||||
# Trainer: Trainable, location, t_state, Trial -> t_state, Result
|
||||
step(o_state, LiveTrial, List[LiveTrial]) -> LiveTrial, *args
|
||||
Server(List[LiveTrial]) -> List[LiveTrial]
|
||||
checkpointer(LiveTrial, manager_state) -> TrialTrace
|
||||
Logger(TrialTrace)
|
||||
Optimizer(o_trace, ...)
|
||||
Syncer()
|
||||
|
||||
|
||||
|
||||
|
||||
TrialExecutor(reuse_actors, queue_trials)
|
||||
ServerConfig(server_port)
|
||||
Optimizer(stop, search_alg, scheduler)
|
||||
Experiment(resume, local_dir)
|
||||
CheckpointManager(
|
||||
sync_on_checkpoint,
|
||||
keep_checkpoints_num,
|
||||
global_checkpoint_period,
|
||||
export_formats,
|
||||
checkpoint_score_attr
|
||||
)
|
||||
### Tune commands
|
||||
tune.set_log_config(
|
||||
upload_dir,
|
||||
sync_to_cloud,
|
||||
trial_name_creator,
|
||||
sync_to_driver,
|
||||
progress_reporter,
|
||||
loggers,
|
||||
verbose
|
||||
)
|
||||
tune.set_server(ServerConfig)
|
||||
tune.run(
|
||||
experiment,
|
||||
trainable_fn,
|
||||
raise_on_failed_trial, # where can this go?
|
||||
max_failures: int or "fail-fast",
|
||||
trial_executor,
|
||||
restore_from, # checkpoint path to restore from
|
||||
resources_per_trial,
|
||||
num_samples,
|
||||
search_space, # I'm not a big fan of this because Search Algs have their own search_space too
|
||||
Optimizer,
|
||||
CheckpointManager)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
checkpoint_manager = State(location)
|
||||
checkpoint_manager.optimizer_state
|
||||
checkpoint_manager.generator_state
|
||||
checkpoint_manager.trial_state
|
||||
|
||||
# How much have we learned
|
||||
optimizer = Optimizer.from_checkpoint(checkpoint_manager)
|
||||
|
||||
optimizer = Optimizer(space, checkpoint=checkpoint_manager)
|
||||
for x, y in warm_start:
|
||||
optimizer.report(x, y)
|
||||
|
||||
samples = [optimizer.sample(random=True) for i in range(50)]
|
||||
|
||||
spec = TrialSpec(func, local_dir, checkpoint)
|
||||
|
||||
generator = TrialGenerator.from_checkpoint(checkpoint, optimizer)
|
||||
|
||||
generator = TrialGenerator.from_trials(trials)
|
||||
|
||||
generator = TrialGenerator.from_spec(spec, optimizer)
|
||||
generator.configure(checkpoint_callback)
|
||||
generator.queue(samples)
|
||||
generator.queue(num_samples=50, repeat=3, max_concurrent=4)
|
||||
generator.next()
|
||||
|
||||
generator = TrialGenerator.from_multi_spec(spec)
|
||||
|
||||
run(generator)
|
||||
###################################################
|
||||
|
||||
# Exploration process
|
||||
trial_list = get_trials(checkpoint_manager)
|
||||
failed_trials = [t.reset() for t in trial_list if t.status == "FAILED"]
|
||||
|
||||
generator = TrialGenerator.from_trials(failed_trials)
|
||||
tune.run(generator)
|
||||
|
||||
builder = Builder()
|
||||
|
||||
for params in samples:
|
||||
yield builder.build(params)
|
||||
@@ -0,0 +1,71 @@
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: sgd-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 0
|
||||
initial_workers: 0
|
||||
max_workers: 0
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
# If a node is idle for this many minutes, it will be removed.
|
||||
idle_timeout_minutes: 20
|
||||
# docker:
|
||||
# image: tensorflow/tensorflow:1.5.0-py3
|
||||
# container_name: ray_docker
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-west-2
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
head_node:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: p3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
# Run workers on spot by default. Comment this out to use on-demand.
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
setup_commands:
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- pip install -U ipdb ray[rllib] torch torchvision
|
||||
# Install apex.
|
||||
# - rm -rf apex || true
|
||||
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
|
||||
|
||||
|
||||
file_mounts: {
|
||||
}
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
||||
|
||||
# Custom commands that will be run on worker nodes after common setup.
|
||||
worker_setup_commands: []
|
||||
|
||||
# # Command to start ray on the head node. You don't need to change this.
|
||||
head_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --head --redis-port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
|
||||
|
||||
# Command to start ray on worker nodes. You don't need to change this.
|
||||
worker_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import argparse
|
||||
import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import wandb
|
||||
|
||||
from ray import tune
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.integration.wandb import WandbLogger, WandbTrainableMixin, \
|
||||
wandb_mixin
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
|
||||
|
||||
def train_function(config, checkpoint_dir=None):
|
||||
for i in range(30):
|
||||
loss = config["mean"] + config["sd"] * np.random.randn()
|
||||
tune.report(loss=loss)
|
||||
|
||||
|
||||
def tune_function(api_key_file):
|
||||
"""Example for using a WandbLogger with the function API"""
|
||||
tune.run(
|
||||
train_function,
|
||||
config={
|
||||
"mean": tune.grid_search([1, 2, 3, 4, 5]),
|
||||
"sd": tune.uniform(0.2, 0.8),
|
||||
"wandb": {
|
||||
"api_key_file": api_key_file,
|
||||
"project": "Wandb_example"
|
||||
}
|
||||
},
|
||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
||||
|
||||
|
||||
@wandb_mixin
|
||||
def decorated_train_function(config, checkpoint_dir=None):
|
||||
for i in range(30):
|
||||
loss = config["mean"] + config["sd"] * np.random.randn()
|
||||
tune.report(loss=loss)
|
||||
wandb.log(dict(loss=loss))
|
||||
|
||||
|
||||
def tune_decorated(api_key_file):
|
||||
"""Example for using the @wandb_mixin decorator with the function API"""
|
||||
tune.run(
|
||||
decorated_train_function,
|
||||
config={
|
||||
"mean": tune.grid_search([1, 2, 3, 4, 5]),
|
||||
"sd": tune.uniform(0.2, 0.8),
|
||||
"wandb": {
|
||||
"api_key_file": api_key_file,
|
||||
"project": "Wandb_example"
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
class WandbTrainable(WandbTrainableMixin, Trainable):
|
||||
def step(self):
|
||||
for i in range(30):
|
||||
loss = self.config["mean"] + self.config["sd"] * np.random.randn()
|
||||
wandb.log({"loss": loss})
|
||||
return {"loss": loss, "done": True}
|
||||
|
||||
|
||||
def tune_trainable(api_key_file):
|
||||
"""Example for using a WandTrainableMixin with the class API"""
|
||||
tune.run(
|
||||
WandbTrainable,
|
||||
config={
|
||||
"mean": tune.grid_search([1, 2, 3, 4, 5]),
|
||||
"sd": tune.uniform(0.2, 0.8),
|
||||
"wandb": {
|
||||
"api_key_file": api_key_file,
|
||||
"project": "Wandb_example"
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--mock-api", action="store_true", help="Mock Wandb API access")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
api_key_file = "~/.wandb_api_key"
|
||||
|
||||
if args.mock_api:
|
||||
WandbLogger._logger_process_cls = MagicMock
|
||||
decorated_train_function.__mixins__ = tuple()
|
||||
WandbTrainable._wandb = MagicMock()
|
||||
wandb = MagicMock() # noqa: F811
|
||||
temp_file = tempfile.NamedTemporaryFile()
|
||||
temp_file.write(b"1234")
|
||||
temp_file.flush()
|
||||
api_key_file = temp_file.name
|
||||
|
||||
tune_function(api_key_file)
|
||||
tune_decorated(api_key_file)
|
||||
tune_trainable(api_key_file)
|
||||
|
||||
if args.mock_api:
|
||||
temp_file.close()
|
||||
@@ -369,7 +369,15 @@ def detect_checkpoint_function(train_func, abort=False):
|
||||
|
||||
|
||||
def wrap_function(train_func):
|
||||
class ImplicitFunc(FunctionRunner):
|
||||
if hasattr(train_func, "__mixins__"):
|
||||
inherit_from = train_func.__mixins__ + (FunctionRunner, )
|
||||
else:
|
||||
inherit_from = (FunctionRunner, )
|
||||
|
||||
class ImplicitFunc(*inherit_from):
|
||||
_name = train_func.__name__ if hasattr(train_func, "__name__") \
|
||||
else "func"
|
||||
|
||||
def _trainable_func(self, config, reporter, checkpoint_dir):
|
||||
func_args = inspect.getfullargspec(train_func).args
|
||||
if len(func_args) > 1: # more arguments than just the config
|
||||
|
||||
@@ -0,0 +1,345 @@
|
||||
import os
|
||||
|
||||
from multiprocessing import Process, Queue
|
||||
from numbers import Number
|
||||
|
||||
from ray import logger
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.function_runner import FunctionRunner
|
||||
from ray.tune.logger import Logger
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError:
|
||||
logger.error("pip install 'wandb' to use WandbLogger/WandbTrainableMixin.")
|
||||
wandb = None
|
||||
|
||||
WANDB_ENV_VAR = "WANDB_API_KEY"
|
||||
_WANDB_QUEUE_END = (None, )
|
||||
|
||||
|
||||
def wandb_mixin(func):
|
||||
"""wandb_mixin
|
||||
|
||||
Weights and biases (https://www.wandb.com/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
Trainable mixin helps initializing the Wandb API for use with the
|
||||
``Trainable`` class or with `@wandb_mixin` for the function API.
|
||||
|
||||
For basic usage, just prepend your training function with the
|
||||
``@wandb_mixin`` decorator:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.integration.wandb import wandb_mixin
|
||||
|
||||
@wandb_mixin
|
||||
def train_fn(config):
|
||||
wandb.log()
|
||||
|
||||
|
||||
Wandb configuration is done by passing a ``wandb`` key to
|
||||
the ``config`` parameter of ``tune.run()`` (see example below).
|
||||
|
||||
The content of the ``wandb`` config entry is passed to ``wandb.init()``
|
||||
as keyword arguments. The exception are the following settings, which
|
||||
are used to configure the ``WandbTrainableMixin`` itself:
|
||||
|
||||
Args:
|
||||
api_key_file (str): Path to file containing the Wandb API KEY.
|
||||
api_key (str): Wandb API Key. Alternative to setting `api_key_file`.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
by Tune, but can be overwritten by filling out the respective configuration
|
||||
values.
|
||||
|
||||
Please see here for all other valid configuration settings:
|
||||
https://docs.wandb.com/library/init
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray import tune
|
||||
from ray.tune.integration.wandb import wandb_mixin
|
||||
|
||||
@wandb_mixin
|
||||
def train_fn(config):
|
||||
for i in range(10):
|
||||
loss = self.config["a"] + self.config["b"]
|
||||
wandb.log({"loss": loss})
|
||||
tune.report(loss=loss, done=True)
|
||||
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"a": tune.choice([1, 2, 3]),
|
||||
"b": tune.choice([4, 5, 6]),
|
||||
# wandb configuration
|
||||
"wandb": {
|
||||
"project": "Optimization_Project",
|
||||
"api_key_file": "/path/to/file"
|
||||
}
|
||||
})
|
||||
|
||||
"""
|
||||
func.__mixins__ = (WandbTrainableMixin, )
|
||||
func.__wandb_group__ = func.__name__
|
||||
return func
|
||||
|
||||
|
||||
def _set_api_key(wandb_config):
|
||||
"""Set WandB API key from `wandb_config`. Will pop the
|
||||
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
|
||||
api_key_file = os.path.expanduser(wandb_config.pop("api_key_file", ""))
|
||||
api_key = wandb_config.pop("api_key", None)
|
||||
|
||||
if api_key_file:
|
||||
if api_key:
|
||||
raise ValueError("Both WandB `api_key_file` and `api_key` set.")
|
||||
with open(api_key_file, "rt") as fp:
|
||||
api_key = fp.readline().strip()
|
||||
if api_key:
|
||||
os.environ[WANDB_ENV_VAR] = api_key
|
||||
elif not os.environ.get(WANDB_ENV_VAR):
|
||||
raise ValueError(
|
||||
"No WandB API key found. Either set the {} environment "
|
||||
"variable or pass `api_key` or `api_key_file` in the config".
|
||||
format(WANDB_ENV_VAR))
|
||||
|
||||
|
||||
class _WandbLoggingProcess(Process):
|
||||
"""
|
||||
We need a `multiprocessing.Process` to allow multiple concurrent
|
||||
wandb logging instances locally.
|
||||
"""
|
||||
|
||||
def __init__(self, queue, exclude, to_config, *args, **kwargs):
|
||||
super(_WandbLoggingProcess, self).__init__()
|
||||
self.queue = queue
|
||||
self._exclude = set(exclude)
|
||||
self._to_config = set(to_config)
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def run(self):
|
||||
wandb.init(*self.args, **self.kwargs)
|
||||
while True:
|
||||
result = self.queue.get()
|
||||
if result == _WANDB_QUEUE_END:
|
||||
break
|
||||
log, config_update = self._handle_result(result)
|
||||
wandb.config.update(config_update, allow_val_change=True)
|
||||
wandb.log(log)
|
||||
wandb.join()
|
||||
|
||||
def _handle_result(self, result):
|
||||
config_update = result.get("config", {}).copy()
|
||||
log = {}
|
||||
|
||||
for k, v in result.items():
|
||||
if k in self._to_config:
|
||||
config_update[k] = v
|
||||
elif k in self._exclude:
|
||||
continue
|
||||
elif not isinstance(v, Number):
|
||||
continue
|
||||
else:
|
||||
log[k] = v
|
||||
|
||||
config_update.pop("callbacks", None) # Remove callbacks
|
||||
return log, config_update
|
||||
|
||||
|
||||
class WandbLogger(Logger):
|
||||
"""WandbLogger
|
||||
|
||||
Weights and biases (https://www.wandb.com/) is a tool for experiment
|
||||
tracking, model optimization, and dataset versioning. This Ray Tune
|
||||
``Logger`` sends metrics to Wandb for automatic tracking and
|
||||
visualization.
|
||||
|
||||
Wandb configuration is done by passing a ``wandb`` key to
|
||||
the ``config`` parameter of ``tune.run()`` (see example below).
|
||||
|
||||
The content of the ``wandb`` config entry is passed to ``wandb.init()``
|
||||
as keyword arguments. The exception are the following settings, which
|
||||
are used to configure the WandbLogger itself:
|
||||
|
||||
Args:
|
||||
api_key_file (str): Path to file containing the Wandb API KEY.
|
||||
api_key (str): Wandb API Key. Alternative to setting ``api_key_file``.
|
||||
excludes (list): List of metrics that should be excluded from
|
||||
the log.
|
||||
log_config (bool): Boolean indicating if the ``config`` parameter of
|
||||
the ``results`` dict should be logged. This makes sense if
|
||||
parameters will change during training, e.g. with
|
||||
PopulationBasedTraining. Defaults to False.
|
||||
|
||||
Wandb's ``group``, ``run_id`` and ``run_name`` are automatically selected
|
||||
by Tune, but can be overwritten by filling out the respective configuration
|
||||
values.
|
||||
|
||||
Please see here for all other valid configuration settings:
|
||||
https://docs.wandb.com/library/init
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from ray.tune.logger import DEFAULT_LOGGERS
|
||||
from ray.tune.integration.wandb import WandbLogger
|
||||
tune.run(
|
||||
train_fn,
|
||||
config={
|
||||
# define search space here
|
||||
"parameter_1": tune.choice([1, 2, 3]),
|
||||
"parameter_2": tune.choice([4, 5, 6]),
|
||||
# wandb configuration
|
||||
"wandb": {
|
||||
"project": "Optimization_Project",
|
||||
"api_key_file": "/path/to/file",
|
||||
"log_config": True
|
||||
}
|
||||
},
|
||||
loggers=DEFAULT_LOGGERS + (WandbLogger, ))
|
||||
|
||||
"""
|
||||
|
||||
# Do not log these result keys
|
||||
_exclude_results = ["done", "should_checkpoint"]
|
||||
|
||||
# Use these result keys to update `wandb.config`
|
||||
_config_results = [
|
||||
"trial_id", "experiment_tag", "node_ip", "experiment_id", "hostname",
|
||||
"pid", "date"
|
||||
]
|
||||
|
||||
_logger_process_cls = _WandbLoggingProcess
|
||||
|
||||
def _init(self):
|
||||
config = self.config.copy()
|
||||
|
||||
try:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb logger specified but no configuration has been passed. "
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification.")
|
||||
|
||||
_set_api_key(wandb_config)
|
||||
|
||||
exclude_results = self._exclude_results.copy()
|
||||
|
||||
# Additional excludes
|
||||
additional_excludes = wandb_config.pop("excludes", [])
|
||||
exclude_results += additional_excludes
|
||||
|
||||
# Log config keys on each result?
|
||||
log_config = wandb_config.pop("log_config", False)
|
||||
if not log_config:
|
||||
exclude_results += ["config"]
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = self.trial.trial_id
|
||||
trial_name = str(self.trial)
|
||||
|
||||
# Project name for Wandb
|
||||
try:
|
||||
wandb_project = wandb_config.pop("project")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"You need to specify a `project` in your wandb `config` dict.")
|
||||
|
||||
# Grouping
|
||||
wandb_group = wandb_config.pop("group", self.trial.trainable_name)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
wandb_init_kwargs.update(wandb_config)
|
||||
|
||||
self._queue = Queue()
|
||||
self._wandb = self._logger_process_cls(
|
||||
queue=self._queue,
|
||||
exclude=exclude_results,
|
||||
to_config=self._config_results,
|
||||
**wandb_init_kwargs)
|
||||
self._wandb.start()
|
||||
|
||||
def on_result(self, result):
|
||||
self._queue.put(result)
|
||||
|
||||
def close(self):
|
||||
self._queue.put(_WANDB_QUEUE_END)
|
||||
self._wandb.join(timeout=10)
|
||||
|
||||
|
||||
class WandbTrainableMixin:
|
||||
_wandb = wandb
|
||||
|
||||
def __init__(self, config, *args, **kwargs):
|
||||
if not isinstance(self, Trainable):
|
||||
raise ValueError(
|
||||
"The `WandbTrainableMixin` can only be used as a mixin "
|
||||
"for `tune.Trainable` classes. Please make sure your "
|
||||
"class inherits from both. For example: "
|
||||
"`class YourTrainable(WandbTrainableMixin)`.")
|
||||
|
||||
super().__init__(config, *args, **kwargs)
|
||||
|
||||
config = config.copy()
|
||||
|
||||
try:
|
||||
wandb_config = config.pop("wandb").copy()
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Wandb mixin specified but no configuration has been passed. "
|
||||
"Make sure to include a `wandb` key in your `config` dict "
|
||||
"containing at least a `project` specification.")
|
||||
|
||||
_set_api_key(wandb_config)
|
||||
|
||||
# Fill trial ID and name
|
||||
trial_id = self.trial_id
|
||||
trial_name = self.trial_name
|
||||
|
||||
# Project name for Wandb
|
||||
try:
|
||||
wandb_project = wandb_config.pop("project")
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"You need to specify a `project` in your wandb `config` dict.")
|
||||
|
||||
# Grouping
|
||||
if isinstance(self, FunctionRunner):
|
||||
default_group = self._name
|
||||
else:
|
||||
default_group = type(self).__name__
|
||||
wandb_group = wandb_config.pop("group", default_group)
|
||||
|
||||
wandb_init_kwargs = dict(
|
||||
id=trial_id,
|
||||
name=trial_name,
|
||||
resume=True,
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
group=wandb_group,
|
||||
project=wandb_project,
|
||||
config=config)
|
||||
wandb_init_kwargs.update(wandb_config)
|
||||
|
||||
self.wandb = self._wandb.init(**wandb_init_kwargs)
|
||||
|
||||
def stop(self):
|
||||
self._wandb.join()
|
||||
if hasattr(super(), "stop"):
|
||||
super().stop()
|
||||
@@ -469,6 +469,10 @@ def _get_trial_info(trial, parameters, metrics):
|
||||
result = trial.last_result
|
||||
config = trial.config
|
||||
trial_info = [str(trial), trial.status, str(trial.location)]
|
||||
trial_info += [unflattened_lookup(param, config) for param in parameters]
|
||||
trial_info += [unflattened_lookup(metric, result) for metric in metrics]
|
||||
trial_info += [
|
||||
unflattened_lookup(param, config, default=None) for param in parameters
|
||||
]
|
||||
trial_info += [
|
||||
unflattened_lookup(metric, result, default=None) for metric in metrics
|
||||
]
|
||||
return trial_info
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
# An unique identifier for the head node and workers of this cluster.
|
||||
cluster_name: sgd-pytorch
|
||||
|
||||
# The maximum number of workers nodes to launch in addition to the head
|
||||
# node. This takes precedence over min_workers. min_workers default to 0.
|
||||
min_workers: 0
|
||||
initial_workers: 0
|
||||
max_workers: 0
|
||||
|
||||
target_utilization_fraction: 0.9
|
||||
|
||||
# If a node is idle for this many minutes, it will be removed.
|
||||
idle_timeout_minutes: 20
|
||||
# docker:
|
||||
# image: tensorflow/tensorflow:1.5.0-py3
|
||||
# container_name: ray_docker
|
||||
|
||||
# Cloud-provider specific configuration.
|
||||
provider:
|
||||
type: aws
|
||||
region: us-west-2
|
||||
|
||||
# How Ray will authenticate with newly launched nodes.
|
||||
auth:
|
||||
ssh_user: ubuntu
|
||||
|
||||
head_node:
|
||||
InstanceType: g3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
|
||||
worker_nodes:
|
||||
InstanceType: g3.8xlarge
|
||||
ImageId: latest_dlami
|
||||
# Run workers on spot by default. Comment this out to use on-demand.
|
||||
InstanceMarketOptions:
|
||||
MarketType: spot
|
||||
# SpotOptions:
|
||||
# MaxPrice: "9.0"
|
||||
|
||||
setup_commands:
|
||||
- ray || pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-0.9.0.dev0-cp36-cp36m-manylinux1_x86_64.whl
|
||||
- pip install -U ipdb ray[rllib] torch torchvision
|
||||
# Install apex.
|
||||
# - rm -rf apex || true
|
||||
# - git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --no-cache-dir ./ || true
|
||||
|
||||
|
||||
file_mounts: {
|
||||
~/anaconda3/lib/python3.6/site-packages/ray/tune: ./tune/,
|
||||
~/anaconda3/lib/python3.6/site-packages/ray/util/sgd/torch: ./util/sgd/torch/
|
||||
}
|
||||
|
||||
# Custom commands that will be run on the head node after common setup.
|
||||
head_setup_commands: []
|
||||
|
||||
# Custom commands that will be run on worker nodes after common setup.
|
||||
worker_setup_commands: []
|
||||
|
||||
# # Command to start ray on the head node. You don't need to change this.
|
||||
head_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --head --redis-port=6379 --object-manager-port=8076 --autoscaling-config=~/ray_bootstrap_config.yaml --object-store-memory=1000000000
|
||||
|
||||
# Command to start ray on worker nodes. You don't need to change this.
|
||||
worker_start_ray_commands:
|
||||
- ray stop
|
||||
- ray start --address=$RAY_HEAD_IP:6379 --object-manager-port=8076 --object-store-memory=1000000000
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,298 @@
|
||||
import os
|
||||
import tempfile
|
||||
from collections import namedtuple
|
||||
from multiprocessing import Queue
|
||||
import unittest
|
||||
|
||||
from ray.tune import Trainable
|
||||
from ray.tune.function_runner import wrap_function
|
||||
from ray.tune.integration.wandb import _WandbLoggingProcess, \
|
||||
_WANDB_QUEUE_END, WandbLogger, WANDB_ENV_VAR, WandbTrainableMixin, \
|
||||
wandb_mixin
|
||||
from ray.tune.result import TRIAL_INFO
|
||||
from ray.tune.trial import TrialInfo
|
||||
|
||||
Trial = namedtuple("MockTrial",
|
||||
["config", "trial_id", "trial_name", "trainable_name"])
|
||||
Trial.__str__ = lambda t: t.trial_name
|
||||
|
||||
|
||||
class _MockWandbLoggingProcess(_WandbLoggingProcess):
|
||||
def __init__(self, queue, exclude, to_config, *args, **kwargs):
|
||||
super(_MockWandbLoggingProcess,
|
||||
self).__init__(queue, exclude, to_config, *args, **kwargs)
|
||||
|
||||
self.logs = Queue()
|
||||
self.config_updates = Queue()
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
result = self.queue.get()
|
||||
if result == _WANDB_QUEUE_END:
|
||||
break
|
||||
log, config_update = self._handle_result(result)
|
||||
self.config_updates.put(config_update)
|
||||
self.logs.put(log)
|
||||
|
||||
|
||||
class WandbTestLogger(WandbLogger):
|
||||
_logger_process_cls = _MockWandbLoggingProcess
|
||||
|
||||
|
||||
class _MockWandbAPI(object):
|
||||
def init(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
return self
|
||||
|
||||
|
||||
class _MockWandbTrainableMixin(WandbTrainableMixin):
|
||||
_wandb = _MockWandbAPI()
|
||||
|
||||
|
||||
class WandbTestTrainable(_MockWandbTrainableMixin, Trainable):
|
||||
pass
|
||||
|
||||
|
||||
class WandbIntegrationTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def testWandbLoggerConfig(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# Needs at least a project
|
||||
with self.assertRaises(ValueError):
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
# No API key
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
with self.assertRaises(ValueError):
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
# API Key in config
|
||||
trial_config["wandb"] = {"project": "test_project", "api_key": "1234"}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
||||
|
||||
logger.close()
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key file
|
||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
||||
fp.write("5678")
|
||||
fp.flush()
|
||||
|
||||
trial_config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"api_key_file": fp.name
|
||||
}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
||||
|
||||
logger.close()
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key in env
|
||||
os.environ[WANDB_ENV_VAR] = "9012"
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
logger.close()
|
||||
|
||||
# From now on, the API key is in the env variable.
|
||||
|
||||
# Default configuration
|
||||
trial_config["wandb"] = {"project": "test_project"}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertEqual(logger._wandb.kwargs["project"], "test_project")
|
||||
self.assertEqual(logger._wandb.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(logger._wandb.kwargs["name"], trial.trial_name)
|
||||
self.assertEqual(logger._wandb.kwargs["group"], trial.trainable_name)
|
||||
self.assertIn("config", logger._wandb._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
# log config.
|
||||
trial_config["wandb"] = {"project": "test_project", "log_config": True}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertNotIn("config", logger._wandb._exclude)
|
||||
self.assertNotIn("metric", logger._wandb._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
# Exclude metric.
|
||||
trial_config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"excludes": ["metric"]
|
||||
}
|
||||
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
self.assertIn("config", logger._wandb._exclude)
|
||||
self.assertIn("metric", logger._wandb._exclude)
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbLoggerReporting(self):
|
||||
trial_config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(trial_config, 0, "trial_0", "trainable")
|
||||
|
||||
trial_config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"api_key": "1234",
|
||||
"excludes": ["metric2"]
|
||||
}
|
||||
logger = WandbTestLogger(trial_config, "/tmp", trial)
|
||||
|
||||
r1 = {
|
||||
"metric1": 0.8,
|
||||
"metric2": 1.4,
|
||||
"const": "text",
|
||||
"config": trial_config
|
||||
}
|
||||
|
||||
logger.on_result(r1)
|
||||
|
||||
logged = logger._wandb.logs.get(timeout=10)
|
||||
self.assertIn("metric1", logged)
|
||||
self.assertNotIn("metric2", logged)
|
||||
self.assertNotIn("const", logged)
|
||||
self.assertNotIn("config", logged)
|
||||
|
||||
logger.close()
|
||||
|
||||
def testWandbMixinConfig(self):
|
||||
config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(config, 0, "trial_0", "trainable")
|
||||
trial_info = TrialInfo(trial)
|
||||
|
||||
config[TRIAL_INFO] = trial_info
|
||||
|
||||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# Needs at least a project
|
||||
with self.assertRaises(ValueError):
|
||||
trainable = WandbTestTrainable(config)
|
||||
|
||||
# No API key
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
with self.assertRaises(ValueError):
|
||||
trainable = WandbTestTrainable(config)
|
||||
|
||||
# API Key in config
|
||||
config["wandb"] = {"project": "test_project", "api_key": "1234"}
|
||||
trainable = WandbTestTrainable(config)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
||||
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key file
|
||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
||||
fp.write("5678")
|
||||
fp.flush()
|
||||
|
||||
config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"api_key_file": fp.name
|
||||
}
|
||||
|
||||
trainable = WandbTestTrainable(config)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
||||
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key in env
|
||||
os.environ[WANDB_ENV_VAR] = "9012"
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
trainable = WandbTestTrainable(config)
|
||||
|
||||
# From now on, the API key is in the env variable.
|
||||
|
||||
# Default configuration
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
config[TRIAL_INFO] = trial_info
|
||||
|
||||
trainable = WandbTestTrainable(config)
|
||||
self.assertEqual(trainable.wandb.kwargs["project"], "test_project")
|
||||
self.assertEqual(trainable.wandb.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(trainable.wandb.kwargs["name"], trial.trial_name)
|
||||
self.assertEqual(trainable.wandb.kwargs["group"], "WandbTestTrainable")
|
||||
|
||||
def testWandbDecoratorConfig(self):
|
||||
config = {"par1": 4, "par2": 9.12345678}
|
||||
trial = Trial(config, 0, "trial_0", "trainable")
|
||||
trial_info = TrialInfo(trial)
|
||||
|
||||
@wandb_mixin
|
||||
def train_fn(config):
|
||||
return 1
|
||||
|
||||
train_fn.__mixins__ = (_MockWandbTrainableMixin, )
|
||||
|
||||
config[TRIAL_INFO] = trial_info
|
||||
|
||||
if WANDB_ENV_VAR in os.environ:
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# Needs at least a project
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
|
||||
# No API key
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
with self.assertRaises(ValueError):
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
|
||||
# API Key in config
|
||||
config["wandb"] = {"project": "test_project", "api_key": "1234"}
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "1234")
|
||||
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key file
|
||||
with tempfile.NamedTemporaryFile("wt") as fp:
|
||||
fp.write("5678")
|
||||
fp.flush()
|
||||
|
||||
config["wandb"] = {
|
||||
"project": "test_project",
|
||||
"api_key_file": fp.name
|
||||
}
|
||||
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
self.assertEqual(os.environ[WANDB_ENV_VAR], "5678")
|
||||
|
||||
del os.environ[WANDB_ENV_VAR]
|
||||
|
||||
# API Key in env
|
||||
os.environ[WANDB_ENV_VAR] = "9012"
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
|
||||
# From now on, the API key is in the env variable.
|
||||
|
||||
# Default configuration
|
||||
config["wandb"] = {"project": "test_project"}
|
||||
config[TRIAL_INFO] = trial_info
|
||||
|
||||
wrapped = wrap_function(train_fn)(config)
|
||||
self.assertEqual(wrapped.wandb.kwargs["project"], "test_project")
|
||||
self.assertEqual(wrapped.wandb.kwargs["id"], trial.trial_id)
|
||||
self.assertEqual(wrapped.wandb.kwargs["name"], trial.trial_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
import sys
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -0,0 +1,28 @@
|
||||
import pickle
|
||||
import ray
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ConvNet(nn.Module):
|
||||
def __init__(self):
|
||||
super(ConvNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 3, kernel_size=3)
|
||||
self.fc = nn.Linear(192, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 3))
|
||||
x = x.view(-1, 192)
|
||||
x = self.fc(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
def save_me():
|
||||
model = ConvNet()
|
||||
torch.save(model, "./test.th")
|
||||
return 1
|
||||
|
||||
ray_func = ray.remote(save_me)
|
||||
|
||||
ray.init()
|
||||
ray.get(ray_func.remote())
|
||||
@@ -0,0 +1 @@
|
||||
trialv2.py
|
||||
@@ -216,7 +216,7 @@ def flatten_dict(dt, delimiter="/"):
|
||||
return dt
|
||||
|
||||
|
||||
def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
|
||||
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
|
||||
"""
|
||||
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
|
||||
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
|
||||
@@ -232,8 +232,10 @@ def unflattened_lookup(flat_key, lookup, delimiter="/", default=None):
|
||||
base = base[int(key)]
|
||||
else:
|
||||
raise KeyError()
|
||||
except KeyError:
|
||||
return default
|
||||
except KeyError as e:
|
||||
if "default" in kwargs:
|
||||
return kwargs["default"]
|
||||
raise e
|
||||
return base
|
||||
|
||||
|
||||
|
||||
@@ -24,5 +24,6 @@ timm
|
||||
torch>=1.5.0
|
||||
torchvision>=0.6.0
|
||||
tune-sklearn==0.0.5
|
||||
wandb
|
||||
xgboost
|
||||
zoopt>=0.4.0
|
||||
|
||||
Reference in New Issue
Block a user