mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-05 17:30:48 +08:00
Detoxify Rest API Client (#404)
* [NEW] utils: Endpoint Toxic Roberta * [NEW] Constants API URL * [NEW] Git ignore venv * [NEW] Lint * [NEW] Backend default args * [NEW] HUGGINGFACE_API_ERROR * [NEW] Requests package * [NEW] Get Toxicity Endpoint * [NEW] Schema: ToxicityClassification [NEW] Constants module [FIX] Module * [FIX] Test Key HF * [NEW] settings: HUGGING_FACE_API_KEY * [NEW] Remove requests * [NEW] HuggingFace client * [NEW] Cleaning code
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
__pycache__
|
||||
.env
|
||||
notes.txt
|
||||
venv
|
||||
|
||||
+2
-2
@@ -230,8 +230,8 @@ if __name__ == "__main__":
|
||||
help="Dumps the openapi schema to stdout",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
)
|
||||
parser.add_argument("--host", help="The host to run the server")
|
||||
parser.add_argument("--port", help="The port to run the server")
|
||||
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
|
||||
parser.add_argument("--port", help="The port to run the server", default=8080)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import (
|
||||
frontend_messages,
|
||||
frontend_users,
|
||||
hugging_face,
|
||||
leaderboards,
|
||||
messages,
|
||||
stats,
|
||||
@@ -19,3 +20,4 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
|
||||
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.schemas.hugging_face import ToxicityClassification
|
||||
from oasst_backend.utils.hugging_face import HuggingFaceAPI
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HF_url(str, Enum):
|
||||
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
|
||||
|
||||
|
||||
@router.get("/text_toxicity")
|
||||
async def get_text_toxicity(
|
||||
msg: str,
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
) -> List[List[ToxicityClassification]]:
|
||||
"""Get the Message Toxicity from HuggingFace Roberta model.
|
||||
|
||||
Args:
|
||||
msg (str): the message that we want to analyze.
|
||||
api_client (ApiClient, optional): authentification of the user of the request.
|
||||
Defaults to Depends(deps.get_trusted_api_client).
|
||||
|
||||
Returns:
|
||||
ToxicityClassification: the score of toxicity of the message.
|
||||
"""
|
||||
|
||||
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
|
||||
hugging_face_api = HuggingFaceAPI(api_url)
|
||||
response = await hugging_face_api.post(msg)
|
||||
|
||||
return response
|
||||
@@ -22,6 +22,8 @@ class Settings(BaseSettings):
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToxicityClassification(BaseModel):
|
||||
label: str
|
||||
score: float
|
||||
@@ -0,0 +1,49 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from oasst_backend.config import settings
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
|
||||
|
||||
class HuggingFaceAPI:
|
||||
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_url: str,
|
||||
):
|
||||
|
||||
# The API endpoint we want to access
|
||||
self.api_url: str = api_url
|
||||
|
||||
# Access token for the api
|
||||
self.api_key: str = settings.HUGGING_FACE_API_KEY
|
||||
|
||||
# Headers going to be used
|
||||
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
async def post(self, input: str) -> Any:
|
||||
"""Post request to the endpoint to get an inference
|
||||
|
||||
Args:
|
||||
input (str): the input that we will pass to the model
|
||||
|
||||
Raises:
|
||||
OasstError: in the case we get a bad response
|
||||
|
||||
Returns:
|
||||
inference: the inference we obtain from the model in HF
|
||||
"""
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
payload: Dict[str, str] = {"inputs": input}
|
||||
response = await session.post(self.api_url, headers=self.headers, json=payload)
|
||||
|
||||
# If we get a bad response
|
||||
if response.status != 200:
|
||||
raise OasstError("Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR)
|
||||
|
||||
# Get the response from the API call
|
||||
inference = await response.json()
|
||||
|
||||
return inference
|
||||
@@ -10,6 +10,7 @@ class OasstErrorCode(IntEnum):
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
3000-4000: external resources
|
||||
"""
|
||||
|
||||
# 0-1000: general errors
|
||||
@@ -45,6 +46,9 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
|
||||
# 3000-4000: external resources
|
||||
HUGGINGFACE_API_ERROR = 3001
|
||||
|
||||
|
||||
class OasstError(Exception):
|
||||
"""Base class for Open-Assistant exceptions."""
|
||||
|
||||
Reference in New Issue
Block a user