Merge pull request #1398 from jackapbutler/fix-tokenizer-match

Add tests and update docstring to tokenizer matching
This commit is contained in:
sanagnos
2023-02-10 11:37:00 +01:00
committed by GitHub
2 changed files with 24 additions and 2 deletions
@@ -1,9 +1,28 @@
from argparse import Namespace
from utils import get_tokenizer
import pytest
from utils import TOKENIZER_CONFIGS, get_tokenizer, match_tokenizer_name
def test_tokenizer():
get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache"))
get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache"))
get_tokenizer(Namespace(model_name="", cache_dir=".cache"))
def test_tokenizer_successful_match():
for config_name, config in TOKENIZER_CONFIGS.items():
found_config = match_tokenizer_name(config_name)
assert found_config == config
def test_tokenizer_partial_match():
for config_name in ["facebook/galactica-1.3b", "togethercomputer/GPT-JT-6B-v1", "Salesforce/codegen-2B-multi"]:
found_config = match_tokenizer_name(config_name)
assert found_config
def test_tokenizer_failed_match():
for fake_config_name in ["not-a-model", "fake"]:
with pytest.raises(ValueError):
match_tokenizer_name(fake_config_name)
+4 -1
View File
@@ -106,7 +106,10 @@ TOKENIZER_CONFIGS = {
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
"""Match a partial model name to a tokenizer configuration"""
"""
Match a partial model name to a tokenizer configuration
i.e. model_name `Salesforce/codegen-2B-multi` has config name `codegen`
"""
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
if not tokenizer_config_matches:
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")