Merge pull request #1262 from jackapbutler/create-tokeniser-configs

Add tokenizer config classes
This commit is contained in:
theblackcat102
2023-02-08 09:34:57 +08:00
committed by GitHub
+37 -12
View File
@@ -1,5 +1,6 @@
# from functools import partial
from pathlib import Path
from typing import NamedTuple
import evaluate
@@ -16,19 +17,43 @@ from sklearn.model_selection import train_test_split
from torch.utils.data import ConcatDataset, Subset
def get_tokenizer(conf):
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
class SpecialTokens(NamedTuple):
pad_token: str = ""
eos_token: str = ""
sep_token: str = ""
if "galactica" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
elif "GPT-JT" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
elif "codegen" in conf.model_name:
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
elif "pythia" in conf.model_name:
tokenizer.add_special_tokens(
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
)
class TokenizerConfig(NamedTuple):
special_tokens: SpecialTokens = {}
TOKENIZER_CONFIGS = {
"galactica": TokenizerConfig(special_tokens=SpecialTokens("<pad>", "</s>")),
"GPT-JT": TokenizerConfig(special_tokens=SpecialTokens(sep_token="<|extratoken_100|>")),
"codegen": TokenizerConfig(special_tokens=SpecialTokens("<|endoftext|>", sep_token="<|endoftext|>")),
"pythia": TokenizerConfig(special_tokens=SpecialTokens("<|padding|>", "<|endoftext|>", "<|endoftext|>")),
}
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
"""Match a partial model name to a tokenizer configuration"""
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
if not tokenizer_config_matches:
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")
elif 1 < len(tokenizer_config_matches):
raise ValueError(f"Found multiple tokeniser configuration matches for {model_name=}")
else:
return tokenizer_config_matches[0]
def get_tokenizer(conf) -> transformers.AutoTokenizer:
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
tokenizer_config = match_tokenizer_name(conf.model_name)
if tokenizer_config.special_tokens:
if "GPT-JT" in conf.model_name:
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
additional_special_tokens = (
[]