various fixes to the SFT code

This commit is contained in:
Yannic Kilcher
2023-02-11 21:37:35 +01:00
parent 74e310e7ba
commit b6a0eedf81
4 changed files with 4 additions and 2 deletions
+1
View File
@@ -0,0 +1 @@
.cache
@@ -13,6 +13,7 @@ defaults:
logging_steps: 10
max_grad_norm: 2.0
save_total_limit: 4
fp16: false
eval_accumulation_steps:
freeze_layer:
datasets:
@@ -20,7 +20,7 @@ class OAPrivate(Dataset):
total_prob = reduce(lambda prev, split: prev + split[1], self.splits.items(), 0)
assert math.isclose(total_prob, 1), "Make sure OAPrivate split ratios add to 1"
jsonl_file = os.path.join(data_path, self.file)
jsonl_file = os.path.join(data_path, file)
with open(jsonl_file, "r", encoding="utf-8") as f:
lines = f.readlines()
+1 -1
View File
@@ -217,7 +217,7 @@ if __name__ == "__main__":
learning_rate=float(training_conf.learning_rate),
deepspeed="configs/zero_config.json" if training_conf.deepspeed else None,
optim=optimizer,
fp16=True,
fp16=training_conf.fp16,
local_rank=training_conf.local_rank,
gradient_checkpointing=training_conf.gradient_checkpointing,
gradient_accumulation_steps=training_conf.gradient_accumulation_steps,