From 090c5cbcc24614d6a4d1b09987c7e31aae37dddd Mon Sep 17 00:00:00 2001 From: "jack.butler" Date: Thu, 9 Feb 2023 18:47:38 +0000 Subject: [PATCH 1/3] fix tokenizer matching and add tests --- .../supervised_finetuning/tests/test_utils.py | 21 ++++++++++++++++++- model/supervised_finetuning/utils.py | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/tests/test_utils.py b/model/supervised_finetuning/tests/test_utils.py index ad40e534..96637e98 100644 --- a/model/supervised_finetuning/tests/test_utils.py +++ b/model/supervised_finetuning/tests/test_utils.py @@ -1,9 +1,28 @@ from argparse import Namespace -from utils import get_tokenizer +import pytest +from utils import TOKENIZER_CONFIGS, get_tokenizer, match_tokenizer_name def test_tokenizer(): get_tokenizer(Namespace(model_name="Salesforce/codegen-2B-multi", cache_dir=".cache")) get_tokenizer(Namespace(model_name="facebook/galactica-1.3b", cache_dir=".cache")) get_tokenizer(Namespace(model_name="", cache_dir=".cache")) + + +def test_tokenizer_successful_match(): + for config_name, config in TOKENIZER_CONFIGS: + found_config = match_tokenizer_name(config_name) + assert found_config == config + + +def test_tokenizer_partial_match(): + for config_name, config in TOKENIZER_CONFIGS: + found_config = match_tokenizer_name(config_name[: len(config_name) - 1]) + assert found_config == config + + +def test_tokenizer_failed_match(): + for fake_config_name in ["not-a-model", "fake"]: + with pytest.raises(ValueError): + match_tokenizer_name(fake_config_name) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index c3b8264f..43ff6c8c 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -37,7 +37,7 @@ TOKENIZER_CONFIGS = { def match_tokenizer_name(model_name: str) -> TokenizerConfig: """Match a partial model name to a tokenizer configuration""" - tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name] + tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if model_name in name] if not tokenizer_config_matches: raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}") elif 1 < len(tokenizer_config_matches): From 2fbf2fa4575d1545c1513adf0a369bdc65ea63ef Mon Sep 17 00:00:00 2001 From: "jack.butler" Date: Fri, 10 Feb 2023 09:46:58 +0000 Subject: [PATCH 2/3] add docstring info about tokenizer matching --- model/supervised_finetuning/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/model/supervised_finetuning/utils.py b/model/supervised_finetuning/utils.py index 43ff6c8c..7be7bd27 100644 --- a/model/supervised_finetuning/utils.py +++ b/model/supervised_finetuning/utils.py @@ -36,8 +36,11 @@ TOKENIZER_CONFIGS = { def match_tokenizer_name(model_name: str) -> TokenizerConfig: - """Match a partial model name to a tokenizer configuration""" - tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if model_name in name] + """ + Match a partial model name to a tokenizer configuration + i.e. model_name `Salesforce/codegen-2B-multi` has config name `codegen` + """ + tokenizer_config_matches = [config for name, config in TOKENIZER_CONFIGS.items() if name in model_name] if not tokenizer_config_matches: raise ValueError(f"Cannot find any tokeniser configuration to match {model_name=}") elif 1 < len(tokenizer_config_matches): From 24b07523aafdb5a0633d015eec6f7441b6534abd Mon Sep 17 00:00:00 2001 From: "jack.butler" Date: Fri, 10 Feb 2023 09:47:20 +0000 Subject: [PATCH 3/3] add test for tokenizer matching behaviour --- model/supervised_finetuning/tests/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/supervised_finetuning/tests/test_utils.py b/model/supervised_finetuning/tests/test_utils.py index 96637e98..c4982024 100644 --- a/model/supervised_finetuning/tests/test_utils.py +++ b/model/supervised_finetuning/tests/test_utils.py @@ -11,15 +11,15 @@ def test_tokenizer(): def test_tokenizer_successful_match(): - for config_name, config in TOKENIZER_CONFIGS: + for config_name, config in TOKENIZER_CONFIGS.items(): found_config = match_tokenizer_name(config_name) assert found_config == config def test_tokenizer_partial_match(): - for config_name, config in TOKENIZER_CONFIGS: - found_config = match_tokenizer_name(config_name[: len(config_name) - 1]) - assert found_config == config + for config_name in ["facebook/galactica-1.3b", "togethercomputer/GPT-JT-6B-v1", "Salesforce/codegen-2B-multi"]: + found_config = match_tokenizer_name(config_name) + assert found_config def test_tokenizer_failed_match():