diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index affc81bb..594ba0df 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -1,13 +1,19 @@ +from typing import Optional +from uuid import UUID + from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger from oasst_backend.api import deps +from oasst_backend.config import settings +from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TextLabel +from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -43,12 +49,28 @@ def label_text( @router.get("/valid_labels") -def get_valid_lables() -> ValidLabelsResponse: +def get_valid_lables( + *, + message_id: Optional[UUID] = None, + db: Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_api_client), +) -> ValidLabelsResponse: + if message_id: + pr = PromptRepository(db, api_client=api_client) + message = pr.fetch_message(message_id=message_id) + if message.parent_id is None: + valid_labels = settings.tree_manager.labels_initial_prompt + elif message.role == "assistant": + valid_labels = settings.tree_manager.labels_assistant_reply + else: + valid_labels = settings.tree_manager.labels_prompter_reply + else: + valid_labels = [l for l in TextLabel if l != TextLabel.fails_task] + return ValidLabelsResponse( valid_labels=[ LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text) - for l in TextLabel - if l != TextLabel.fails_task + for l in valid_labels ] )