mirror of
https://github.com/wassname/alignment-handbook.git
synced 2026-06-27 18:41:19 +08:00
Allow loading datasets from disk using load_from_disk method. (#53)
* feat: Allow loading datasets from disk using `load_from_disk` method. * Fixing the type of error being catched.
This commit is contained in:
committed by
GitHub
parent
80e952ec47
commit
15279e7157
+12
-14
@@ -12,11 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset
|
||||
from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
|
||||
from datasets.builder import DatasetGenerationError
|
||||
|
||||
from .configs import DataArguments
|
||||
|
||||
@@ -145,20 +146,17 @@ def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffl
|
||||
for ds, frac in dataset_mixer.items():
|
||||
fracs.append(frac)
|
||||
for split in splits:
|
||||
try:
|
||||
# Try first if dataset on a Hub repo
|
||||
dataset = load_dataset(ds, split=split)
|
||||
except DatasetGenerationError:
|
||||
# If not, check local dataset
|
||||
dataset = load_from_disk(os.path.join(ds, split))
|
||||
|
||||
if "train" in split:
|
||||
raw_train_datasets.append(
|
||||
load_dataset(
|
||||
ds,
|
||||
split=split,
|
||||
)
|
||||
)
|
||||
raw_train_datasets.append(dataset)
|
||||
elif "test" in split:
|
||||
raw_val_datasets.append(
|
||||
load_dataset(
|
||||
ds,
|
||||
split=split,
|
||||
)
|
||||
)
|
||||
raw_val_datasets.append(dataset)
|
||||
else:
|
||||
raise ValueError(f"Split type {split} not recognized as one of test or train.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user