Files
Open-Assistant/backend/oasst_backend/utils/hugging_face.py
T
Andreas Köpf 14fa08e2e7 Message tree state machine (#555)
* add query_incomplete_rankings()

* Add SQL queries for TreeManager task selection

* first working version of TreeManager.next_task()

* remove old generate_task(), add mandatory_labels to text_labels task

* Add ConversationMessage list to Ranking tasks

* add more sophisticated sql queries to find extendible trees

* add TreeManager.query_extendible_parents()

* fix task validation, seed data insertion (reviewed)

* provide user for task selection in text-frontend

* enter 'growing' state

* enter 'aborted_low_grade' state

* enter 'ranking' state

* check tree 'growing' state upon relpy insertion

* exclude user from labeling their own messages (added DEBUG_ALLOW_SELF_LABELING setting)

* add DEBUG_ALLOW_SELF_LABELING to docker-compose.yaml

* fix ranking submission

* add query_tree_ranking_results()

* add ranked_message_ids to RankingReactionPayload

* fix reply_messages instead of prompt_messages

* incorment 'ranking_count' of ranked replies

* added logic to check_condition_for_scoring_state

* changes to msg_tree_state_machine

* pre-commit changes

* enter 'ready_for_scoring' state

* re-add HF embedding call (lost during merge)

* use prepare_conversation() helper for seed-data creation

* Partially add user specified task selection

Co-authored-by: Daniel Hug <danielpatrickhug@gmail.com>
2023-01-11 10:54:03 +01:00

66 lines
2.1 KiB
Python

from enum import Enum
from typing import Any, Dict
import aiohttp
from loguru import logger
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HfUrl(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = ("https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta",)
HUGGINGFACE_FEATURE_EXTRACTION = "https://api-inference.huggingface.co/pipeline/feature-extraction"
class HfEmbeddingModel(str, Enum):
MINILM = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
def __init__(
self,
api_url: str,
):
# The API endpoint we want to access
self.api_url: str = api_url
# Access token for the api
self.api_key: str = settings.HUGGING_FACE_API_KEY
# Headers going to be used
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
async def post(self, input: str) -> Any:
"""Post request to the endpoint to get an inference
Args:
input (str): the input that we will pass to the model
Raises:
OasstError: in the case we get a bad response
Returns:
inference: the inference we obtain from the model in HF
"""
async with aiohttp.ClientSession() as session:
payload: Dict[str, str] = {"inputs": input}
async with session.post(self.api_url, headers=self.headers, json=payload) as response:
# If we get a bad response
if not response.ok:
logger.error(response)
logger.info(self.headers)
raise OasstError(
f"Response Error HuggingFace API (Status: {response.status})",
error_code=OasstErrorCode.HUGGINGFACE_API_ERROR,
)
# Get the response from the API call
inference = await response.json()
return inference