Cleaned up default argument logic.

This commit is contained in:
Bobak Hashemi
2023-01-03 21:45:16 -05:00
parent 4569bcf354
commit da79aa04a0
2 changed files with 9 additions and 9 deletions
+3 -6
View File
@@ -168,7 +168,7 @@ if __name__ == "__main__":
loss_function=training_conf["loss"],
learning_rate=training_conf["learning_rate"],
# half_precision_backend="apex",
fp16=training_conf["fp16"] if "fp16" in training_conf else True,
fp16=training_conf["fp16"],
gradient_checkpointing=training_conf["gradient_checkpointing"],
gradient_accumulation_steps=training_conf["gradient_accumulation_steps"],
per_device_train_batch_size=training_conf["per_device_train_batch_size"],
@@ -180,7 +180,7 @@ if __name__ == "__main__":
evaluation_strategy="steps",
eval_steps=training_conf["eval_steps"],
save_steps=1000,
report_to="wandb",
report_to="local",
)
train_datasets, evals = [], {}
if "webgpt" in training_conf["datasets"]:
@@ -196,10 +196,7 @@ if __name__ == "__main__":
evals["hfsummary"] = sum_eval
train = ConcatDataset(train_datasets)
if "tokenizer_name" in training_conf:
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
else:
tokenizer = get_tokenizer(model_name)
tokenizer = get_tokenizer(training_conf["tokenizer_name"])
if "rankgen" in model_name:
collate_fn = RankGenCollator(tokenizer, max_length=training_conf["max_length"])
+6 -3
View File
@@ -71,6 +71,10 @@ def freeze_top_n_layers(model, target_layers):
def argument_parsing(parser):
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
default_params = {
"num_train_epochs": 4,
"learning_rate": 3e-5,
@@ -82,10 +86,9 @@ def argument_parsing(parser):
"gradient_accumulation_steps": 8,
"gradient_checkpointing": False,
"datasets": ["webgpt"],
"fp16": True,
"tokenizer_name": training_conf["model_name"],
}
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
training_conf = yaml.safe_load(f.read())
params = {**default_params, **training_conf}
params["gradient_accumulation_steps"] = int(params["gradient_accumulation_steps"])