mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge pull request #295 from croumegous/fastapi-response-model
chore: add fastapi response model to every endpoints, add openapi documentation for API response
This commit is contained in:
@@ -5,12 +5,13 @@ from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -27,7 +28,7 @@ def get_message_by_frontend_id(
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -41,7 +42,7 @@ def get_conv_by_frontend_id(
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -55,7 +56,7 @@ def get_tree_by_frontend_id(
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -68,7 +69,7 @@ def get_children_by_frontend_id(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -82,7 +83,7 @@ def get_descendants_by_frontend_id(
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -96,7 +97,7 @@ def get_longest_conv_by_frontend_id(
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children_by_frontend_id(
|
||||
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
|
||||
@@ -6,14 +6,14 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{username}/messages")
|
||||
@router.get("/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
api_client_id: UUID = None,
|
||||
@@ -43,11 +43,10 @@ def query_frontend_user_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{username}/messages")
|
||||
@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_frontend_user_messages_deleted(
|
||||
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(username=username, api_client_id=api_client.id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Response
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", response_model=list[protocol.Message])
|
||||
def query_messages(
|
||||
username: str = None,
|
||||
api_client_id: str = None,
|
||||
@@ -45,7 +46,7 @@ def query_messages(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
@router.get("/{message_id}", response_model=protocol.Message)
|
||||
def get_message(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -61,7 +62,7 @@ def get_message(
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
|
||||
def get_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -74,7 +75,7 @@ def get_conv(
|
||||
return utils.prepare_conversation(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/tree")
|
||||
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
|
||||
def get_tree(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -87,7 +88,7 @@ def get_tree(
|
||||
return utils.prepare_tree(tree, message.message_tree_id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/children")
|
||||
@router.get("/{message_id}/children", response_model=list[protocol.Message])
|
||||
def get_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -99,7 +100,7 @@ def get_children(
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
|
||||
def get_descendants(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -112,7 +113,7 @@ def get_descendants(
|
||||
return utils.prepare_tree(descendants, message.id)
|
||||
|
||||
|
||||
@router.get("/{message_id}/longest_conversation_in_tree")
|
||||
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
|
||||
def get_longest_conv(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -125,7 +126,7 @@ def get_longest_conv(
|
||||
return utils.prepare_conversation(conv)
|
||||
|
||||
|
||||
@router.get("/{message_id}/max_children_in_tree")
|
||||
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
|
||||
def get_max_children(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
@@ -138,10 +139,9 @@ def get_max_children(
|
||||
return utils.prepare_tree([message, *children], message.id)
|
||||
|
||||
|
||||
@router.delete("/{message_id}")
|
||||
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_message_deleted(
|
||||
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
pr.mark_messages_deleted(message_id)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
@@ -2,12 +2,13 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/")
|
||||
@router.get("/", response_model=protocol.SystemStats)
|
||||
def get_message_stats(
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
|
||||
@@ -10,6 +10,7 @@ from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -159,14 +160,14 @@ def request_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/{task_id}/ack", response_model=None)
|
||||
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
ack_request: protocol_schema.TaskAck,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""
|
||||
The frontend acknowledges a task.
|
||||
"""
|
||||
@@ -187,14 +188,14 @@ def tasks_acknowledge(
|
||||
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
|
||||
|
||||
|
||||
@router.post("/{task_id}/nack", response_model=None)
|
||||
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
|
||||
def tasks_acknowledge_failure(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_key: APIKey = Depends(deps.get_api_key),
|
||||
task_id: UUID,
|
||||
nack_request: protocol_schema.TaskNAck,
|
||||
) -> Any:
|
||||
) -> None:
|
||||
"""
|
||||
The frontend reports failure to implement a task.
|
||||
"""
|
||||
@@ -265,7 +266,7 @@ def tasks_interaction(
|
||||
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
|
||||
|
||||
|
||||
@router.post("/close")
|
||||
@router.post("/close", response_model=protocol_schema.TaskDone)
|
||||
def close_collective_task(
|
||||
close_task_request: protocol_schema.TaskClose,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -6,7 +6,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_400_BAD_REQUEST
|
||||
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -16,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel):
|
||||
user: protocol_schema.User
|
||||
|
||||
|
||||
@router.post("/")
|
||||
@router.post("/", status_code=HTTP_204_NO_CONTENT)
|
||||
def label_text(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -3,17 +3,17 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages")
|
||||
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
@@ -41,19 +41,13 @@ def query_user_messages(
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{user_id}/messages")
|
||||
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
|
||||
def mark_user_messages_deleted(
|
||||
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
|
||||
):
|
||||
pr = PromptRepository(db, api_client, None)
|
||||
messages = pr.query_messages(user_id=user_id)
|
||||
pr.mark_messages_deleted(messages)
|
||||
return Response(status_code=HTTP_200_OK)
|
||||
|
||||
Reference in New Issue
Block a user