Detoxify Rest API Client (#404)

* [NEW] utils: Endpoint Toxic Roberta

* [NEW] Constants API URL

* [NEW] Git ignore venv

* [NEW] Lint

* [NEW] Backend default args

* [NEW] HUGGINGFACE_API_ERROR

* [NEW] Requests package

* [NEW] Get Toxicity Endpoint

* [NEW] Schema:  ToxicityClassification

[NEW] Constants module

[FIX] Module

* [FIX] Test Key HF

* [NEW] settings: HUGGING_FACE_API_KEY

* [NEW] Remove requests

* [NEW] HuggingFace client

* [NEW] Cleaning code
This commit is contained in:
Nil Andreu
2023-01-05 16:18:04 +01:00
committed by GitHub
parent 3dbe0ae1ba
commit ee50b573e1
10 changed files with 103 additions and 2 deletions
+1
View File
@@ -1,3 +1,4 @@
__pycache__
.env
notes.txt
venv
+2 -2
View File
@@ -230,8 +230,8 @@ if __name__ == "__main__":
help="Dumps the openapi schema to stdout",
action=argparse.BooleanOptionalAction,
)
parser.add_argument("--host", help="The host to run the server")
parser.add_argument("--port", help="The port to run the server")
parser.add_argument("--host", help="The host to run the server", default="0.0.0.0")
parser.add_argument("--port", help="The port to run the server", default=8080)
args = parser.parse_args()
+2
View File
@@ -2,6 +2,7 @@ from fastapi import APIRouter
from oasst_backend.api.v1 import (
frontend_messages,
frontend_users,
hugging_face,
leaderboards,
messages,
stats,
@@ -19,3 +20,4 @@ api_router.include_router(users.router, prefix="/users", tags=["users"])
api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"])
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
@@ -0,0 +1,37 @@
from enum import Enum
from typing import List
from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.schemas.hugging_face import ToxicityClassification
from oasst_backend.utils.hugging_face import HuggingFaceAPI
router = APIRouter()
class HF_url(str, Enum):
HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta"
@router.get("/text_toxicity")
async def get_text_toxicity(
msg: str,
api_client: ApiClient = Depends(deps.get_trusted_api_client),
) -> List[List[ToxicityClassification]]:
"""Get the Message Toxicity from HuggingFace Roberta model.
Args:
msg (str): the message that we want to analyze.
api_client (ApiClient, optional): authentification of the user of the request.
Defaults to Depends(deps.get_trusted_api_client).
Returns:
ToxicityClassification: the score of toxicity of the message.
"""
api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value
hugging_face_api = HuggingFaceAPI(api_url)
response = await hugging_face_api.post(msg)
return response
+2
View File
@@ -22,6 +22,8 @@ class Settings(BaseSettings):
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
HUGGING_FACE_API_KEY: str = ""
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ToxicityClassification(BaseModel):
label: str
score: float
@@ -0,0 +1,49 @@
from typing import Any, Dict
import aiohttp
from oasst_backend.config import settings
from oasst_shared.exceptions import OasstError, OasstErrorCode
class HuggingFaceAPI:
"""Class Object to make post calls to endpoints for inference in models hosted in HuggingFace"""
def __init__(
self,
api_url: str,
):
# The API endpoint we want to access
self.api_url: str = api_url
# Access token for the api
self.api_key: str = settings.HUGGING_FACE_API_KEY
# Headers going to be used
self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
async def post(self, input: str) -> Any:
"""Post request to the endpoint to get an inference
Args:
input (str): the input that we will pass to the model
Raises:
OasstError: in the case we get a bad response
Returns:
inference: the inference we obtain from the model in HF
"""
session = aiohttp.ClientSession()
payload: Dict[str, str] = {"inputs": input}
response = await session.post(self.api_url, headers=self.headers, json=payload)
# If we get a bad response
if response.status != 200:
raise OasstError("Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR)
# Get the response from the API call
inference = await response.json()
return inference
@@ -10,6 +10,7 @@ class OasstErrorCode(IntEnum):
0-1000: general errors
1000-2000: tasks endpoint
2000-3000: prompt_repository
3000-4000: external resources
"""
# 0-1000: general errors
@@ -45,6 +46,9 @@ class OasstErrorCode(IntEnum):
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001
class OasstError(Exception):
"""Base class for Open-Assistant exceptions."""