mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
201 lines
6.2 KiB
Python
201 lines
6.2 KiB
Python
import argparse
|
|
import os
|
|
from dataclasses import dataclass
|
|
from distutils.util import strtobool
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import Dataset
|
|
from transformers import (
|
|
DataCollator,
|
|
EvalPrediction,
|
|
PreTrainedModel,
|
|
PreTrainedTokenizerBase,
|
|
Trainer,
|
|
TrainerCallback,
|
|
TrainingArguments,
|
|
get_cosine_schedule_with_warmup,
|
|
)
|
|
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
|
|
|
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
|
|
|
|
|
@dataclass
|
|
class CustomTrainingArguments(TrainingArguments):
|
|
loss_function: str = "CrossEntropyLoss"
|
|
|
|
|
|
def compute_metrics(eval_pred):
|
|
pred_ids = eval_pred.predictions
|
|
labels = eval_pred.label_ids
|
|
|
|
return {"accuracy": (pred_ids[labels > 0] == labels[labels > 0]).mean()}
|
|
|
|
|
|
def preprocess_logits_for_metrics(logits, labels):
|
|
pred_ids = torch.argmax(logits, dim=-1)
|
|
return pred_ids
|
|
|
|
|
|
class SFTTrainer(Trainer):
|
|
def __init__(
|
|
self,
|
|
model: Union[PreTrainedModel, nn.Module] = None,
|
|
args: TrainingArguments = None,
|
|
data_collator: Optional[DataCollator] = None,
|
|
train_dataset: Optional[Dataset] = None,
|
|
eval_dataset: Optional[Dataset] = None,
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
model_init: Callable[[], PreTrainedModel] = None,
|
|
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
|
callbacks: Optional[List[TrainerCallback]] = None,
|
|
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
|
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
|
):
|
|
super().__init__(
|
|
model,
|
|
args,
|
|
data_collator,
|
|
train_dataset,
|
|
eval_dataset,
|
|
tokenizer,
|
|
model_init,
|
|
compute_metrics,
|
|
callbacks,
|
|
optimizers,
|
|
preprocess_logits_for_metrics,
|
|
)
|
|
self.loss_fct = get_loss(args.loss_function)
|
|
|
|
def fetch_scheduler(self):
|
|
return get_cosine_schedule_with_warmup(
|
|
self.optimizer,
|
|
num_warmup_steps=self.args.warmup_steps,
|
|
num_training_steps=self.num_train_steps,
|
|
num_cycles=1,
|
|
last_epoch=-1,
|
|
)
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False):
|
|
labels_mask = inputs.pop("label_masks")
|
|
|
|
outputs = model(**inputs)
|
|
|
|
loss = self.loss_fct(outputs.get("logits"), torch.roll(inputs["input_ids"], -1, -1), mask=labels_mask)
|
|
|
|
return (loss, outputs) if return_outputs else loss
|
|
|
|
def _compute_loss(self, model, inputs):
|
|
|
|
labels_mask = inputs.pop("label_masks")
|
|
|
|
inputs = self._prepare_inputs(inputs)
|
|
|
|
outputs = model(**inputs)
|
|
|
|
logits = outputs.get("logits")
|
|
|
|
targets = torch.roll(inputs["input_ids"], -1, -1)
|
|
loss = self.loss_fct(outputs.get("logits"), targets, mask=labels_mask)
|
|
|
|
return loss, logits, targets, labels_mask
|
|
|
|
def prediction_step(
|
|
self,
|
|
model: nn.Module,
|
|
inputs: Dict[str, Union[torch.Tensor, Any]],
|
|
prediction_loss_only: bool,
|
|
ignore_keys: Optional[List[str]] = None,
|
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
|
|
with torch.no_grad():
|
|
loss, logits, labels, labels_mask = self._compute_loss(model, inputs)
|
|
labels[~labels_mask] = -1
|
|
|
|
loss = loss.mean().detach()
|
|
|
|
if self.args.prediction_loss_only:
|
|
return (loss, None, None)
|
|
|
|
return (loss, logits, labels)
|
|
|
|
|
|
def _strtobool(x):
|
|
return bool(strtobool(x))
|
|
|
|
|
|
def argument_parsing(notebook=False, notebook_args=None):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--configs", nargs="+", required=True)
|
|
|
|
if notebook:
|
|
args, remaining = parser.parse_known_args(notebook_args)
|
|
else:
|
|
args, remaining = parser.parse_known_args()
|
|
|
|
# Config from YAML
|
|
conf = {}
|
|
configs = read_yamls("./configs")
|
|
for name in args.configs:
|
|
if "," in name:
|
|
for n in name.split(","):
|
|
conf.update(configs[n])
|
|
else:
|
|
conf.update(configs[name])
|
|
|
|
# Override config from command-line
|
|
parser = argparse.ArgumentParser()
|
|
for key, value in conf.items():
|
|
type_ = type(value) if value is not None else str
|
|
if type_ == bool:
|
|
type_ = _strtobool
|
|
parser.add_argument(f"--{key}", type=type_, default=value)
|
|
|
|
return parser.parse_args(remaining)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
training_conf = argument_parsing()
|
|
|
|
model = get_model(training_conf)
|
|
tokenizer = get_tokenizer(training_conf)
|
|
|
|
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
|
|
|
args = CustomTrainingArguments(
|
|
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
|
num_train_epochs=training_conf.num_train_epochs,
|
|
warmup_steps=training_conf.warmup_steps,
|
|
loss_function=training_conf.loss_fn,
|
|
learning_rate=float(training_conf.learning_rate),
|
|
fp16=True,
|
|
gradient_checkpointing=training_conf.gradient_checkpointing,
|
|
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
|
|
per_device_train_batch_size=training_conf.per_device_train_batch_size,
|
|
per_device_eval_batch_size=training_conf.per_device_eval_batch_size,
|
|
weight_decay=training_conf.weight_decay,
|
|
max_grad_norm=training_conf.max_grad_norm,
|
|
logging_steps=training_conf.logging_steps,
|
|
save_total_limit=training_conf.save_total_limit,
|
|
evaluation_strategy="steps",
|
|
eval_steps=training_conf.eval_steps,
|
|
save_steps=training_conf.save_steps,
|
|
eval_accumulation_steps=training_conf.eval_accumulation_steps,
|
|
report_to="wandb",
|
|
)
|
|
|
|
assert len(evals) > 0
|
|
trainer = SFTTrainer(
|
|
model,
|
|
args,
|
|
train_dataset=train,
|
|
eval_dataset=evals,
|
|
data_collator=collate_fn,
|
|
tokenizer=tokenizer,
|
|
compute_metrics=compute_metrics,
|
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
|
)
|
|
trainer.train()
|