Allow loading datasets from disk using load_from_disk method. (#53)

* feat: Allow loading datasets from disk using `load_from_disk` method.

* Fixing the type of error being catched.
This commit is contained in:
Dragan Milchevski
2023-12-01 11:05:35 +01:00
committed by GitHub
parent 80e952ec47
commit 15279e7157
+12 -14
View File
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from typing import List, Literal, Optional
from datasets import DatasetDict, concatenate_datasets, load_dataset
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from .configs import DataArguments
@@ -145,20 +146,17 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl
for ds, frac in dataset_mixer.items():
fracs.append(frac)
for split in splits:
try:
# Try first if dataset on a Hub repo
dataset = load_dataset(ds, split=split)
except DatasetGenerationError:
# If not, check local dataset
dataset = load_from_disk(os.path.join(ds, split))
if "train" in split:
raw_train_datasets.append(
load_dataset(
ds,
split=split,
)
)
raw_train_datasets.append(dataset)
elif "test" in split:
raw_val_datasets.append(
load_dataset(
ds,
split=split,
)
)
raw_val_datasets.append(dataset)
else:
raise ValueError(f"Split type {split} not recognized as one of test or train.")