mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-03 17:10:10 +08:00
14fa08e2e7
* add query_incomplete_rankings() * Add SQL queries for TreeManager task selection * first working version of TreeManager.next_task() * remove old generate_task(), add mandatory_labels to text_labels task * Add ConversationMessage list to Ranking tasks * add more sophisticated sql queries to find extendible trees * add TreeManager.query_extendible_parents() * fix task validation, seed data insertion (reviewed) * provide user for task selection in text-frontend * enter 'growing' state * enter 'aborted_low_grade' state * enter 'ranking' state * check tree 'growing' state upon relpy insertion * exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting) * add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml * fix ranking submission * add query_tree_ranking_results() * add ranked_message_ids to RankingReactionPayload * fix reply_messages instead of prompt_messages * incorment 'ranking_count' of ranked replies * added logic to check_condition_for_scoring_state * changes to msg_tree_state_machine * pre-commit changes * enter 'ready_for_scoring' state * re-add HF embedding call (lost during merge) * use prepare_conversation() helper for seed-data creation * Partially add user specified task selection Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
66 lines
2.1 KiB
Python
66 lines
2.1 KiB
Python
from enum import Enum
|
|
from typing import Any, Dict
|
|
|
|
import aiohttp
|
|
from loguru import logger
|
|
from oasst_backend.config import settings
|
|
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
|
|
|
|
|
class HfUrl(str, Enum):
|
|
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
|
|
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
|
|
|
|
|
|
class HfEmbeddingModel(str, Enum):
|
|
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
|
|
|
|
|
|
class HuggingFaceAPI:
|
|
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
|
|
|
|
def __init__(
|
|
self,
|
|
api_url: str,
|
|
):
|
|
|
|
# The API endpoint we want to access
|
|
self.api_url: str = api_url
|
|
|
|
# Access token for the api
|
|
self.api_key: str = settings.HUGGING_FACE_API_KEY
|
|
|
|
# Headers going to be used
|
|
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
|
|
|
|
async def post(self, input: str) -> Any:
|
|
"""Post request to the endpoint to get an inference
|
|
|
|
Args:
|
|
input (str): the input that we will pass to the model
|
|
|
|
Raises:
|
|
OasstError: in the case we get a bad response
|
|
|
|
Returns:
|
|
inference: the inference we obtain from the model in HF
|
|
"""
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
payload: Dict[str, str] = {"inputs": input}
|
|
|
|
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
|
|
# If we get a bad response
|
|
if not response.ok:
|
|
logger.error(response)
|
|
logger.info(self.headers)
|
|
raise OasstError(
|
|
f"Response Error HuggingFace API (Status: {response.status})",
|
|
error_code=OasstErrorCode.HUGGINGFACE_API_ERROR,
|
|
)
|
|
|
|
# Get the response from the API call
|
|
inference = await response.json()
|
|
|
|
return inference
|