Clean deprecated max_samples arguments (#89)

This commit is contained in:
Kirill
2024-01-05 02:06:47 +04:00
committed by GitHub
parent e316174e1c
commit 98fe28fb14
3 changed files with 4 additions and 30 deletions
+2 -8
View File
@@ -173,10 +173,7 @@ def main():
###############
train_result = dpo_trainer.train()
metrics = train_result.metrics
max_train_samples = (
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
)
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
metrics["train_samples"] = len(raw_datasets["train"])
dpo_trainer.log_metrics("train", metrics)
dpo_trainer.save_metrics("train", metrics)
dpo_trainer.save_state()
@@ -189,10 +186,7 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = dpo_trainer.evaluate()
max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
)
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
metrics["eval_samples"] = len(raw_datasets["test"])
dpo_trainer.log_metrics("eval", metrics)
dpo_trainer.save_metrics("eval", metrics)
+2 -4
View File
@@ -151,8 +151,7 @@ def main():
logger.info("*** Train ***")
train_result = trainer.train()
metrics = train_result.metrics
max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
@@ -163,8 +162,7 @@ def main():
if training_args.do_eval:
logger.info("*** Evaluate ***")
metrics = trainer.evaluate()
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
metrics["eval_samples"] = len(eval_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
-18
View File
@@ -197,24 +197,6 @@ class DataArguments:
default_factory=lambda: ["train", "test"],
metadata={"help": ("List of train test splits to use in the dataset")},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},