mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Cleaned up default argument logic.
This commit is contained in:
@@ -168,7 +168,7 @@ if __name__ == "__main__":
|
||||
loss_function=training_conf["loss"],
|
||||
learning_rate=training_conf["learning_rate"],
|
||||
# half_precision_backend="apex",
|
||||
fp16=training_conf["fp16"] if "fp16" in training_conf else True,
|
||||
fp16=training_conf["fp16"],
|
||||
gradient_checkpointing=training_conf["gradient_checkpointing"],
|
||||
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
|
||||
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
|
||||
@@ -180,7 +180,7 @@ if __name__ == "__main__":
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=training_conf["eval_steps"],
|
||||
save_steps=1000,
|
||||
report_to="wandb",
|
||||
report_to="local",
|
||||
)
|
||||
train_datasets, evals = [], {}
|
||||
if "webgpt" in training_conf["datasets"]:
|
||||
@@ -196,10 +196,7 @@ if __name__ == "__main__":
|
||||
evals["hfsummary"] = sum_eval
|
||||
train = ConcatDataset(train_datasets)
|
||||
|
||||
if "tokenizer_name" in training_conf:
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
else:
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
|
||||
|
||||
if "rankgen" in model_name:
|
||||
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
|
||||
|
||||
@@ -71,6 +71,10 @@ def freeze_top_n_layers(model, target_layers):
|
||||
|
||||
|
||||
def argument_parsing(parser):
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
default_params = {
|
||||
"num_train_epochs": 4,
|
||||
"learning_rate": 3e-5,
|
||||
@@ -82,10 +86,9 @@ def argument_parsing(parser):
|
||||
"gradient_accumulation_steps": 8,
|
||||
"gradient_checkpointing": False,
|
||||
"datasets": ["webgpt"],
|
||||
"fp16": True,
|
||||
"tokenizer_name": training_conf["model_name"],
|
||||
}
|
||||
args = parser.parse_args()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
training_conf = yaml.safe_load(f.read())
|
||||
|
||||
params = {**default_params, **training_conf}
|
||||
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])
|
||||
|
||||
Reference in New Issue
Block a user