mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
fix tokenizer matching and add tests
This commit is contained in:
@@ -1,9 +1,28 @@
|
||||
from argparse import Namespace
|
||||
|
||||
from utils import get_tokenizer
|
||||
import pytest
|
||||
from utils import TOKENIZER_CONFIGS, get_tokenizer, match_tokenizer_name
|
||||
|
||||
|
||||
def test_tokenizer():
|
||||
get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache"))
|
||||
get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache"))
|
||||
get_tokenizer(Namespace(model_name="", cache_dir=".cache"))
|
||||
|
||||
|
||||
def test_tokenizer_successful_match():
|
||||
for config_name, config in TOKENIZER_CONFIGS:
|
||||
found_config = match_tokenizer_name(config_name)
|
||||
assert found_config == config
|
||||
|
||||
|
||||
def test_tokenizer_partial_match():
|
||||
for config_name, config in TOKENIZER_CONFIGS:
|
||||
found_config = match_tokenizer_name(config_name[: len(config_name) - 1])
|
||||
assert found_config == config
|
||||
|
||||
|
||||
def test_tokenizer_failed_match():
|
||||
for fake_config_name in ["not-a-model", "fake"]:
|
||||
with pytest.raises(ValueError):
|
||||
match_tokenizer_name(fake_config_name)
|
||||
|
||||
@@ -37,7 +37,7 @@ TOKENIZER_CONFIGS = {
|
||||
|
||||
def match_tokenizer_name(model_name: str) -> TokenizerConfig:
|
||||
"""Match a partial model name to a tokenizer configuration"""
|
||||
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name]
|
||||
tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if model_name in name]
|
||||
if not tokenizer_config_matches:
|
||||
raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}")
|
||||
elif 1 < len(tokenizer_config_matches):
|
||||
|
||||
Reference in New Issue
Block a user