Merge pull request #295 from croumegous/fastapi-response-model

chore: add fastapi response model to every endpoints, add openapi documentation for API response
This commit is contained in:
Yannic Kilcher
2023-01-02 23:17:48 +01:00
committed by GitHub
7 changed files with 39 additions and 43 deletions
@@ -5,12 +5,13 @@ from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@router.get("/{message_id}")
@router.get("/{message_id}", response_model=protocol.Message)
def get_message_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -27,7 +28,7 @@ def get_message_by_frontend_id(
return utils.prepare_message(message)
@router.get("/{message_id}/conversation")
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -41,7 +42,7 @@ def get_conv_by_frontend_id(
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree")
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -55,7 +56,7 @@ def get_tree_by_frontend_id(
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children")
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -68,7 +69,7 @@ def get_children_by_frontend_id(
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants")
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -82,7 +83,7 @@ def get_descendants_by_frontend_id(
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree")
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -96,7 +97,7 @@ def get_longest_conv_by_frontend_id(
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree")
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children_by_frontend_id(
message_id: str, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -6,14 +6,14 @@ from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/{username}/messages")
@router.get("/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
username: str,
api_client_id: UUID = None,
@@ -43,11 +43,10 @@ def query_frontend_user_messages(
return utils.prepare_message_list(messages)
@router.delete("/{username}/messages")
@router.delete("/{username}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_frontend_user_messages_deleted(
username: str, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(username=username, api_client_id=api_client.id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)
+12 -12
View File
@@ -1,20 +1,21 @@
import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, Query, Response
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/")
@router.get("/", response_model=list[protocol.Message])
def query_messages(
username: str = None,
api_client_id: str = None,
@@ -45,7 +46,7 @@ def query_messages(
return utils.prepare_message_list(messages)
@router.get("/{message_id}")
@router.get("/{message_id}", response_model=protocol.Message)
def get_message(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -61,7 +62,7 @@ def get_message(
return utils.prepare_message(message)
@router.get("/{message_id}/conversation")
@router.get("/{message_id}/conversation", response_model=protocol.Conversation)
def get_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -74,7 +75,7 @@ def get_conv(
return utils.prepare_conversation(messages)
@router.get("/{message_id}/tree")
@router.get("/{message_id}/tree", response_model=protocol.MessageTree)
def get_tree(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -87,7 +88,7 @@ def get_tree(
return utils.prepare_tree(tree, message.message_tree_id)
@router.get("/{message_id}/children")
@router.get("/{message_id}/children", response_model=list[protocol.Message])
def get_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -99,7 +100,7 @@ def get_children(
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants")
@router.get("/{message_id}/descendants", response_model=protocol.MessageTree)
def get_descendants(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -112,7 +113,7 @@ def get_descendants(
return utils.prepare_tree(descendants, message.id)
@router.get("/{message_id}/longest_conversation_in_tree")
@router.get("/{message_id}/longest_conversation_in_tree", response_model=protocol.Conversation)
def get_longest_conv(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -125,7 +126,7 @@ def get_longest_conv(
return utils.prepare_conversation(conv)
@router.get("/{message_id}/max_children_in_tree")
@router.get("/{message_id}/max_children_in_tree", response_model=protocol.MessageTree)
def get_max_children(
message_id: UUID, api_client: ApiClient = Depends(deps.get_api_client), db: Session = Depends(deps.get_db)
):
@@ -138,10 +139,9 @@ def get_max_children(
return utils.prepare_tree([message, *children], message.id)
@router.delete("/{message_id}")
@router.delete("/{message_id}", status_code=HTTP_204_NO_CONTENT)
def mark_message_deleted(
message_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
pr.mark_messages_deleted(message_id)
return Response(status_code=HTTP_200_OK)
+2 -1
View File
@@ -2,12 +2,13 @@ from fastapi import APIRouter, Depends
from oasst_backend.api import deps
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@router.get("/")
@router.get("/", response_model=protocol.SystemStats)
def get_message_stats(
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
+6 -5
View File
@@ -10,6 +10,7 @@ from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@@ -159,14 +160,14 @@ def request_task(
return task
@router.post("/{task_id}/ack", response_model=None)
@router.post("/{task_id}/ack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
ack_request: protocol_schema.TaskAck,
) -> Any:
) -> None:
"""
The frontend acknowledges a task.
"""
@@ -187,14 +188,14 @@ def tasks_acknowledge(
raise OasstError("Failed to acknowledge task.", OasstErrorCode.TASK_ACK_FAILED)
@router.post("/{task_id}/nack", response_model=None)
@router.post("/{task_id}/nack", response_model=None, status_code=HTTP_204_NO_CONTENT)
def tasks_acknowledge_failure(
*,
db: Session = Depends(deps.get_db),
api_key: APIKey = Depends(deps.get_api_key),
task_id: UUID,
nack_request: protocol_schema.TaskNAck,
) -> Any:
) -> None:
"""
The frontend reports failure to implement a task.
"""
@@ -265,7 +266,7 @@ def tasks_interaction(
raise OasstError("Interaction request failed.", OasstErrorCode.TASK_INTERACTION_REQUEST_FAILED)
@router.post("/close")
@router.post("/close", response_model=protocol_schema.TaskDone)
def close_collective_task(
close_task_request: protocol_schema.TaskClose,
db: Session = Depends(deps.get_db),
+2 -2
View File
@@ -6,7 +6,7 @@ from oasst_backend.api import deps
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_400_BAD_REQUEST
from starlette.status import HTTP_204_NO_CONTENT, HTTP_400_BAD_REQUEST
router = APIRouter()
@@ -16,7 +16,7 @@ class LabelTextRequest(pydantic.BaseModel):
user: protocol_schema.User
@router.post("/")
@router.post("/", status_code=HTTP_204_NO_CONTENT)
def label_text(
*,
db: Session = Depends(deps.get_db),
+5 -11
View File
@@ -3,17 +3,17 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/{user_id}/messages")
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
@@ -41,19 +41,13 @@ def query_user_messages(
deleted=None if include_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)
@router.delete("/{user_id}/messages")
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client, None)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)
return Response(status_code=HTTP_200_OK)