mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add to protocol.Message
This commit is contained in:
@@ -7,7 +7,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
router = APIRouter()
|
||||
@@ -27,7 +26,7 @@ def get_message_by_frontend_id(
|
||||
# Unexpected message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@@ -68,12 +67,7 @@ def get_children_by_frontend_id(
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
messages = pr.fetch_message_children(message.id)
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
|
||||
@@ -4,9 +4,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.responses import Response
|
||||
from starlette.status import HTTP_200_OK
|
||||
@@ -41,13 +41,7 @@ def query_frontend_user_messages(
|
||||
only_roots=only_roots,
|
||||
deleted=None if include_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.delete("/{username}/messages")
|
||||
|
||||
@@ -10,7 +10,6 @@ from oasst_backend.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_200_OK
|
||||
|
||||
@@ -45,12 +44,7 @@ def query_messages(
|
||||
deleted=None if allow_deleted else False,
|
||||
)
|
||||
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}")
|
||||
@@ -66,7 +60,7 @@ def get_message(
|
||||
# Unexptcted message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
@router.get("/{message_id}/conversation")
|
||||
@@ -104,12 +98,7 @@ def get_children(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
messages = pr.fetch_message_children(message_id)
|
||||
return [
|
||||
protocol.Message(
|
||||
id=m.id, parent_id=m.parent_id, text=m.payload.payload.text, is_assistant=(m.role == "assistant")
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
return utils.prepare_message_list(messages)
|
||||
|
||||
|
||||
@router.get("/{message_id}/descendants")
|
||||
|
||||
@@ -9,6 +9,22 @@ from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_message(m: Message) -> protocol.Message:
|
||||
if not isinstance(m.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return protocol.Message(
|
||||
id=m.id,
|
||||
parent_id=m.parent_id,
|
||||
text=m.payload.payload.text,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
)
|
||||
|
||||
|
||||
def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
return [prepare_message(m) for m in messages]
|
||||
|
||||
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
@@ -26,13 +42,6 @@ def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(
|
||||
protocol.Message(
|
||||
id=message.id,
|
||||
parent_id=message.parent_id,
|
||||
text=message.payload.payload.text,
|
||||
is_assistant=(message.role == "assistant"),
|
||||
)
|
||||
)
|
||||
tree_messages.append(prepare_message(message))
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
@@ -16,7 +16,7 @@ from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
from sqlalchemy import update
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class PromptRepository:
|
||||
@@ -72,7 +72,7 @@ class PromptRepository:
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
@@ -88,7 +88,7 @@ class PromptRepository:
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
if task is None:
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)
|
||||
if task.done or task.ack is not None:
|
||||
@@ -108,7 +108,9 @@ class PromptRepository:
|
||||
)
|
||||
if fail_if_missing and message is None:
|
||||
raise OasstError(
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.", OasstErrorCode.MESSAGE_NOT_FOUND
|
||||
f"Message with frontend_message_id {frontend_message_id} not found.",
|
||||
OasstErrorCode.MESSAGE_NOT_FOUND,
|
||||
HTTP_404_NOT_FOUND,
|
||||
)
|
||||
return message
|
||||
|
||||
@@ -488,7 +490,7 @@ class PromptRepository:
|
||||
def fetch_message(self, message_id: UUID, fail_if_missing: bool = True) -> Optional[Message]:
|
||||
message = self.db.query(Message).filter(Message.id == message_id).one_or_none()
|
||||
if fail_if_missing and not message:
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND)
|
||||
raise OasstError("Message not found", OasstErrorCode.MESSAGE_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
return message
|
||||
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
@@ -499,7 +501,9 @@ class PromptRepository:
|
||||
task = self.fetch_task_by_frontend_message_id(frontend_message_id)
|
||||
|
||||
if not task:
|
||||
raise OasstError(f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND)
|
||||
raise OasstError(
|
||||
f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
|
||||
)
|
||||
if task.expired:
|
||||
raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
|
||||
if not allow_personal_tasks and not task.collective:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
@@ -41,6 +42,7 @@ class Conversation(BaseModel):
|
||||
class Message(ConversationMessage):
|
||||
id: UUID
|
||||
parent_id: Optional[UUID] = None
|
||||
created_date: Optional[datetime] = None
|
||||
|
||||
|
||||
class MessageTree(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user