mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
First version of single GPU sampling working
This commit is contained in:
@@ -16,29 +16,32 @@ defaults:
|
||||
eval_accumulation_steps:
|
||||
freeze_layer:
|
||||
datasets:
|
||||
- webgpt
|
||||
- prompt_dialogue
|
||||
- squad_v2
|
||||
- adversarial_qa
|
||||
- trivia_qa_nocontext
|
||||
- xsum
|
||||
- cnn_dailymail
|
||||
- prompt_dialogue
|
||||
- multi_news
|
||||
- scitldr
|
||||
- soda
|
||||
- joke
|
||||
- gsm8k
|
||||
- dive_mt
|
||||
- wmt2019_zh-en
|
||||
- wmt2019_ru-en
|
||||
- wmt2019_de-en
|
||||
- ted_trans_nl-en
|
||||
- ted_trans_de-ja
|
||||
- instruct_tuning
|
||||
- wmt2019_de-en
|
||||
- samsum
|
||||
- soda_dialogue
|
||||
- webgpt:
|
||||
fraction : 0.001
|
||||
- prompt_dialogue:
|
||||
size : 10
|
||||
- squad_v2:
|
||||
size : 10
|
||||
# - adversarial_qa
|
||||
# - trivia_qa_nocontext
|
||||
# - xsum
|
||||
# - cnn_dailymail
|
||||
# - prompt_dialogue
|
||||
# - multi_news
|
||||
# - scitldr
|
||||
# - soda
|
||||
# - joke
|
||||
# - gsm8k
|
||||
# - dive_mt
|
||||
# - wmt2019_zh-en
|
||||
# - wmt2019_ru-en
|
||||
# - wmt2019_de-en
|
||||
# - ted_trans_nl-en
|
||||
# - ted_trans_de-ja
|
||||
# - instruct_tuning
|
||||
# - wmt2019_de-en
|
||||
# - samsum
|
||||
# - soda_dialogue
|
||||
cache_dir: .cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
|
||||
@@ -4,7 +4,7 @@ datasets==2.8.0
|
||||
deepspeed==0.7.7
|
||||
evaluate==0.4.0
|
||||
gdown
|
||||
mpi4py==3.1.4
|
||||
# mpi4py==3.1.4
|
||||
nltk==3.8.1
|
||||
numpy>=1.22.4
|
||||
py7zr
|
||||
|
||||
@@ -8,7 +8,8 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments
|
||||
from transformers.training_args import OptimizerNames
|
||||
from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls
|
||||
from utils import (build_train_sampler, get_dataset, get_loss, get_metrics,
|
||||
get_model, get_tokenizer, read_yamls)
|
||||
|
||||
|
||||
def compute_metrics(eval_pred, preprocess_fns, metrics):
|
||||
@@ -30,6 +31,7 @@ class SFTTrainer(Trainer):
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
sampler: torch.utils.data.sampler.Sampler = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
poly_eps: float = 1.0,
|
||||
**kwargs,
|
||||
@@ -38,6 +40,7 @@ class SFTTrainer(Trainer):
|
||||
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(loss_function, poly_eps)
|
||||
self.sampler = sampler
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
labels_mask = inputs.pop("label_masks")
|
||||
@@ -88,6 +91,22 @@ class SFTTrainer(Trainer):
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
def get_train_dataloader(self):
|
||||
if self.sampler is None:
|
||||
torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=self.data_collator
|
||||
)
|
||||
else:
|
||||
return torch.utils.data.DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
sampler=self.sampler,
|
||||
collate_fn=self.data_collator
|
||||
)
|
||||
|
||||
|
||||
def _strtobool(x):
|
||||
return bool(strtobool(x))
|
||||
@@ -141,8 +160,8 @@ if __name__ == "__main__":
|
||||
model = get_model(training_conf, tokenizer)
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
sampler = build_train_sampler(training_conf, train.datasets)
|
||||
metrics, preprocess_fns = get_metrics(training_conf, tokenizer)
|
||||
|
||||
optimizer = OptimizerNames.ADAMW_BNB if training_conf.quantization else OptimizerNames.ADAMW_HF
|
||||
|
||||
if training_conf.quantization:
|
||||
@@ -159,7 +178,7 @@ if __name__ == "__main__":
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
|
||||
optim=optimizer,
|
||||
fp16=True,
|
||||
# fp16=True,
|
||||
local_rank=training_conf.local_rank,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
|
||||
@@ -177,19 +196,20 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
assert len(evals) > 0
|
||||
|
||||
if not training_conf.deepspeed or training_conf.local_rank == 0:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project="supervised-finetuning",
|
||||
entity=training_conf.wandb_entity,
|
||||
# entity=training_conf.wandb_entity,
|
||||
entity="maw501",
|
||||
name=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
)
|
||||
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
model=model,
|
||||
args=args,
|
||||
sampler=sampler,
|
||||
loss_function=training_conf.loss_fn,
|
||||
poly_eps=training_conf.poly_eps,
|
||||
train_dataset=train,
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
from pathlib import Path
|
||||
|
||||
import evaluate
|
||||
import random
|
||||
|
||||
# import nltk
|
||||
# import numpy as np
|
||||
import numpy as np
|
||||
import transformers
|
||||
import yaml
|
||||
from custom_datasets import get_one_dataset
|
||||
@@ -14,6 +15,35 @@ from losses import CrossEntropyLoss, PolyLoss
|
||||
from models import freeze_top_n_layers, get_specific_model
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class ClassSampler(Sampler):
|
||||
"""Sampler which returns a fixed number of samples per class, per epoch"""
|
||||
def __init__(self, class_labels, class_sizes):
|
||||
self.class_labels = class_labels
|
||||
self.class_sizes = class_sizes
|
||||
|
||||
def __iter__(self):
|
||||
out = []
|
||||
for i, _class in enumerate(np.unique(self.class_labels)):
|
||||
class_idx = np.argwhere(self.class_labels == _class).flatten()
|
||||
sampled_idx = random.sample(list(class_idx), int(self.class_sizes[i]))
|
||||
out.extend(sampled_idx)
|
||||
random.shuffle(out)
|
||||
return iter(out)
|
||||
|
||||
def __len__(self):
|
||||
return int(sum(self.class_sizes))
|
||||
|
||||
|
||||
def build_train_sampler(training_conf, datasets):
|
||||
train_sizes = [len(x) for x in datasets]
|
||||
fractions = get_dataset_fractions(training_conf.datasets, train_sizes)
|
||||
dataset_size_per_epoch = [int(size * frac) for size, frac in zip(train_sizes, fractions)]
|
||||
dataset_labels = [[i] * d for i, d in zip(range(len(dataset_size_per_epoch)), dataset_size_per_epoch)]
|
||||
dataset_labels = [i for s in dataset_labels for i in s]
|
||||
return ClassSampler(dataset_labels, dataset_size_per_epoch)
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
@@ -115,10 +145,35 @@ def get_model(conf, tokenizer):
|
||||
return model
|
||||
|
||||
|
||||
def get_dataset_name_from_data_config(data_config):
|
||||
if isinstance(data_config, dict):
|
||||
return list(data_config.keys())[0]
|
||||
return data_config
|
||||
|
||||
|
||||
def get_dataset_fractions(conf, dataset_sizes):
|
||||
fractions = []
|
||||
for i, data_config in enumerate(conf):
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
if isinstance(data_config, dict):
|
||||
if "fraction" in data_config[dataset_name]:
|
||||
fractions.append(min(1, data_config[dataset_name]["fraction"]))
|
||||
elif "size" in data_config[dataset_name]:
|
||||
if data_config[dataset_name]["size"] > dataset_sizes[i]:
|
||||
raise ValueError(f"Please specify a size smaller than number of examples ({dataset_sizes[i]})")
|
||||
fractions.append(data_config[dataset_name]["size"] / dataset_sizes[i])
|
||||
else:
|
||||
raise ValueError("Please specify either fraction or size in config.yaml")
|
||||
else:
|
||||
fractions.append(1)
|
||||
return fractions
|
||||
|
||||
|
||||
def get_dataset(conf, tokenizer):
|
||||
train_datasets, evals = [], {}
|
||||
|
||||
for dataset_name in conf.datasets:
|
||||
for data_config in conf.datasets:
|
||||
dataset_name = get_dataset_name_from_data_config(data_config)
|
||||
train, val = get_one_dataset(conf, dataset_name)
|
||||
train_datasets.append(train)
|
||||
evals[dataset_name] = Subset(val, list(range(min(len(val), conf.eval_size)))) if conf.eval_size else val
|
||||
|
||||
Reference in New Issue
Block a user