Files
Open-Assistant/model/supervised_finetuning/models/__init__.py
T
ekurtulus 5b77dd2e9f better
2023-01-11 11:37:27 +03:00

37 lines
1.3 KiB
Python

import transformers
# from .gptj import get_model as get_gptj_model
SUPPORTED_MODELS = ["galactica", "gpt-j"]
def freeze_top_n_layers(model, target_layers):
# its possible we can simply detect which module is a ModuleList
# and simply freeze the module without doing string parsing
for name, param in model.named_parameters():
if "embed" in name:
param.requires_grad = False
elif ".layer" in name or ".h." in name:
tokens = name.split(".")
layer_ = None
for token in tokens:
if token.isdigit():
layer_ = int(token)
break
if layer_ is not None and layer_ < target_layers:
# print('freeze ', layer_, name)
param.requires_grad = False
return model
def get_specific_model(model_name, cache_dir, quantization, seq2seqmodel):
# encoder-decoder support for Flan-T5 like models
# for now, we can use an argument but in the future,
# we can automate this
if seq2seqmodel:
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=cache_dir)
else:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
return model