diff --git a/src/alignment/data.py b/src/alignment/data.py index b6a9a24..532bddc 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -99,7 +99,7 @@ def get_datasets( splits (`List[str]`, *optional*, defaults to `['train', 'test']`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training data. + Whether to shuffle the training and testing/validation data. Returns [`DatasetDict`]: The dataset dictionary containing the loaded datasets. @@ -137,7 +137,7 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl splits (Optional[List[str]], *optional*, defaults to `None`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training data. + Whether to shuffle the training and testing/validation data. """ raw_datasets = DatasetDict() raw_train_datasets = []