diff --git a/backend/.gitignore b/backend/.gitignore index 098a83e4..30c79448 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -1,3 +1,4 @@ __pycache__ .env notes.txt +venv diff --git a/backend/main.py b/backend/main.py index cb682a9f..edbad943 100644 --- a/backend/main.py +++ b/backend/main.py @@ -230,8 +230,8 @@ if __name__ == "__main__": help="Dumps the openapi schema to stdout", action=argparse.BooleanOptionalAction, ) - parser.add_argument("--host", help="The host to run the server") - parser.add_argument("--port", help="The port to run the server") + parser.add_argument("--host", help="The host to run the server", default="0.0.0.0") + parser.add_argument("--port", help="The port to run the server", default=8080) args = parser.parse_args() diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index a9d09457..5bdf1c97 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -2,6 +2,7 @@ from fastapi import APIRouter from oasst_backend.api.v1 import ( frontend_messages, frontend_users, + hugging_face, leaderboards, messages, stats, @@ -19,3 +20,4 @@ api_router.include_router(users.router, prefix="/users", tags=["users"]) api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=["frontend_users"]) api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/experimental/leaderboards", tags=["leaderboards"]) +api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) diff --git a/backend/oasst_backend/api/v1/hugging_face.py b/backend/oasst_backend/api/v1/hugging_face.py new file mode 100644 index 00000000..1e7f1ffe --- /dev/null +++ b/backend/oasst_backend/api/v1/hugging_face.py @@ -0,0 +1,37 @@ +from enum import Enum +from typing import List + +from fastapi import APIRouter, Depends +from oasst_backend.api import deps +from oasst_backend.models import ApiClient +from oasst_backend.schemas.hugging_face import ToxicityClassification +from oasst_backend.utils.hugging_face import HuggingFaceAPI + +router = APIRouter() + + +class HF_url(str, Enum): + HUGGINGFACE_TOXIC_ROBERTA = "https://api-inference.huggingface.co/models/unitary/multilingual-toxic-xlm-roberta" + + +@router.get("/text_toxicity") +async def get_text_toxicity( + msg: str, + api_client: ApiClient = Depends(deps.get_trusted_api_client), +) -> List[List[ToxicityClassification]]: + """Get the Message Toxicity from HuggingFace Roberta model. + + Args: + msg (str): the message that we want to analyze. + api_client (ApiClient, optional): authentification of the user of the request. + Defaults to Depends(deps.get_trusted_api_client). + + Returns: + ToxicityClassification: the score of toxicity of the message. + """ + + api_url: str = HF_url.HUGGINGFACE_TOXIC_ROBERTA.value + hugging_face_api = HuggingFaceAPI(api_url) + response = await hugging_face_api.post(msg) + + return response diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index fef59832..df37dc9f 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -22,6 +22,8 @@ class Settings(BaseSettings): DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False + HUGGING_FACE_API_KEY: str = "" + @validator("DATABASE_URI", pre=True) def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: if isinstance(v, str): diff --git a/backend/oasst_backend/schemas/__init__.py b/backend/oasst_backend/schemas/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/schemas/hugging_face.py b/backend/oasst_backend/schemas/hugging_face.py new file mode 100644 index 00000000..f4da3e74 --- /dev/null +++ b/backend/oasst_backend/schemas/hugging_face.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ToxicityClassification(BaseModel): + label: str + score: float diff --git a/backend/oasst_backend/utils/__init__.py b/backend/oasst_backend/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/backend/oasst_backend/utils/hugging_face.py b/backend/oasst_backend/utils/hugging_face.py new file mode 100644 index 00000000..ff73a1c5 --- /dev/null +++ b/backend/oasst_backend/utils/hugging_face.py @@ -0,0 +1,49 @@ +from typing import Any, Dict + +import aiohttp +from oasst_backend.config import settings +from oasst_shared.exceptions import OasstError, OasstErrorCode + + +class HuggingFaceAPI: + """Class Object to make post calls to endpoints for inference in models hosted in HuggingFace""" + + def __init__( + self, + api_url: str, + ): + + # The API endpoint we want to access + self.api_url: str = api_url + + # Access token for the api + self.api_key: str = settings.HUGGING_FACE_API_KEY + + # Headers going to be used + self.headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"} + + async def post(self, input: str) -> Any: + """Post request to the endpoint to get an inference + + Args: + input (str): the input that we will pass to the model + + Raises: + OasstError: in the case we get a bad response + + Returns: + inference: the inference we obtain from the model in HF + """ + + session = aiohttp.ClientSession() + payload: Dict[str, str] = {"inputs": input} + response = await session.post(self.api_url, headers=self.headers, json=payload) + + # If we get a bad response + if response.status != 200: + raise OasstError("Response Error Detoxify HuggingFace", error_code=OasstErrorCode.HUGGINGFACE_API_ERROR) + + # Get the response from the API call + inference = await response.json() + + return inference diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 49eeb088..6cc25918 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -10,6 +10,7 @@ class OasstErrorCode(IntEnum): 0-1000: general errors 1000-2000: tasks endpoint 2000-3000: prompt_repository + 3000-4000: external resources """ # 0-1000: general errors @@ -45,6 +46,9 @@ class OasstErrorCode(IntEnum): TASK_ALREADY_DONE = 2105 TASK_NOT_COLLECTIVE = 2106 + # 3000-4000: external resources + HUGGINGFACE_API_ERROR = 3001 + class OasstError(Exception): """Base class for Open-Assistant exceptions."""