diff --git a/model/supervised_finetuning/custom_datasets/dialogue_collator.py b/model/supervised_finetuning/custom_datasets/dialogue_collator.py index f9e1bb5e..479931f6 100644 --- a/model/supervised_finetuning/custom_datasets/dialogue_collator.py +++ b/model/supervised_finetuning/custom_datasets/dialogue_collator.py @@ -24,26 +24,24 @@ class DialogueDataCollator: flatten_messages = [] label_masks = [] - for messages in features: - assert len(messages) % 2 == 0, "Number of messages must be even" + for feature_one in features: + assert len(feature_one) % 2 == 0, "Number of messages must be even" messages = [ (QA_SPECIAL_TOKENS["Question"] if i % 2 == 0 else "") + x + (QA_SPECIAL_TOKENS["Answer"] if i % 2 == 0 else "") - for i, x in enumerate(messages) + for i, x in enumerate(feature_one) ] # Add a way for the model to terminate generation # When we predict the start of a new expected question, we want to be able to stop generation messages.append(QA_SPECIAL_TOKENS["Question"]) - flatten_messages.append( - self.tokenizer( - "".join(messages), - truncation=True, - max_length=self.max_length, - return_offsets_mapping=True, - ) + flatten_message = self.tokenizer( + "".join(messages), + truncation=True, + max_length=self.max_length, + return_offsets_mapping=True, ) message_change_indices = np.cumsum([len(x) for x in messages[:-1]]) @@ -57,18 +55,19 @@ class DialogueDataCollator: message_indices = list( map( lambda x: next((i for i, val in enumerate(message_change_indices) if val >= x), -2), - list(map(lambda x: x[1], flatten_messages[-1]["offset_mapping"])), + list(map(lambda x: x[1], flatten_message["offset_mapping"])), ) ) label_mask = np.roll(list(map(lambda x: x % 2 == 1, message_indices)), -1, -1) try: label_mask[[i for i in range(len(message_indices)) if message_indices[i] == -2][0] - 1] = True except IndexError: - # an aftermath of padding - pass + # due to truncation, we might not have the last termination token + label_mask[-1] = False label_masks.append(label_mask) - flatten_messages[-1].pop("offset_mapping") + + flatten_messages.append({k: v for k, v in flatten_message.items() if k != "offset_mapping"}) batch = self.tokenizer.pad( flatten_messages, @@ -79,10 +78,9 @@ class DialogueDataCollator: ) dim = batch["input_ids"].shape[-1] - batch["label_masks"] = torch.stack([F.pad(torch.tensor(x), (0, dim - len(x))) for x in label_masks]) - - for k in list(batch.keys()): - if k not in ["input_ids", "attention_mask", "label_masks"]: - batch.pop(k) + batch["label_masks"] = torch.stack( + [F.pad(torch.tensor(x), (0, dim - len(x)), value=False) for x in label_masks] + ) + batch["targets"] = torch.roll(batch["input_ids"], -1, -1) return batch diff --git a/model/supervised_finetuning/models/gptj.py b/model/supervised_finetuning/models/gptj.py index 3cbec3ce..d954c830 100644 --- a/model/supervised_finetuning/models/gptj.py +++ b/model/supervised_finetuning/models/gptj.py @@ -171,7 +171,7 @@ def get_model(model_name, cache_dir, quantization): if quantization is None: model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) elif quantization == "8bit": - print("Loading 8-bit model") + raise ValueError("Loading 8-bit model. Bitsandbytes does not behave so far...") transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) add_adapters(model) diff --git a/model/supervised_finetuning/requirements.txt b/model/supervised_finetuning/requirements.txt index d579468f..798b5950 100644 --- a/model/supervised_finetuning/requirements.txt +++ b/model/supervised_finetuning/requirements.txt @@ -4,3 +4,6 @@ PyYAML==6.0 scikit_learn==1.2.0 torch==1.13.1 transformers==4.25.1 +deepspeed==0.7.7 +mpi4py==3.1.4 +accelerate==0.15.0 \ No newline at end of file