mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
28 lines
1.1 KiB
Python
28 lines
1.1 KiB
Python
import torch
|
|
from transformers import AutoModel
|
|
|
|
|
|
class RankGenModel(torch.nn.Module):
|
|
def __init__(self, model_name):
|
|
super().__init__()
|
|
self.rankgen_hf_hub = model_name
|
|
assert model_name in [
|
|
"kalpeshk2011/rankgen-t5-xl-all",
|
|
"kalpeshk2011/rankgen-t5-xl-pg19",
|
|
"kalpeshk2011/rankgen-t5-base-all",
|
|
"kalpeshk2011/rankgen-t5-large-all",
|
|
]
|
|
self.model = AutoModel.from_pretrained(self.rankgen_hf_hub, trust_remote_code=True)
|
|
|
|
def forward(self, prefixes, suffixes):
|
|
# print(list(self.model.parameters()))
|
|
# raise Exception("stop")
|
|
embedded_prefixes = self.model(**prefixes)
|
|
embedded_suffixes = self.model(**suffixes)
|
|
# take dot product of each row independently
|
|
dot_products = torch.sum(embedded_prefixes * embedded_suffixes, dim=1)
|
|
|
|
# print(f"{embedded_prefixes.shape=}, {embedded_suffixes.shape=}, {prefixes['input_ids'].shape=}, {suffixes['input_ids'].shape=}, {embedded_prefixes=}, {embedded_suffixes=}, {dot_products=}")
|
|
# raise Exception("stop")
|
|
return dot_products
|