From 7b29582cbb0111280c519774e35653ddbb60c376 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 31 Dec 2022 19:23:54 +0100 Subject: [PATCH] add to protocol.Message --- .../oasst_backend/api/v1/frontend_messages.py | 10 ++------ .../oasst_backend/api/v1/frontend_users.py | 10 ++------ backend/oasst_backend/api/v1/messages.py | 17 +++---------- backend/oasst_backend/api/v1/utils.py | 25 +++++++++++++------ backend/oasst_backend/prompt_repository.py | 16 +++++++----- oasst-shared/oasst_shared/schemas/protocol.py | 2 ++ 6 files changed, 36 insertions(+), 44 deletions(-) diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index 85ab74ca..6ee27aa1 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -7,7 +7,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.models import ApiClient from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.schemas import protocol from sqlmodel import Session router = APIRouter() @@ -27,7 +26,7 @@ def get_message_by_frontend_id( # Unexpected message payload raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) - return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + return utils.prepare_message(message) @router.get("/{message_id}/conversation") @@ -68,12 +67,7 @@ def get_children_by_frontend_id( pr = PromptRepository(db, api_client, user=None) message = pr.fetch_message_by_frontend_message_id(message_id) messages = pr.fetch_message_children(message.id) - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] + return utils.prepare_message_list(messages) @router.get("/{message_id}/descendants") diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 35048331..940c7bb3 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -4,9 +4,9 @@ from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps +from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.responses import Response from starlette.status import HTTP_200_OK @@ -41,13 +41,7 @@ def query_frontend_user_messages( only_roots=only_roots, deleted=None if include_deleted else False, ) - - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] + return utils.prepare_message_list(messages) @router.delete("/{username}/messages") diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 9221afc3..71e4e3eb 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -10,7 +10,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.models import ApiClient from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_200_OK @@ -45,12 +44,7 @@ def query_messages( deleted=None if allow_deleted else False, ) - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] + return utils.prepare_message_list(messages) @router.get("/{message_id}") @@ -66,7 +60,7 @@ def get_message( # Unexptcted message payload raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) - return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + return utils.prepare_message(message) @router.get("/{message_id}/conversation") @@ -104,12 +98,7 @@ def get_children( """ pr = PromptRepository(db, api_client, user=None) messages = pr.fetch_message_children(message_id) - return [ - protocol.Message( - id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant") - ) - for m in messages - ] + return utils.prepare_message_list(messages) @router.get("/{message_id}/descendants") diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 2dac7947..0fa452bb 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -9,6 +9,22 @@ from oasst_backend.models.db_payload import MessagePayload from oasst_shared.schemas import protocol +def prepare_message(m: Message) -> protocol.Message: + if not isinstance(m.payload.payload, MessagePayload): + raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) + return protocol.Message( + id=m.id, + parent_id=m.parent_id, + text=m.payload.payload.text, + is_assistant=(m.role == "assistant"), + created_date=m.created_date, + ) + + +def prepare_message_list(messages: list[Message]) -> list[protocol.Message]: + return [prepare_message(m) for m in messages] + + def prepare_conversation(messages: list[Message]) -> protocol.Conversation: conv_messages = [] for message in messages: @@ -26,13 +42,6 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree: for message in tree: if not isinstance(message.payload.payload, MessagePayload): raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) - tree_messages.append( - protocol.Message( - id=message.id, - parent_id=message.parent_id, - text=message.payload.payload.text, - is_assistant=(message.role == "assistant"), - ) - ) + tree_messages.append(prepare_message(message)) return protocol.MessageTree(id=tree_id, messages=tree_messages) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index b95f9d53..8cc770c5 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -16,7 +16,7 @@ from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats from sqlalchemy import update from sqlmodel import Session, func -from starlette.status import HTTP_403_FORBIDDEN +from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND class PromptRepository: @@ -72,7 +72,7 @@ class PromptRepository: # find task task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) if task.expired: raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) if task.done or task.ack is not None: @@ -88,7 +88,7 @@ class PromptRepository: # find task task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() if task is None: - raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND) if task.expired: raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED) if task.done or task.ack is not None: @@ -108,7 +108,9 @@ class PromptRepository: ) if fail_if_missing and message is None: raise OasstError( - f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND + f"Message with frontend_message_id {frontend_message_id} not found.", + OasstErrorCode.MESSAGE_NOT_FOUND, + HTTP_404_NOT_FOUND, ) return message @@ -488,7 +490,7 @@ class PromptRepository: def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]: message = self.db.query(Message).filter(Message.id == message_id).one_or_none() if fail_if_missing and not message: - raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND) + raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND) return message def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): @@ -499,7 +501,9 @@ class PromptRepository: task = self.fetch_task_by_frontend_message_id(frontend_message_id) if not task: - raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND) + raise OasstError( + f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND + ) if task.expired: raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED) if not allow_personal_tasks and not task.collective: diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index ed7dc780..8a6685c2 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- import enum +from datetime import datetime from typing import Literal, Optional, Union from uuid import UUID, uuid4 @@ -41,6 +42,7 @@ class Conversation(BaseModel): class Message(ConversationMessage): id: UUID parent_id: Optional[UUID] = None + created_date: Optional[datetime] = None class MessageTree(BaseModel):