Files
Open-Assistant/backend/oasst_backend/api/v1/hugging_face.py
T
Nil Andreu a902c600fa Store Message Toxicity in database (#553)
* [NEW] MessageToxicity table

* [NEW] Alembic message Toxicity

* [NEW] Model name enum

* [NEW] Refactor Enum HF

* [NEW] Settings: DEBUT_SKIP_TOXICITY_CALCULATION

* [NEW] Store toxicity values

* [FIX] Merge conflict

* [FIX] Documentation

* [NEW] save_toxicity: function

* [FIX] Formatted string

* [NEW] DEBUG_SKIP_TOXICITY_CALCULATION=True

* [FIX] HfClassificationModel

* [FIX] Alembic merge heads

* [NEW] Refactor save_toxicity

* [NEW] Separating score/label

* [NEW] Store score and label

* [FIX] Cleaning Alembic

* [NEW] Clean HF names

* [NEW] Not type hinting

* [NEW] Update alembic versions

* [NEW] Revert the changes

* [NEW] Type hinting label & score

* Updated down_revision in migration script

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
2023-01-14 12:22:55 +00:00

34 lines
1.1 KiB
Python

from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HfClassificationModel, HfUrl, HuggingFaceAPI
router = APIRouter()
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> List[List[ToxicityClassification]]:
"""Get the Message Toxicity from HuggingFace Roberta model.
Args:
msg (str): the message that we want to analyze.
api_client (ApiClient, optional): authentification of the user of the request.
Defaults to Depends(deps.get_trusted_api_client).
Returns:
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HfUrl.HUGGINGFACE_TOXIC_CLASSIFICATION.value + "/" + HfClassificationModel.TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
return response