mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 16:14:07 +08:00
Fix dataloading for cpt (#137)
* avpid mutable parameter * do not remove text_column for cpt * fix typo * add * remove constant KEEPCOLS * update tests with columns_to_keep
This commit is contained in:
+6
-1
@@ -83,7 +83,12 @@ def main():
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=[data_args.text_column],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Training on the following datasets and their proportions:"
|
||||
|
||||
+6
-1
@@ -77,7 +77,12 @@ def main():
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
|
||||
)
|
||||
logger.info(
|
||||
f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
|
||||
+6
-1
@@ -85,7 +85,12 @@ def main():
|
||||
###############
|
||||
# Load datasets
|
||||
###############
|
||||
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits, configs=data_args.dataset_configs)
|
||||
raw_datasets = get_datasets(
|
||||
data_args,
|
||||
splits=data_args.dataset_splits,
|
||||
configs=data_args.dataset_configs,
|
||||
columns_to_keep=["messages", "chosen", "rejected", "prompt", "completion", "label"],
|
||||
)
|
||||
logger.info(
|
||||
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
|
||||
)
|
||||
|
||||
+25
-10
@@ -23,8 +23,6 @@ from .configs import DataArguments
|
||||
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
COLUMNS_TO_KEEP = ["messages", "chosen", "rejected", "prompt", "completion", "label"]
|
||||
|
||||
|
||||
def maybe_insert_system_message(messages, tokenizer):
|
||||
if messages[0]["role"] == "system":
|
||||
@@ -97,8 +95,9 @@ def apply_chat_template(
|
||||
|
||||
def get_datasets(
|
||||
data_config: DataArguments | dict,
|
||||
splits: List[str] = ["train", "test"],
|
||||
splits: Optional[List[str]] = None,
|
||||
configs: Optional[List[str]] = None,
|
||||
columns_to_keep: Optional[List[str]] = None,
|
||||
shuffle: bool = True,
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
@@ -109,13 +108,17 @@ def get_datasets(
|
||||
Dataset configuration and split proportions.
|
||||
splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
|
||||
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
||||
configs (Optional[List[str]], *optional*, defaults to `None`):
|
||||
List of dataset config names. If given must be the same length as 'data_config' keys.
|
||||
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
|
||||
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
|
||||
and for cpt this should be (at least) the text column.
|
||||
shuffle (`bool`, *optional*, defaults to `True`):
|
||||
Whether to shuffle the training and testing/validation data.
|
||||
|
||||
Returns
|
||||
[`DatasetDict`]: The dataset dictionary containing the loaded datasets.
|
||||
"""
|
||||
|
||||
if type(data_config) is DataArguments:
|
||||
# Structure of the config to read the datasets and their mix
|
||||
# datasets_mixer:
|
||||
@@ -134,12 +137,18 @@ def get_datasets(
|
||||
else:
|
||||
raise ValueError(f"Data config {data_config} not recognized.")
|
||||
|
||||
raw_datasets = mix_datasets(dataset_mixer, splits=splits, configs=configs, shuffle=shuffle)
|
||||
raw_datasets = mix_datasets(
|
||||
dataset_mixer, splits=splits, configs=configs, columns_to_keep=columns_to_keep, shuffle=shuffle
|
||||
)
|
||||
return raw_datasets
|
||||
|
||||
|
||||
def mix_datasets(
|
||||
dataset_mixer: dict, configs: Optional[List[str]] = None, splits: Optional[List[str]] = None, shuffle=True
|
||||
dataset_mixer: dict,
|
||||
splits: Optional[List[str]] = None,
|
||||
configs: Optional[List[str]] = None,
|
||||
columns_to_keep: Optional[List[str]] = None,
|
||||
shuffle=True,
|
||||
) -> DatasetDict:
|
||||
"""
|
||||
Loads and mixes datasets according to proportions specified in `dataset_mixer`.
|
||||
@@ -147,14 +156,20 @@ def mix_datasets(
|
||||
Args:
|
||||
dataset_mixer (`dict`):
|
||||
Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
|
||||
configs (Optional[List[str]], *optional*, defaults to `None`):
|
||||
List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
|
||||
splits (Optional[List[str]], *optional*, defaults to `None`):
|
||||
Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
|
||||
configs (Optional[List[str]], *optional*, defaults to `None`):
|
||||
List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
|
||||
columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
|
||||
Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
|
||||
and for cpt this should be (at least) the text column.
|
||||
shuffle (`bool`, *optional*, defaults to `True`):
|
||||
Whether to shuffle the training and testing/validation data.
|
||||
"""
|
||||
splits = ["train", "test"] if splits is None else splits
|
||||
configs = [None] * len(dataset_mixer) if not configs else configs
|
||||
columns_to_keep = [] if columns_to_keep is None else columns_to_keep
|
||||
|
||||
if configs is not None and len(configs) != len(dataset_mixer):
|
||||
raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")
|
||||
|
||||
@@ -173,7 +188,7 @@ def mix_datasets(
|
||||
dataset = load_from_disk(os.path.join(ds, split))
|
||||
|
||||
# Remove redundant columns to avoid schema conflicts on load
|
||||
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in COLUMNS_TO_KEEP])
|
||||
dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
|
||||
if "train" in split:
|
||||
raw_train_datasets.append(dataset)
|
||||
elif "test" in split:
|
||||
@@ -202,7 +217,7 @@ def mix_datasets(
|
||||
|
||||
if len(raw_datasets) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
|
||||
f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted."
|
||||
)
|
||||
|
||||
return raw_datasets
|
||||
|
||||
+6
-6
@@ -33,7 +33,7 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
||||
}
|
||||
data_args = DataArguments(dataset_mixer=dataset_mixer)
|
||||
datasets = get_datasets(data_args)
|
||||
datasets = get_datasets(data_args, columns_to_keep=["prompt", "completion"])
|
||||
self.assertEqual(len(datasets["train"]), 100)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
@@ -43,7 +43,7 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
"HuggingFaceH4/testing_self_instruct_small": 0.3,
|
||||
"HuggingFaceH4/testing_codealpaca_small": 0.2,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
|
||||
self.assertEqual(len(datasets["train"]), 100)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
@@ -53,7 +53,7 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
"HuggingFaceH4/testing_self_instruct_small": 1.0,
|
||||
"HuggingFaceH4/testing_codealpaca_small": 1.0,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
|
||||
self.assertEqual(len(datasets["train"]), 300)
|
||||
self.assertEqual(len(datasets["test"]), 300)
|
||||
|
||||
@@ -62,7 +62,7 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
"HuggingFaceH4/testing_alpaca_small": 0.7,
|
||||
"HuggingFaceH4/testing_self_instruct_small": 0.4,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer)
|
||||
datasets = get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
|
||||
self.assertEqual(len(datasets["train"]), 70 + 40)
|
||||
self.assertEqual(len(datasets["test"]), 200)
|
||||
|
||||
@@ -72,13 +72,13 @@ class GetDatasetsTest(unittest.TestCase):
|
||||
"HuggingFaceH4/testing_self_instruct_small": -0.3,
|
||||
}
|
||||
with pytest.raises(ValueError, match=r"Dataset fractions cannot be negative."):
|
||||
get_datasets(dataset_mixer)
|
||||
get_datasets(dataset_mixer, columns_to_keep=["prompt", "completion"])
|
||||
|
||||
def test_loading_single_split_with_unit_fractions(self):
|
||||
dataset_mixer = {
|
||||
"HuggingFaceH4/testing_alpaca_small": 1.0,
|
||||
}
|
||||
datasets = get_datasets(dataset_mixer, splits=["test"])
|
||||
datasets = get_datasets(dataset_mixer, splits=["test"], columns_to_keep=["prompt", "completion"])
|
||||
self.assertEqual(len(datasets["test"]), 100)
|
||||
self.assertRaises(KeyError, lambda: datasets["train"])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user