add to protocol.Message

This commit is contained in:
Andreas Köpf
2022-12-31 19:23:54 +01:00
parent bd2a7e93e3
commit 7b29582cbb
6 changed files with 36 additions and 44 deletions
@@ -7,7 +7,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
router = APIRouter()
@@ -27,7 +26,7 @@ def get_message_by_frontend_id(
# Unexpected message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
return utils.prepare_message(message)
@router.get("/{message_id}/conversation")
@@ -68,12 +67,7 @@ def get_children_by_frontend_id(
pr = PromptRepository(db, api_client, user=None)
message = pr.fetch_message_by_frontend_message_id(message_id)
messages = pr.fetch_message_children(message.id)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants")
@@ -4,9 +4,9 @@ from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.responses import Response
from starlette.status import HTTP_200_OK
@@ -41,13 +41,7 @@ def query_frontend_user_messages(
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)
@router.delete("/{username}/messages")
+3 -14
View File
@@ -10,7 +10,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode
from oasst_backend.models import ApiClient
from oasst_backend.models.db_payload import MessagePayload
from oasst_backend.prompt_repository import PromptRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_200_OK
@@ -45,12 +44,7 @@ def query_messages(
deleted=None if allow_deleted else False,
)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)
@router.get("/{message_id}")
@@ -66,7 +60,7 @@ def get_message(
# Unexptcted message payload
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
return utils.prepare_message(message)
@router.get("/{message_id}/conversation")
@@ -104,12 +98,7 @@ def get_children(
"""
pr = PromptRepository(db, api_client, user=None)
messages = pr.fetch_message_children(message_id)
return [
protocol.Message(
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
)
for m in messages
]
return utils.prepare_message_list(messages)
@router.get("/{message_id}/descendants")
+17 -8
View File
@@ -9,6 +9,22 @@ from oasst_backend.models.db_payload import MessagePayload
from oasst_shared.schemas import protocol
def prepare_message(m: Message) -> protocol.Message:
if not isinstance(m.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
return protocol.Message(
id=m.id,
parent_id=m.parent_id,
text=m.payload.payload.text,
is_assistant=(m.role == "assistant"),
created_date=m.created_date,
)
def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
return [prepare_message(m) for m in messages]
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
conv_messages = []
for message in messages:
@@ -26,13 +42,6 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
for message in tree:
if not isinstance(message.payload.payload, MessagePayload):
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
tree_messages.append(
protocol.Message(
id=message.id,
parent_id=message.parent_id,
text=message.payload.payload.text,
is_assistant=(message.role == "assistant"),
)
)
tree_messages.append(prepare_message(message))
return protocol.MessageTree(id=tree_id, messages=tree_messages)
+10 -6
View File
@@ -16,7 +16,7 @@ from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
from sqlalchemy import update
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class PromptRepository:
@@ -72,7 +72,7 @@ class PromptRepository:
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
@@ -88,7 +88,7 @@ class PromptRepository:
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
if task is None:
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
if task.expired:
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
if task.done or task.ack is not None:
@@ -108,7 +108,9 @@ class PromptRepository:
)
if fail_if_missing and message is None:
raise OasstError(
f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND
f"Message with frontend_message_id {frontend_message_id} not found.",
OasstErrorCode.MESSAGE_NOT_FOUND,
HTTP_404_NOT_FOUND,
)
return message
@@ -488,7 +490,7 @@ class PromptRepository:
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
if fail_if_missing and not message:
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND)
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
return message
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
@@ -499,7 +501,9 @@ class PromptRepository:
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
if not task:
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
raise OasstError(
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
)
if task.expired:
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
if not allow_personal_tasks and not task.collective:
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
import enum
from datetime import datetime
from typing import Literal, Optional, Union
from uuid import UUID, uuid4
@@ -41,6 +42,7 @@ class Conversation(BaseModel):
class Message(ConversationMessage):
id: UUID
parent_id: Optional[UUID] = None
created_date: Optional[datetime] = None
class MessageTree(BaseModel):