mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:22:17 +08:00
🪁 (#129)
* Add Gemma 7B recipe * Use Gemma template * Make it work for dolly lol * Enable cahce * Clean up * DPO to the max * DPO, DPO, DPO * Add openhermes * Add custom configs * Add kwargs * Fix config * Bump deps * Move old recipes * Add doc * Add norte * Renable cache * Nuke * Clean * Apply suggestions from code review Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> * Fix isort * Update README.md * Update config_full.yaml --------- Co-authored-by: Alvaro Bartolome <alvaro@argilla.io> Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
+10
-10
@@ -197,16 +197,6 @@ def main():
|
||||
|
||||
logger.info("*** Training complete ***")
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
@@ -227,6 +217,16 @@ def main():
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
+11
-11
@@ -134,7 +134,6 @@ def main():
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
logger.info("*** Model loaded! ***")
|
||||
|
||||
########################
|
||||
# Initialize the Trainer
|
||||
@@ -150,6 +149,7 @@ def main():
|
||||
tokenizer=tokenizer,
|
||||
packing=True,
|
||||
peft_config=get_peft_config(model_args),
|
||||
dataset_kwargs=training_args.dataset_kwargs,
|
||||
)
|
||||
|
||||
###############
|
||||
@@ -168,16 +168,6 @@ def main():
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(eval_dataset)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
##################################
|
||||
# Save model and create model card
|
||||
##################################
|
||||
@@ -198,6 +188,16 @@ def main():
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.config.save_pretrained(training_args.output_dir)
|
||||
|
||||
##########
|
||||
# Evaluate
|
||||
##########
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
metrics["eval_samples"] = len(eval_dataset)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
if training_args.push_to_hub is True:
|
||||
logger.info("Pushing to hub...")
|
||||
trainer.push_to_hub(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user