mirror of
https://github.com/wassname/alpaca_convert.git
synced 2026-06-27 18:57:59 +08:00
fix bug
This commit is contained in:
@@ -131,6 +131,7 @@ class AutogradMatmul4bit(torch.autograd.Function):
|
||||
def backward(ctx, grad_output):
|
||||
qweight, scales, zeros = ctx.saved_tensors
|
||||
buffer = get_buffer(qweight.shape, dtype=scales.dtype, device=qweight.device)
|
||||
quant.quant_cuda.vecquant4recons(qweight, buffer, scales, zeros)
|
||||
grad = torch.matmul(grad_output, buffer.T)
|
||||
return grad, None, None, None
|
||||
|
||||
@@ -229,4 +230,4 @@ def load_llama_model_4bit_low_ram(config_path, model_path, half=False):
|
||||
print(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user