From 27f7dbf00663dab66ad7334afb7a1311fa251f41 Mon Sep 17 00:00:00 2001 From: Chansung Park Date: Tue, 20 Aug 2024 22:02:18 +0900 Subject: [PATCH] Fix BitsAndBytes JSON Serializable (#191) * Update run_sft.py * fix BitsAndBytes JSON serializable * get_quantization_config to return dict * to_dict() for load_in_8bit too * convert quant test to use dict subscriptions instead of dot syntax * Remove torch --------- Co-authored-by: lewtun --- src/alignment/model_utils.py | 4 ++-- tests/test_model_utils.py | 12 +++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/alignment/model_utils.py b/src/alignment/model_utils.py index 1115b67..6795b6a 100644 --- a/src/alignment/model_utils.py +++ b/src/alignment/model_utils.py @@ -52,11 +52,11 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, - ) + ).to_dict() elif model_args.load_in_8bit: quantization_config = BitsAndBytesConfig( load_in_8bit=True, - ) + ).to_dict() else: quantization_config = None diff --git a/tests/test_model_utils.py b/tests/test_model_utils.py index e0fc6fe..16ada92 100644 --- a/tests/test_model_utils.py +++ b/tests/test_model_utils.py @@ -14,8 +14,6 @@ # limitations under the License. import unittest -import torch - from alignment import ( DataArguments, ModelArguments, @@ -31,15 +29,15 @@ class GetQuantizationConfigTest(unittest.TestCase): def test_4bit(self): model_args = ModelArguments(load_in_4bit=True) quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config.load_in_4bit) - self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16) - self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4") - self.assertFalse(quantization_config.bnb_4bit_use_double_quant) + self.assertTrue(quantization_config["load_in_4bit"]) + self.assertEqual(quantization_config["bnb_4bit_compute_dtype"], "float16") + self.assertEqual(quantization_config["bnb_4bit_quant_type"], "nf4") + self.assertFalse(quantization_config["bnb_4bit_use_double_quant"]) def test_8bit(self): model_args = ModelArguments(load_in_8bit=True) quantization_config = get_quantization_config(model_args) - self.assertTrue(quantization_config.load_in_8bit) + self.assertTrue(quantization_config["load_in_8bit"]) def test_no_quantization(self): model_args = ModelArguments()