add optional message_id query param to text_labels/valid_labels endpoint

This commit is contained in:
Andreas Köpf
2023-01-28 15:29:38 +01:00
parent 264e914225
commit 19116f7251
+25 -3
View File
@@ -1,13 +1,19 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security.api_key import APIKey
from loguru import logger
from oasst_backend.api import deps
from oasst_backend.config import settings
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TextLabel
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -43,12 +49,28 @@ def label_text(
@router.get("/valid_labels")
def get_valid_lables() -> ValidLabelsResponse:
def get_valid_lables(
*,
message_id: Optional[UUID] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_api_client),
) -> ValidLabelsResponse:
if message_id:
pr = PromptRepository(db, api_client=api_client)
message = pr.fetch_message(message_id=message_id)
if message.parent_id is None:
valid_labels = settings.tree_manager.labels_initial_prompt
elif message.role == "assistant":
valid_labels = settings.tree_manager.labels_assistant_reply
else:
valid_labels = settings.tree_manager.labels_prompter_reply
else:
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
return ValidLabelsResponse(
valid_labels=[
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
for l in TextLabel
if l != TextLabel.fails_task
for l in valid_labels
]
)