Files
Open-Assistant/backend/oasst_backend/api/v1/text_labels.py
T
Andreas Köpf 8906854dbf Extract UserRepository and TaskRepository from PromptRepository
* Extract classes UserRepository and TaskRepository from PromptRepository
* move close_task() to TaskRepository and get_user_leaderboard to UserRepository()
* Use UserRepository in leaderboards endpoint, add type annotation to leaderboards endpoint
2023-01-08 19:08:47 +01:00

46 lines
1.4 KiB
Python

from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelOption, ValidLabelsResponse
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
text_labels: protocol_schema.TextLabels,
) -> None:
"""
Label a piece of text.
"""
api_client = deps.api_auth(api_key, db)
try:
logger.info(f"Labeling text {text_labels=}.")
pr = PromptRepository(db, api_client, client_user=text_labels.user)
pr.store_text_labels(text_labels)
except Exception:
logger.exception("Failed to store label.")
raise HTTPException(
status_code=HTTP_400_BAD_REQUEST,
)
@router.get("/valid_labels")
def get_valid_lables() -> ValidLabelsResponse:
return ValidLabelsResponse(
valid_labels=[
LabelOption(name=l.value, display_text=l.display_text, help_text=l.help_text)
for l in protocol_schema.TextLabel
]
)