[Tune] Update PBT Transformers Example (#10289)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
Amog Kamsetty
2020-08-27 08:25:05 -07:00
committed by GitHub
parent 53ab228b75
commit 0aec4cbccb
4 changed files with 33 additions and 54 deletions
@@ -2,6 +2,7 @@ import os
import ray
from ray.tune import CLIReporter
from ray.tune.integration.wandb import wandb_mixin # noqa: F401
from ray.tune.schedulers import PopulationBasedTraining
from ray import tune
@@ -15,22 +16,15 @@ from transformers import (AutoConfig, AutoModelForSequenceClassification,
Trainer, TrainingArguments)
def get_trainer(model_name_or_path,
train_dataset,
eval_dataset,
task_name,
training_args,
wandb_args=None):
def get_trainer(model_name_or_path, train_dataset, eval_dataset, task_name,
training_args):
try:
num_labels = glue_tasks_num_labels[task_name]
except KeyError:
raise ValueError("Task not found: %s" % (task_name))
config = AutoConfig.from_pretrained(
model_name_or_path,
num_labels=num_labels,
finetuning_task=task_name,
)
model_name_or_path, num_labels=num_labels, finetuning_task=task_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name_or_path,
@@ -41,8 +35,7 @@ def get_trainer(model_name_or_path,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
compute_metrics=build_compute_metrics_fn(task_name),
wandb_args=wandb_args)
compute_metrics=build_compute_metrics_fn(task_name))
return tune_trainer
@@ -62,6 +55,8 @@ def recover_checkpoint(tune_checkpoint_dir, model_name=None):
# __train_begin__
# Uncomment this line to use W&B!
# @wandb_mixin
def train_transformer(config, checkpoint_dir=None):
data_args = DataTrainingArguments(
task_name=config["task_name"], data_dir=config["data_dir"])
@@ -96,21 +91,9 @@ def train_transformer(config, checkpoint_dir=None):
logging_dir="./logs",
)
# Arguments for W&B.
name = tune.get_trial_name()
wandb_args = {
"project_name": "transformers_pbt",
"watch": "false", # Either set to gradient, false, or all
"run_name": name,
}
tune_trainer = get_trainer(
recover_checkpoint(checkpoint_dir, config["model_name"]),
train_dataset,
eval_dataset,
config["task_name"],
training_args,
wandb_args=wandb_args)
train_dataset, eval_dataset, config["task_name"], training_args)
tune_trainer.train(
recover_checkpoint(checkpoint_dir, config["model_name"]))
@@ -159,6 +142,11 @@ def tune_transformer(num_samples=8,
"weight_decay": tune.uniform(0.0, 0.3),
"num_epochs": tune.choice([2, 3, 4, 5]),
"max_steps": 1 if smoke_test else -1, # Used for smoke test.
"wandb": {
"project": "pbt_transformers",
"reinit": True,
"allow_val_change": True
}
}
scheduler = PopulationBasedTraining(
@@ -5,28 +5,20 @@ from typing import Dict, Optional, Tuple
from ray import tune
import transformers
from transformers.file_utils import is_torch_tpu_available
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
import torch
from torch.utils.data import Dataset
import wandb
logger = logging.getLogger(__name__)
"""A Trainer class integrated with Tune.
The only changes to the original transformers.Trainer are:
- Report eval metrics to Tune
- Save state using Tune's checkpoint directories
- Pass in extra arguments for wandb
"""
class TuneTransformerTrainer(transformers.Trainer):
def __init__(self, *args, wandb_args=None, **kwargs):
self.wandb_args = wandb_args
super().__init__(*args, **kwargs)
def get_optimizers(
self, num_training_steps: int
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
@@ -58,22 +50,3 @@ class TuneTransformerTrainer(transformers.Trainer):
os.path.join(output_dir, "optimizer.pt"))
torch.save(self.current_scheduler.state_dict(),
os.path.join(output_dir, "scheduler.pt"))
def _setup_wandb(self):
if self.is_world_master() and self.wandb_args is not None:
wandb.init(
project=self.wandb_args["project_name"],
name=self.wandb_args["run_name"],
id=self.wandb_args["run_name"],
dir=tune.get_trial_dir(),
config=vars(self.args),
reinit=True,
allow_val_change=True,
resume=self.wandb_args["run_name"])
# keep track of model topology and gradients, unsupported on TPU
if not is_torch_tpu_available(
) and self.wandb_args["watch"] != "false":
wandb.watch(
self.model,
log=self.wandb_args["watch"],
log_freq=max(100, self.args.logging_steps))
+10 -2
View File
@@ -103,10 +103,18 @@ def _set_api_key(wandb_config):
if api_key:
os.environ[WANDB_ENV_VAR] = api_key
elif not os.environ.get(WANDB_ENV_VAR):
try:
# Check if user is already logged into wandb.
wandb.ensure_configured()
if wandb.api.api_key:
logger.info("Already logged into W&B.")
return
except AttributeError:
pass
raise ValueError(
"No WandB API key found. Either set the {} environment "
"variable or pass `api_key` or `api_key_file` in the config".
format(WANDB_ENV_VAR))
"variable, pass `api_key` or `api_key_file` in the config, "
"or run `wandb login` from the command line".format(WANDB_ENV_VAR))
class _WandbLoggingProcess(Process):