mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #1262 from jackapbutler/create-tokeniser-configs
Add tokenizer config classes
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
import evaluate
|
||||
|
||||
@@ -16,19 +17,43 @@ from sklearn.model_selection import train_test_split
|
||||
from torch.utils.data import ConcatDataset, Subset
|
||||
|
||||
|
||||
def get_tokenizer(conf):
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
class SpecialTokens(NamedTuple):
|
||||
pad_token: str = ""
|
||||
eos_token: str = ""
|
||||
sep_token: str = ""
|
||||
|
||||
if "galactica" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<pad>", "eos_token": "</s>"})
|
||||
elif "GPT-JT" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token, "sep_token": "<|extratoken_100|>"})
|
||||
elif "codegen" in conf.model_name:
|
||||
tokenizer.add_special_tokens({"pad_token": "<|endoftext|>", "sep_token": "<|endoftext|>"})
|
||||
elif "pythia" in conf.model_name:
|
||||
tokenizer.add_special_tokens(
|
||||
{"pad_token": "<|padding|>", "sep_token": "<|endoftext|>", "eos_token": "<|endoftext|>"}
|
||||
)
|
||||
|
||||
class TokenizerConfig(NamedTuple):
|
||||
special_tokens: SpecialTokens = {}
|
||||
|
||||
|
||||
TOKENIZER_CONFIGS = {
|
||||
"galactica": TokenizerConfig(special_tokens=SpecialTokens("<pad>", "</s>")),
|
||||
"GPT-JT": TokenizerConfig(special_tokens=SpecialTokens(sep_token="<|extratoken_100|>")),
|
||||
"codegen": TokenizerConfig(special_tokens=SpecialTokens("<|endoftext|>", sep_token="<|endoftext|>")),
|
||||
"pythia": TokenizerConfig(special_tokens=SpecialTokens("<|padding|>", "<|endoftext|>", "<|endoftext|>")),
|
||||
}
|
||||
|
||||
|
||||
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
|
||||
"""Match a partial model name to a tokenizer configuration"""
|
||||
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
|
||||
if not tokenizer_config_matches:
|
||||
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")
|
||||
elif 1 < len(tokenizer_config_matches):
|
||||
raise ValueError(f"Found multiple tokeniser configuration matches for {model_name=}")
|
||||
else:
|
||||
return tokenizer_config_matches[0]
|
||||
|
||||
|
||||
def get_tokenizer(conf) -> transformers.AutoTokenizer:
|
||||
tokenizer = transformers.AutoTokenizer.from_pretrained(conf.model_name, cache_dir=conf.cache_dir)
|
||||
tokenizer_config = match_tokenizer_name(conf.model_name)
|
||||
|
||||
if tokenizer_config.special_tokens:
|
||||
if "GPT-JT" in conf.model_name:
|
||||
tokenizer_config.special_tokens.pad_token = tokenizer.eos_token
|
||||
tokenizer.add_special_tokens(tokenizer_config.special_tokens)
|
||||
|
||||
additional_special_tokens = (
|
||||
[]
|
||||
|
||||
Reference in New Issue
Block a user