Fix dataloading for cpt (#137)

* avpid mutable parameter

* do not remove text_column for cpt

* fix typo

* add

* remove constant KEEPCOLS

* update tests with columns_to_keep
This commit is contained in:
Bram Vanroy
2024-03-21 20:05:53 +01:00
committed by GitHub
parent c44cb1cd1d
commit ba7e0e4fca
5 changed files with 49 additions and 19 deletions
+6 -1
View File
@@ -83,7 +83,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=[data_args.text_column],
)
logger.info(
f"Training on the following datasets and their proportions:"
+6 -1
View File
@@ -77,7 +77,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
+6 -1
View File
@@ -85,7 +85,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
)
logger.info(
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)