Fix dataloading for cpt (#137)

* avpid mutable parameter

* do not remove text_column for cpt

* fix typo

* add

* remove constant KEEPCOLS

* update tests with columns_to_keep
This commit is contained in:
Bram Vanroy
2024-03-21 20:05:53 +01:00
committed by GitHub
parent c44cb1cd1d
commit ba7e0e4fca
5 changed files with 49 additions and 19 deletions
+6 -1
View File
@@ -83,7 +83,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=[data_args.text_column],
)
logger.info(
f"Training on the following datasets and their proportions:"
+6 -1
View File
@@ -77,7 +77,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
)
logger.info(
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
+6 -1
View File
@@ -85,7 +85,12 @@ def main():
###############
# Load datasets
###############
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
raw_datasets = get_datasets(
data_args,
splits=data_args.dataset_splits,
configs=data_args.dataset_configs,
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
)
logger.info(
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
)
+25 -10
View File
@@ -23,8 +23,6 @@ from .configs import DataArguments
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
COLUMNS_TO_KEEP = ["messages", "chosen", "rejected", "prompt", "completion", "label"]
def maybe_insert_system_message(messages, tokenizer):
if messages[0]["role"] == "system":
@@ -97,8 +95,9 @@ def apply_chat_template(
def get_datasets(
data_config: DataArguments | dict,
splits: List[str] = ["train", "test"],
splits: Optional[List[str]] = None,
configs: Optional[List[str]] = None,
columns_to_keep: Optional[List[str]] = None,
shuffle: bool = True,
) -> DatasetDict:
"""
@@ -109,13 +108,17 @@ def get_datasets(
Dataset configuration and split proportions.
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
configs (Optional[List[str]], *optional*, defaults to `None`):
List of dataset config names. If given must be the same length as 'data_config' keys.
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
and for cpt this should be (at least) the text column.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training and testing/validation data.
Returns
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
"""
if type(data_config) is DataArguments:
# Structure of the config to read the datasets and their mix
# datasets_mixer:
@@ -134,12 +137,18 @@ def get_datasets(
else:
raise ValueError(f"Data config {data_config} not recognized.")
raw_datasets = mix_datasets(dataset_mixer, splits=splits, configs=configs, shuffle=shuffle)
raw_datasets = mix_datasets(
dataset_mixer, splits=splits, configs=configs, columns_to_keep=columns_to_keep, shuffle=shuffle
)
return raw_datasets
def mix_datasets(
dataset_mixer: dict, configs: Optional[List[str]] = None, splits: Optional[List[str]] = None, shuffle=True
dataset_mixer: dict,
splits: Optional[List[str]] = None,
configs: Optional[List[str]] = None,
columns_to_keep: Optional[List[str]] = None,
shuffle=True,
) -> DatasetDict:
"""
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
@@ -147,14 +156,20 @@ def mix_datasets(
Args:
dataset_mixer (`dict`):
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
configs (Optional[List[str]], *optional*, defaults to `None`):
List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
splits (Optional[List[str]], *optional*, defaults to `None`):
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
configs (Optional[List[str]], *optional*, defaults to `None`):
List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
and for cpt this should be (at least) the text column.
shuffle (`bool`, *optional*, defaults to `True`):
Whether to shuffle the training and testing/validation data.
"""
splits = ["train", "test"] if splits is None else splits
configs = [None] * len(dataset_mixer) if not configs else configs
columns_to_keep = [] if columns_to_keep is None else columns_to_keep
if configs is not None and len(configs) != len(dataset_mixer):
raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")
@@ -173,7 +188,7 @@ def mix_datasets(
dataset = load_from_disk(os.path.join(ds, split))
# Remove redundant columns to avoid schema conflicts on load
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in COLUMNS_TO_KEEP])
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
if "train" in split:
raw_train_datasets.append(dataset)
elif "test" in split:
@@ -202,7 +217,7 @@ def mix_datasets(
if len(raw_datasets) == 0:
raise ValueError(
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted."
)
return raw_datasets
+6 -6
View File
@@ -33,7 +33,7 @@ class GetDatasetsTest(unittest.TestCase):
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
data_args = DataArguments(dataset_mixer=dataset_mixer)
datasets = get_datasets(data_args)
datasets = get_datasets(data_args, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)
@@ -43,7 +43,7 @@ class GetDatasetsTest(unittest.TestCase):
"HuggingFaceH4/testing_self_instruct_small": 0.3,
"HuggingFaceH4/testing_codealpaca_small": 0.2,
}
datasets = get_datasets(dataset_mixer)
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 100)
self.assertEqual(len(datasets["test"]), 300)
@@ -53,7 +53,7 @@ class GetDatasetsTest(unittest.TestCase):
"HuggingFaceH4/testing_self_instruct_small": 1.0,
"HuggingFaceH4/testing_codealpaca_small": 1.0,
}
datasets = get_datasets(dataset_mixer)
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 300)
self.assertEqual(len(datasets["test"]), 300)
@@ -62,7 +62,7 @@ class GetDatasetsTest(unittest.TestCase):
"HuggingFaceH4/testing_alpaca_small": 0.7,
"HuggingFaceH4/testing_self_instruct_small": 0.4,
}
datasets = get_datasets(dataset_mixer)
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["train"]), 70 + 40)
self.assertEqual(len(datasets["test"]), 200)
@@ -72,13 +72,13 @@ class GetDatasetsTest(unittest.TestCase):
"HuggingFaceH4/testing_self_instruct_small": -0.3,
}
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
get_datasets(dataset_mixer)
get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
def test_loading_single_split_with_unit_fractions(self):
dataset_mixer = {
"HuggingFaceH4/testing_alpaca_small": 1.0,
}
datasets = get_datasets(dataset_mixer, splits=["test"])
datasets = get_datasets(dataset_mixer, splits=["test"], columns_to_keep=["prompt", "completion"])
self.assertEqual(len(datasets["test"]), 100)
self.assertRaises(KeyError, lambda: datasets["train"])