From fc7f8cbc1f759539e2acedd560e45038f36d5fc9 Mon Sep 17 00:00:00 2001 From: croumegous Date: Mon, 2 Jan 2023 22:55:08 +0100 Subject: [PATCH] chore: add fastapi response model to every endpoints, add openapi documentation for API response --- .../oasst_backend/api/v1/frontend_messages.py | 15 ++++++------ .../oasst_backend/api/v1/frontend_users.py | 9 ++++--- backend/oasst_backend/api/v1/messages.py | 24 +++++++++---------- backend/oasst_backend/api/v1/stats.py | 3 ++- backend/oasst_backend/api/v1/tasks.py | 11 +++++---- backend/oasst_backend/api/v1/text_labels.py | 4 ++-- backend/oasst_backend/api/v1/users.py | 16 ++++--------- 7 files changed, 39 insertions(+), 43 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index 261b24ea..956d9992 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -5,12 +5,13 @@ from oasst_backend.models import ApiClient from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol from sqlmodel import Session router = APIRouter() -@router.get("/{message_id}") +@router.get("/{message_id}", response_model=protocol.Message) def get_message_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -27,7 +28,7 @@ def get_message_by_frontend_id( return utils.prepare_message(message) -@router.get("/{message_id}/conversation") +@router.get("/{message_id}/conversation", response_model=protocol.Conversation) def get_conv_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -41,7 +42,7 @@ def get_conv_by_frontend_id( return utils.prepare_conversation(messages) -@router.get("/{message_id}/tree") +@router.get("/{message_id}/tree", response_model=protocol.MessageTree) def get_tree_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -55,7 +56,7 @@ def get_tree_by_frontend_id( return utils.prepare_tree(tree, message.message_tree_id) -@router.get("/{message_id}/children") +@router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -68,7 +69,7 @@ def get_children_by_frontend_id( return utils.prepare_message_list(messages) -@router.get("/{message_id}/descendants") +@router.get("/{message_id}/descendants", response_model=protocol.MessageTree) def get_descendants_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -82,7 +83,7 @@ def get_descendants_by_frontend_id( return utils.prepare_tree(descendants, message.id) -@router.get("/{message_id}/longest_conversation_in_tree") +@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation) def get_longest_conv_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -96,7 +97,7 @@ def get_longest_conv_by_frontend_id( return utils.prepare_conversation(conv) -@router.get("/{message_id}/max_children_in_tree") +@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree) def get_max_children_by_frontend_id( message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 738e3cb0..0a745462 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -6,14 +6,14 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol from sqlmodel import Session -from starlette.responses import Response -from starlette.status import HTTP_200_OK +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/{username}/messages") +@router.get("/{username}/messages", response_model=list[protocol.Message]) def query_frontend_user_messages( username: str, api_client_id: UUID = None, @@ -43,11 +43,10 @@ def query_frontend_user_messages( return utils.prepare_message_list(messages) -@router.delete("/{username}/messages") +@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT) def mark_frontend_user_messages_deleted( username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): pr = PromptRepository(db, api_client, None) messages = pr.query_messages(username=username, api_client_id=api_client.id) pr.mark_messages_deleted(messages) - return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 20420690..951355b3 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -1,20 +1,21 @@ import datetime from uuid import UUID -from fastapi import APIRouter, Depends, Query, Response +from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode +from oasst_shared.schemas import protocol from sqlmodel import Session -from starlette.status import HTTP_200_OK +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/") +@router.get("/", response_model=list[protocol.Message]) def query_messages( username: str = None, api_client_id: str = None, @@ -45,7 +46,7 @@ def query_messages( return utils.prepare_message_list(messages) -@router.get("/{message_id}") +@router.get("/{message_id}", response_model=protocol.Message) def get_message( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -61,7 +62,7 @@ def get_message( return utils.prepare_message(message) -@router.get("/{message_id}/conversation") +@router.get("/{message_id}/conversation", response_model=protocol.Conversation) def get_conv( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -74,7 +75,7 @@ def get_conv( return utils.prepare_conversation(messages) -@router.get("/{message_id}/tree") +@router.get("/{message_id}/tree", response_model=protocol.MessageTree) def get_tree( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -87,7 +88,7 @@ def get_tree( return utils.prepare_tree(tree, message.message_tree_id) -@router.get("/{message_id}/children") +@router.get("/{message_id}/children", response_model=list[protocol.Message]) def get_children( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -99,7 +100,7 @@ def get_children( return utils.prepare_message_list(messages) -@router.get("/{message_id}/descendants") +@router.get("/{message_id}/descendants", response_model=protocol.MessageTree) def get_descendants( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -112,7 +113,7 @@ def get_descendants( return utils.prepare_tree(descendants, message.id) -@router.get("/{message_id}/longest_conversation_in_tree") +@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation) def get_longest_conv( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -125,7 +126,7 @@ def get_longest_conv( return utils.prepare_conversation(conv) -@router.get("/{message_id}/max_children_in_tree") +@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree) def get_max_children( message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db) ): @@ -138,10 +139,9 @@ def get_max_children( return utils.prepare_tree([message, *children], message.id) -@router.delete("/{message_id}") +@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT) def mark_message_deleted( message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): pr = PromptRepository(db, api_client, None) pr.mark_messages_deleted(message_id) - return Response(status_code=HTTP_200_OK) diff --git a/backend/oasst_backend/api/v1/stats.py b/backend/oasst_backend/api/v1/stats.py index 0f275b7d..a54aa07b 100644 --- a/backend/oasst_backend/api/v1/stats.py +++ b/backend/oasst_backend/api/v1/stats.py @@ -2,12 +2,13 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol from sqlmodel import Session router = APIRouter() -@router.get("/") +@router.get("/", response_model=protocol.SystemStats) def get_message_stats( db: Session = Depends(deps.get_db), api_client: ApiClient = Depends(deps.get_trusted_api_client), diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 49e82880..e9ecc854 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -10,6 +10,7 @@ from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() @@ -159,14 +160,14 @@ def request_task( return task -@router.post("/{task_id}/ack", response_model=None) +@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT) def tasks_acknowledge( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), task_id: UUID, ack_request: protocol_schema.TaskAck, -) -> Any: +) -> None: """ The frontend acknowledges a task. """ @@ -187,14 +188,14 @@ def tasks_acknowledge( raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED) -@router.post("/{task_id}/nack", response_model=None) +@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT) def tasks_acknowledge_failure( *, db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), task_id: UUID, nack_request: protocol_schema.TaskNAck, -) -> Any: +) -> None: """ The frontend reports failure to implement a task. """ @@ -265,7 +266,7 @@ def tasks_interaction( raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED) -@router.post("/close") +@router.post("/close", response_model=protocol_schema.TaskDone) def close_collective_task( close_task_request: protocol_schema.TaskClose, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/api/v1/text_labels.py b/backend/oasst_backend/api/v1/text_labels.py index 856aeea5..0613711c 100644 --- a/backend/oasst_backend/api/v1/text_labels.py +++ b/backend/oasst_backend/api/v1/text_labels.py @@ -6,7 +6,7 @@ from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session -from starlette.status import HTTP_400_BAD_REQUEST +from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST router = APIRouter() @@ -16,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel): user: protocol_schema.User -@router.post("/") +@router.post("/", status_code=HTTP_204_NO_CONTENT) def label_text( *, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 16ab3133..8d55bfec 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -3,17 +3,17 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps +from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository from oasst_shared.schemas import protocol from sqlmodel import Session -from starlette.responses import Response -from starlette.status import HTTP_200_OK +from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() -@router.get("/{user_id}/messages") +@router.get("/{user_id}/messages", response_model=list[protocol.Message]) def query_user_messages( user_id: UUID, api_client_id: UUID = None, @@ -41,19 +41,13 @@ def query_user_messages( deleted=None if include_deleted else False, ) - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] + return utils.prepare_message_list(messages) -@router.delete("/{user_id}/messages") +@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT) def mark_user_messages_deleted( user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db) ): pr = PromptRepository(db, api_client, None) messages = pr.query_messages(user_id=user_id) pr.mark_messages_deleted(messages) - return Response(status_code=HTTP_200_OK)