mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 17:47:01 +08:00
Fix BitsAndBytes JSON Serializable (#191)
* Update run_sft.py * fix BitsAndBytes JSON serializable * get_quantization_config to return dict * to_dict() for load_in_8bit too * convert quant test to use dict subscriptions instead of dot syntax * Remove torch --------- Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
@@ -52,11 +52,11 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
|
|||||||
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
|
||||||
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
|
||||||
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
|
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
|
||||||
)
|
).to_dict()
|
||||||
elif model_args.load_in_8bit:
|
elif model_args.load_in_8bit:
|
||||||
quantization_config = BitsAndBytesConfig(
|
quantization_config = BitsAndBytesConfig(
|
||||||
load_in_8bit=True,
|
load_in_8bit=True,
|
||||||
)
|
).to_dict()
|
||||||
else:
|
else:
|
||||||
quantization_config = None
|
quantization_config = None
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from alignment import (
|
from alignment import (
|
||||||
DataArguments,
|
DataArguments,
|
||||||
ModelArguments,
|
ModelArguments,
|
||||||
@@ -31,15 +29,15 @@ class GetQuantizationConfigTest(unittest.TestCase):
|
|||||||
def test_4bit(self):
|
def test_4bit(self):
|
||||||
model_args = ModelArguments(load_in_4bit=True)
|
model_args = ModelArguments(load_in_4bit=True)
|
||||||
quantization_config = get_quantization_config(model_args)
|
quantization_config = get_quantization_config(model_args)
|
||||||
self.assertTrue(quantization_config.load_in_4bit)
|
self.assertTrue(quantization_config["load_in_4bit"])
|
||||||
self.assertEqual(quantization_config.bnb_4bit_compute_dtype, torch.float16)
|
self.assertEqual(quantization_config["bnb_4bit_compute_dtype"], "float16")
|
||||||
self.assertEqual(quantization_config.bnb_4bit_quant_type, "nf4")
|
self.assertEqual(quantization_config["bnb_4bit_quant_type"], "nf4")
|
||||||
self.assertFalse(quantization_config.bnb_4bit_use_double_quant)
|
self.assertFalse(quantization_config["bnb_4bit_use_double_quant"])
|
||||||
|
|
||||||
def test_8bit(self):
|
def test_8bit(self):
|
||||||
model_args = ModelArguments(load_in_8bit=True)
|
model_args = ModelArguments(load_in_8bit=True)
|
||||||
quantization_config = get_quantization_config(model_args)
|
quantization_config = get_quantization_config(model_args)
|
||||||
self.assertTrue(quantization_config.load_in_8bit)
|
self.assertTrue(quantization_config["load_in_8bit"])
|
||||||
|
|
||||||
def test_no_quantization(self):
|
def test_no_quantization(self):
|
||||||
model_args = ModelArguments()
|
model_args = ModelArguments()
|
||||||
|
|||||||
Reference in New Issue
Block a user