From 61a11a5c7d66179ed0a930b0dd12e532fce701dd Mon Sep 17 00:00:00 2001 From: Scott Fleming Date: Wed, 6 Dec 2023 01:44:17 -0800 Subject: [PATCH] Update docstring for `data.py` to reflect true behavior of `shuffle` parameter (#60) * Update data.py The docs state that the `shuffle` parameter in `mix_datasets` from `data.py` controls `Whether to shuffle the training data`, but then in the code if `shuffle` is set to `True` it also shuffles the test data. This small change makes the functionality consistent with the docstring. (If you instead want to keep the functionality the same, then we should update the docstring). * Update data.py Reverted to the original code structure but updated the docstring. * Update docstring in `get_dataset` and `mix_datasets` Updated docstrings to reflect the fact that `shuffle` being set to `True` leads to shuffling of both the training and testing/validation data. --- src/alignment/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/alignment/data.py b/src/alignment/data.py index b6a9a24..532bddc 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -99,7 +99,7 @@ def get_datasets( splits (`List[str]`, *optional*, defaults to `['train', 'test']`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training data. + Whether to shuffle the training and testing/validation data. Returns [`DatasetDict`]: The dataset dictionary containing the loaded datasets. @@ -137,7 +137,7 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl splits (Optional[List[str]], *optional*, defaults to `None`): Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. shuffle (`bool`, *optional*, defaults to `True`): - Whether to shuffle the training data. + Whether to shuffle the training and testing/validation data. """ raw_datasets = DatasetDict() raw_train_datasets = []