mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-07-05 00:31:34 +08:00
Clean deprecated max_samples arguments (#89)
This commit is contained in:
+2
-8
@@ -173,10 +173,7 @@ def main():
|
||||
###############
|
||||
train_result = dpo_trainer.train()
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = (
|
||||
data_args.max_train_samples if data_args.max_train_samples is not None else len(raw_datasets["train"])
|
||||
)
|
||||
metrics["train_samples"] = min(max_train_samples, len(raw_datasets["train"]))
|
||||
metrics["train_samples"] = len(raw_datasets["train"])
|
||||
dpo_trainer.log_metrics("train", metrics)
|
||||
dpo_trainer.save_metrics("train", metrics)
|
||||
dpo_trainer.save_state()
|
||||
@@ -189,10 +186,7 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = dpo_trainer.evaluate()
|
||||
max_eval_samples = (
|
||||
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(raw_datasets["test"])
|
||||
)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(raw_datasets["test"]))
|
||||
metrics["eval_samples"] = len(raw_datasets["test"])
|
||||
dpo_trainer.log_metrics("eval", metrics)
|
||||
dpo_trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
+2
-4
@@ -151,8 +151,7 @@ def main():
|
||||
logger.info("*** Train ***")
|
||||
train_result = trainer.train()
|
||||
metrics = train_result.metrics
|
||||
max_train_samples = data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
||||
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
@@ -163,8 +162,7 @@ def main():
|
||||
if training_args.do_eval:
|
||||
logger.info("*** Evaluate ***")
|
||||
metrics = trainer.evaluate()
|
||||
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
||||
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
||||
metrics["eval_samples"] = len(eval_dataset)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
|
||||
@@ -197,24 +197,6 @@ class DataArguments:
|
||||
default_factory=lambda: ["train", "test"],
|
||||
metadata={"help": ("List of train test splits to use in the dataset")},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."},
|
||||
|
||||
Reference in New Issue
Block a user