mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
[feature] add deepspeed, rallio dialogue dataset and codegen parameters
This commit is contained in:
@@ -6,7 +6,7 @@ defaults:
|
||||
per_device_eval_batch_size: 2
|
||||
weight_decay: 0.00
|
||||
warmup_steps: 600
|
||||
eval_steps: 200
|
||||
eval_steps: 100
|
||||
save_steps: 500
|
||||
max_length: 512
|
||||
num_train_epochs: 3
|
||||
@@ -17,6 +17,7 @@ defaults:
|
||||
freeze_layer:
|
||||
datasets:
|
||||
- webgpt
|
||||
- prompt_dialogue
|
||||
cache_dir: ~/.cache
|
||||
loss_fn: CrossEntropyLoss
|
||||
eval_size:
|
||||
@@ -43,6 +44,17 @@ gpt-jt:
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
codegen:
|
||||
learning_rate: 2e-6
|
||||
model_name: Salesforce/codegen-2B-multi
|
||||
weight_decay: 0.01
|
||||
max_length: 812
|
||||
warmup_steps: 600
|
||||
gradient_checkpointing: false
|
||||
gradient_accumulation_steps: 5
|
||||
per_device_train_batch_size: 4
|
||||
per_device_eval_batch_size: 4
|
||||
|
||||
debug:
|
||||
eval_steps: 20
|
||||
eval_size: 100
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
{
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"loss_scale_window": 1000,
|
||||
"initial_scale_power": 16,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 0.1
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": "auto",
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"reduce_bucket_size": "auto",
|
||||
"stage3_prefetch_bucket_size": "auto",
|
||||
"stage3_param_persistence_threshold": "auto",
|
||||
"sub_group_size": 1e9,
|
||||
"stage3_max_live_parameters": 1e9,
|
||||
"stage3_max_reuse_distance": 1e9,
|
||||
"stage3_gather_16bit_weights_on_model_save": true
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": 2.0,
|
||||
"steps_per_print": 2000,
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
@@ -2,6 +2,8 @@ from datasets import load_dataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import Dataset, Subset
|
||||
|
||||
from .prompt_dialogue import PromptGeneratedDataset
|
||||
|
||||
QA_SPECIAL_TOKENS = {"Question": "<question>", "Answer": "<answer>"}
|
||||
|
||||
|
||||
@@ -63,6 +65,9 @@ def get_one_dataset(conf, dataset_name):
|
||||
elif dataset_name == "webgpt":
|
||||
dataset = WebGPT()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
elif dataset_name == "prompt_dialogue":
|
||||
dataset = PromptGeneratedDataset()
|
||||
train, eval = train_val_dataset(dataset, val_split=0.2)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset {dataset_name}")
|
||||
|
||||
|
||||
@@ -81,6 +81,7 @@ class DialogueDataCollator:
|
||||
|
||||
batch["label_masks"] = torch.stack([F.pad(torch.tensor(x), (0, dim - len(x))) for x in label_masks])
|
||||
|
||||
# why the fuck?
|
||||
for k in list(batch.keys()):
|
||||
if k not in ["input_ids", "attention_mask", "label_masks"]:
|
||||
batch.pop(k)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from urllib.request import urlopen
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class PromptGeneratedDataset(Dataset):
|
||||
"""Generates from flan 11B
|
||||
User: What are the best methods for preventing a slave trade?
|
||||
|
||||
Rosey: The best methods ....
|
||||
<|endoftext|>
|
||||
|
||||
we are ignoring results with multiple lines for now
|
||||
"""
|
||||
|
||||
url = "https://github.com/Rallio67/language-model-agents/raw/main/chat_dialogue_v2_c.txt"
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
os.makedirs("datasets", exist_ok=True)
|
||||
chat_dialogue = os.path.join("datasets", "chat_dialogue_v2_c.txt")
|
||||
if not os.path.exists(chat_dialogue):
|
||||
with urlopen(self.url) as file:
|
||||
content = file.read().decode()
|
||||
with open(chat_dialogue, "w") as fout:
|
||||
fout.write(content)
|
||||
|
||||
question = ""
|
||||
answer = ""
|
||||
self.pairs = []
|
||||
with open(chat_dialogue, "r") as f:
|
||||
|
||||
for line in f:
|
||||
try:
|
||||
line = line.rstrip()
|
||||
if len(line) == 0:
|
||||
continue
|
||||
elif line == "<|endoftext|>" and len(question) > 0 and len(answer) > 0:
|
||||
self.pairs.append((question, answer))
|
||||
question = ""
|
||||
answer = ""
|
||||
elif line[:4] == "User":
|
||||
question = line.split(":", maxsplit=1)[1]
|
||||
elif line == "Rosey:":
|
||||
question = ""
|
||||
answer = ""
|
||||
elif len(line) > 4: # should be bot answer
|
||||
answer = line.split(":", maxsplit=1)[1]
|
||||
except IndexError:
|
||||
question = ""
|
||||
answer = ""
|
||||
|
||||
if len(question) > 0 and len(answer) > 0:
|
||||
self.pairs.append((question, answer))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.pairs)
|
||||
|
||||
def __getitem__(self, index):
|
||||
question, answer = self.pairs[index]
|
||||
return question, answer
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .dialogue_collator import DialogueDataCollator
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-multi")
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
dataset = PromptGeneratedDataset()
|
||||
collate_fn = DialogueDataCollator(tokenizer, padding=True, max_length=128)
|
||||
dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=5)
|
||||
for batch in dataloader:
|
||||
print(batch["input_ids"].shape)
|
||||
@@ -7,7 +7,7 @@ class CrossEntropyLoss(nn.CrossEntropyLoss):
|
||||
|
||||
def forward(self, input, target, mask=None):
|
||||
if mask is not None:
|
||||
mask = mask.view(-1)
|
||||
mask = mask.view(-1).bool()
|
||||
input = input.view(-1, input.size(-1))
|
||||
target = target.view(-1)
|
||||
input = input[mask]
|
||||
|
||||
@@ -1,32 +1,17 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import (
|
||||
DataCollator,
|
||||
EvalPrediction,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
get_cosine_schedule_with_warmup,
|
||||
)
|
||||
from transformers import PreTrainedModel, Trainer, TrainingArguments, get_cosine_schedule_with_warmup
|
||||
from utils import get_dataset, get_loss, get_model, get_tokenizer, read_yamls
|
||||
|
||||
os.environ["WANDB_PROJECT"] = "supervised-finetuning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomTrainingArguments(TrainingArguments):
|
||||
loss_function: str = "CrossEntropyLoss"
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
pred_ids = eval_pred.predictions
|
||||
labels = eval_pred.label_ids
|
||||
@@ -44,32 +29,13 @@ class SFTTrainer(Trainer):
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Dataset] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Callable[[], PreTrainedModel] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] = None,
|
||||
loss_function: str = "CrossEntropyLoss",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
super().__init__(model, args, **kwargs)
|
||||
|
||||
# By default CrossEntropyLoss ignores padding_index -100, but just in case use our own loss_fct
|
||||
self.loss_fct = get_loss(args.loss_function)
|
||||
self.loss_fct = get_loss(loss_function)
|
||||
|
||||
def fetch_scheduler(self):
|
||||
return get_cosine_schedule_with_warmup(
|
||||
@@ -131,6 +97,10 @@ def _strtobool(x):
|
||||
def argument_parsing(notebook=False, notebook_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--configs", nargs="+", required=True)
|
||||
parser.add_argument("--local_rank", type=int, default=-1)
|
||||
parser.add_argument("--deepspeed", action="store_true")
|
||||
parser.add_argument("--no-deepspeed", dest="deepspeed", action="store_false")
|
||||
parser.set_defaults(deepspeed=False)
|
||||
|
||||
if notebook:
|
||||
args, remaining = parser.parse_known_args(notebook_args)
|
||||
@@ -147,6 +117,8 @@ def argument_parsing(notebook=False, notebook_args=None):
|
||||
else:
|
||||
conf.update(configs[name])
|
||||
|
||||
conf["local_rank"] = args.local_rank
|
||||
conf["deepspeed"] = args.deepspeed
|
||||
# Override config from command-line
|
||||
parser = argparse.ArgumentParser()
|
||||
for key, value in conf.items():
|
||||
@@ -166,13 +138,14 @@ if __name__ == "__main__":
|
||||
|
||||
train, evals, collate_fn = get_dataset(training_conf, tokenizer)
|
||||
|
||||
args = CustomTrainingArguments(
|
||||
args = TrainingArguments(
|
||||
output_dir=f"{training_conf.model_name}-{training_conf.log_dir}-finetuned",
|
||||
num_train_epochs=training_conf.num_train_epochs,
|
||||
warmup_steps=training_conf.warmup_steps,
|
||||
loss_function=training_conf.loss_fn,
|
||||
learning_rate=float(training_conf.learning_rate),
|
||||
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
|
||||
fp16=True,
|
||||
local_rank=training_conf.local_rank,
|
||||
gradient_checkpointing=training_conf.gradient_checkpointing,
|
||||
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,
|
||||
per_device_train_batch_size=training_conf.per_device_train_batch_size,
|
||||
@@ -192,6 +165,7 @@ if __name__ == "__main__":
|
||||
trainer = SFTTrainer(
|
||||
model,
|
||||
args,
|
||||
loss_function=training_conf.loss_fn,
|
||||
train_dataset=train,
|
||||
eval_dataset=evals,
|
||||
data_collator=collate_fn,
|
||||
|
||||
@@ -8,14 +8,16 @@ from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
SUPPORTED_MODELS = ["galactica", "GPT-JT"] # deprecated ..
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
elif "GPT-JT" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
|
||||
elif "codegen" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
@@ -30,11 +32,6 @@ def get_tokenizer(conf):
|
||||
|
||||
|
||||
def get_model(conf, tokenizer):
|
||||
if not any([x in conf.model_name for x in SUPPORTED_MODELS]):
|
||||
raise ValueError(
|
||||
f"Model {conf.model_name} not supported. Supported models: {SUPPORTED_MODELS}. "
|
||||
"To include more make sure the masking is dne correctly... (decoder only supported for now)"
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user