mirror of
https://github.com/wassname/alpaca_convert.git
synced 2026-06-27 16:14:08 +08:00
add patch for encode function to remove eos token at the beginning of left side
This commit is contained in:
@@ -5,6 +5,8 @@ from autograd_4bit import load_llama_model_4bit_low_ram, Autograd4bitQuantLinear
|
||||
from peft import PeftModel
|
||||
from peft.tuners.lora import Linear4bitLt
|
||||
|
||||
patch_encode_func = False
|
||||
|
||||
def load_model_llama(*args, **kwargs):
|
||||
|
||||
config_path = '../llama-13b-4bit/'
|
||||
@@ -41,4 +43,15 @@ shared.settings['name2'] = 'Assistant'
|
||||
shared.settings['chat_prompt_size_max'] = 2048
|
||||
shared.settings['chat_prompt_size'] = 2048
|
||||
|
||||
if patch_encode_func:
|
||||
from modules import text_generation
|
||||
text_generation.encode_old = text_generation.encode
|
||||
def encode_patched(*args, **kwargs):
|
||||
input_ids = text_generation.encode_old(*args, **kwargs)
|
||||
if input_ids[0,0] == 0:
|
||||
input_ids = input_ids[:, 1:]
|
||||
return input_ids
|
||||
text_generation.encode = encode_patched
|
||||
print('Encode Function Patched.')
|
||||
|
||||
print('Monkey Patch Completed.')
|
||||
|
||||
Reference in New Issue
Block a user