mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 22:03:57 +08:00
[Tune] Update PBT Transformers Example (#10289)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu> Co-authored-by: krfricke <krfricke@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import os
|
||||
|
||||
import ray
|
||||
from ray.tune import CLIReporter
|
||||
from ray.tune.integration.wandb import wandb_mixin # noqa: F401
|
||||
from ray.tune.schedulers import PopulationBasedTraining
|
||||
|
||||
from ray import tune
|
||||
@@ -15,22 +16,15 @@ from transformers import (AutoConfig, AutoModelForSequenceClassification,
|
||||
Trainer, TrainingArguments)
|
||||
|
||||
|
||||
def get_trainer(model_name_or_path,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
task_name,
|
||||
training_args,
|
||||
wandb_args=None):
|
||||
def get_trainer(model_name_or_path, train_dataset, eval_dataset, task_name,
|
||||
training_args):
|
||||
try:
|
||||
num_labels = glue_tasks_num_labels[task_name]
|
||||
except KeyError:
|
||||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
config = AutoConfig.from_pretrained(
|
||||
model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=task_name,
|
||||
)
|
||||
model_name_or_path, num_labels=num_labels, finetuning_task=task_name)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_name_or_path,
|
||||
@@ -41,8 +35,7 @@ def get_trainer(model_name_or_path,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=build_compute_metrics_fn(task_name),
|
||||
wandb_args=wandb_args)
|
||||
compute_metrics=build_compute_metrics_fn(task_name))
|
||||
|
||||
return tune_trainer
|
||||
|
||||
@@ -62,6 +55,8 @@ def recover_checkpoint(tune_checkpoint_dir, model_name=None):
|
||||
|
||||
|
||||
# __train_begin__
|
||||
# Uncomment this line to use W&B!
|
||||
# @wandb_mixin
|
||||
def train_transformer(config, checkpoint_dir=None):
|
||||
data_args = DataTrainingArguments(
|
||||
task_name=config["task_name"], data_dir=config["data_dir"])
|
||||
@@ -96,21 +91,9 @@ def train_transformer(config, checkpoint_dir=None):
|
||||
logging_dir="./logs",
|
||||
)
|
||||
|
||||
# Arguments for W&B.
|
||||
name = tune.get_trial_name()
|
||||
wandb_args = {
|
||||
"project_name": "transformers_pbt",
|
||||
"watch": "false", # Either set to gradient, false, or all
|
||||
"run_name": name,
|
||||
}
|
||||
|
||||
tune_trainer = get_trainer(
|
||||
recover_checkpoint(checkpoint_dir, config["model_name"]),
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
config["task_name"],
|
||||
training_args,
|
||||
wandb_args=wandb_args)
|
||||
train_dataset, eval_dataset, config["task_name"], training_args)
|
||||
tune_trainer.train(
|
||||
recover_checkpoint(checkpoint_dir, config["model_name"]))
|
||||
|
||||
@@ -159,6 +142,11 @@ def tune_transformer(num_samples=8,
|
||||
"weight_decay": tune.uniform(0.0, 0.3),
|
||||
"num_epochs": tune.choice([2, 3, 4, 5]),
|
||||
"max_steps": 1 if smoke_test else -1, # Used for smoke test.
|
||||
"wandb": {
|
||||
"project": "pbt_transformers",
|
||||
"reinit": True,
|
||||
"allow_val_change": True
|
||||
}
|
||||
}
|
||||
|
||||
scheduler = PopulationBasedTraining(
|
||||
|
||||
@@ -5,28 +5,20 @@ from typing import Dict, Optional, Tuple
|
||||
from ray import tune
|
||||
|
||||
import transformers
|
||||
from transformers.file_utils import is_torch_tpu_available
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import wandb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
"""A Trainer class integrated with Tune.
|
||||
The only changes to the original transformers.Trainer are:
|
||||
- Report eval metrics to Tune
|
||||
- Save state using Tune's checkpoint directories
|
||||
- Pass in extra arguments for wandb
|
||||
"""
|
||||
|
||||
|
||||
class TuneTransformerTrainer(transformers.Trainer):
|
||||
def __init__(self, *args, wandb_args=None, **kwargs):
|
||||
self.wandb_args = wandb_args
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_optimizers(
|
||||
self, num_training_steps: int
|
||||
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]:
|
||||
@@ -58,22 +50,3 @@ class TuneTransformerTrainer(transformers.Trainer):
|
||||
os.path.join(output_dir, "optimizer.pt"))
|
||||
torch.save(self.current_scheduler.state_dict(),
|
||||
os.path.join(output_dir, "scheduler.pt"))
|
||||
|
||||
def _setup_wandb(self):
|
||||
if self.is_world_master() and self.wandb_args is not None:
|
||||
wandb.init(
|
||||
project=self.wandb_args["project_name"],
|
||||
name=self.wandb_args["run_name"],
|
||||
id=self.wandb_args["run_name"],
|
||||
dir=tune.get_trial_dir(),
|
||||
config=vars(self.args),
|
||||
reinit=True,
|
||||
allow_val_change=True,
|
||||
resume=self.wandb_args["run_name"])
|
||||
# keep track of model topology and gradients, unsupported on TPU
|
||||
if not is_torch_tpu_available(
|
||||
) and self.wandb_args["watch"] != "false":
|
||||
wandb.watch(
|
||||
self.model,
|
||||
log=self.wandb_args["watch"],
|
||||
log_freq=max(100, self.args.logging_steps))
|
||||
|
||||
@@ -103,10 +103,18 @@ def _set_api_key(wandb_config):
|
||||
if api_key:
|
||||
os.environ[WANDB_ENV_VAR] = api_key
|
||||
elif not os.environ.get(WANDB_ENV_VAR):
|
||||
try:
|
||||
# Check if user is already logged into wandb.
|
||||
wandb.ensure_configured()
|
||||
if wandb.api.api_key:
|
||||
logger.info("Already logged into W&B.")
|
||||
return
|
||||
except AttributeError:
|
||||
pass
|
||||
raise ValueError(
|
||||
"No WandB API key found. Either set the {} environment "
|
||||
"variable or pass `api_key` or `api_key_file` in the config".
|
||||
format(WANDB_ENV_VAR))
|
||||
"variable, pass `api_key` or `api_key_file` in the config, "
|
||||
"or run `wandb login` from the command line".format(WANDB_ENV_VAR))
|
||||
|
||||
|
||||
class _WandbLoggingProcess(Process):
|
||||
|
||||
Reference in New Issue
Block a user