diff --git a/model/supervised_finetuning/trainer.py b/model/supervised_finetuning/trainer.py index 043534ea..c500f8df 100644 --- a/model/supervised_finetuning/trainer.py +++ b/model/supervised_finetuning/trainer.py @@ -9,7 +9,8 @@ import torch from torch import nn from torch.utils.data import DataLoader from transformers import PreTrainedModel, Trainer, TrainingArguments -from transformers.trainer_pt_utils import IterableDatasetShard, seed_worker +from transformers.trainer_pt_utils import IterableDatasetShard +from transformers.trainer_utils import seed_worker from transformers.training_args import OptimizerNames from transformers.utils import is_datasets_available from utils import get_dataset, get_loss, get_metrics, get_model, get_tokenizer, read_yamls