Files
2023-01-04 20:51:19 -05:00

28 lines
1.1 KiB
Python

import torch
from transformers import AutoModel
class RankGenModel(torch.nn.Module):
def __init__(self, model_name):
super().__init__()
self.rankgen_hf_hub = model_name
assert model_name in [
"kalpeshk2011/rankgen-t5-xl-all",
"kalpeshk2011/rankgen-t5-xl-pg19",
"kalpeshk2011/rankgen-t5-base-all",
"kalpeshk2011/rankgen-t5-large-all",
]
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
def forward(self, prefixes, suffixes):
# print(list(self.model.parameters()))
# raise Exception("stop")
embedded_prefixes = self.model(**prefixes)
embedded_suffixes = self.model(**suffixes)
# take dot product of each row independently
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
# raise Exception("stop")
return dot_products