mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
add optional message_id query param to text_labels/valid_labels endpoint
This commit is contained in:
@@ -1,13 +1,19 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security.api_key import APIKey
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.schemas.text_labels import LabelDescription, ValidLabelsResponse
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TextLabel
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
@@ -43,12 +49,28 @@ def label_text(
|
||||
|
||||
|
||||
@router.get("/valid_labels")
|
||||
def get_valid_lables() -> ValidLabelsResponse:
|
||||
def get_valid_lables(
|
||||
*,
|
||||
message_id: Optional[UUID] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
) -> ValidLabelsResponse:
|
||||
if message_id:
|
||||
pr = PromptRepository(db, api_client=api_client)
|
||||
message = pr.fetch_message(message_id=message_id)
|
||||
if message.parent_id is None:
|
||||
valid_labels = settings.tree_manager.labels_initial_prompt
|
||||
elif message.role == "assistant":
|
||||
valid_labels = settings.tree_manager.labels_assistant_reply
|
||||
else:
|
||||
valid_labels = settings.tree_manager.labels_prompter_reply
|
||||
else:
|
||||
valid_labels = [l for l in TextLabel if l != TextLabel.fails_task]
|
||||
|
||||
return ValidLabelsResponse(
|
||||
valid_labels=[
|
||||
LabelDescription(name=l.value, widget=l.widget.value, display_text=l.display_text, help_text=l.help_text)
|
||||
for l in TextLabel
|
||||
if l != TextLabel.fails_task
|
||||
for l in valid_labels
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user