diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index f7a0ab15..c3b8264f 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -1,5 +1,6 @@ # from functools import partial from pathlib import Path +from typing import NamedTuple import evaluate @@ -16,19 +17,43 @@ from sklearn.model_selection import train_test_split from torch.utils.data import ConcatDataset, Subset -def get_tokenizer(conf): - tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) +class SpecialTokens(NamedTuple): + pad_token: str = "" + eos_token: str = "" + sep_token: str = "" - if "galactica" in conf.model_name: - tokenizer.add_special_tokens({"pad_token": "", "eos_token": ""}) - elif "GPT-JT" in conf.model_name: - tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"}) - elif "codegen" in conf.model_name: - tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"}) - elif "pythia" in conf.model_name: - tokenizer.add_special_tokens( - {"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"} - ) + +class TokenizerConfig(NamedTuple): + special_tokens: SpecialTokens = {} + + +TOKENIZER_CONFIGS = { + "galactica": TokenizerConfig(special_tokens=SpecialTokens("", "")), + "GPT-JT": TokenizerConfig(special_tokens=SpecialTokens(sep_token="<|extratoken_100|>")), + "codegen": TokenizerConfig(special_tokens=SpecialTokens("<|endoftext|>", sep_token="<|endoftext|>")), + "pythia": TokenizerConfig(special_tokens=SpecialTokens("<|padding|>", "<|endoftext|>", "<|endoftext|>")), +} + + +def match_tokenizer_name(model_name: str) -> TokenizerConfig: + """Match a partial model name to a tokenizer configuration""" + tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name] + if not tokenizer_config_matches: + raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}") + elif 1 < len(tokenizer_config_matches): + raise ValueError(f"Found multiple tokeniser configuration matches for {model_name=}") + else: + return tokenizer_config_matches[0] + + +def get_tokenizer(conf) -> transformers.AutoTokenizer: + tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir) + tokenizer_config = match_tokenizer_name(conf.model_name) + + if tokenizer_config.special_tokens: + if "GPT-JT" in conf.model_name: + tokenizer_config.special_tokens.pad_token = tokenizer.eos_token + tokenizer.add_special_tokens(tokenizer_config.special_tokens) additional_special_tokens = ( []