Fix BitsAndBytes JSON Serializable (#191)

* Update run_sft.py

* fix BitsAndBytes JSON serializable

* get_quantization_config to return dict

* to_dict() for load_in_8bit too

* convert quant test to use dict subscriptions instead of dot syntax

* Remove torch

---------

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
Chansung Park
2024-08-20 22:02:18 +09:00
committed by GitHub
parent a8dcde2cd3
commit 27f7dbf006
2 changed files with 7 additions and 9 deletions
+2 -2
View File
@@ -52,11 +52,11 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
) ).to_dict()
elif model_args.load_in_8bit: elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig( quantization_config = BitsAndBytesConfig(
load_in_8bit=True, load_in_8bit=True,
) ).to_dict()
else: else:
quantization_config = None quantization_config = None
+5 -7
View File
@@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import unittest import unittest
import torch
from alignment import ( from alignment import (
DataArguments, DataArguments,
ModelArguments, ModelArguments,
@@ -31,15 +29,15 @@ class GetQuantizationConfigTest(unittest.TestCase):
def test_4bit(self): def test_4bit(self):
model_args = ModelArguments(load_in_4bit=True) model_args = ModelArguments(load_in_4bit=True)
quantization_config = get_quantization_config(model_args) quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_4bit) self.assertTrue(quantization_config["load_in_4bit"])
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16) self.assertEqual(quantization_config["bnb_4bit_compute_dtype"], "float16")
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4") self.assertEqual(quantization_config["bnb_4bit_quant_type"], "nf4")
self.assertFalse(quantization_config.bnb_4bit_use_double_quant) self.assertFalse(quantization_config["bnb_4bit_use_double_quant"])
def test_8bit(self): def test_8bit(self):
model_args = ModelArguments(load_in_8bit=True) model_args = ModelArguments(load_in_8bit=True)
quantization_config = get_quantization_config(model_args) quantization_config = get_quantization_config(model_args)
self.assertTrue(quantization_config.load_in_8bit) self.assertTrue(quantization_config["load_in_8bit"])
def test_no_quantization(self): def test_no_quantization(self):
model_args = ModelArguments() model_args = ModelArguments()