mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Fetch conversation for seed data tasks, minor model fixes (#485)
* Fetch conversation for seed data, fix models, remove redundant payload type checks
This commit is contained in:
+10
-5
@@ -132,12 +132,17 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)]
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = protocol_schema.Conversation(
|
||||
messages=[
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.text, is_assistant=msg.role == "assistant"
|
||||
)
|
||||
),
|
||||
for msg in conversation_messages
|
||||
]
|
||||
)
|
||||
task = pr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
|
||||
@@ -2,9 +2,7 @@ from fastapi import APIRouter, Depends
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
|
||||
@@ -20,11 +18,6 @@ def get_message_by_frontend_id(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message_by_frontend_message_id(message_id)
|
||||
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexpected message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -55,10 +53,6 @@ def get_message(
|
||||
"""
|
||||
pr = PromptRepository(db, api_client, user=None)
|
||||
message = pr.fetch_message(message_id)
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
# Unexptcted message payload
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE)
|
||||
|
||||
return utils.prepare_message(message)
|
||||
|
||||
|
||||
|
||||
@@ -57,9 +57,7 @@ def generate_task(
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant"))
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
@@ -70,9 +68,7 @@ def generate_task(
|
||||
logger.info("Generating a AssistantReplyTask.")
|
||||
messages = pr.fetch_random_conversation("prompter")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=msg.payload.payload.text, is_assistant=(msg.role == "assistant")
|
||||
)
|
||||
protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant"))
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
@@ -83,19 +79,19 @@ def generate_task(
|
||||
logger.info("Generating a RankInitialPromptsTask.")
|
||||
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankPrompterRepliesTask(
|
||||
conversation=protocol_schema.Conversation(
|
||||
messages=task_messages,
|
||||
@@ -109,12 +105,12 @@ def generate_task(
|
||||
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
text=p.payload.payload.text,
|
||||
text=p.text,
|
||||
is_assistant=(p.role == "assistant"),
|
||||
)
|
||||
for p in conversation
|
||||
]
|
||||
replies = [p.payload.payload.text for p in replies]
|
||||
replies = [p.text for p in replies]
|
||||
task = protocol_schema.RankAssistantRepliesTask(
|
||||
conversation=protocol_schema.Conversation(messages=task_messages),
|
||||
replies=replies,
|
||||
@@ -125,14 +121,14 @@ def generate_task(
|
||||
message = pr.fetch_random_initial_prompts(1)[0]
|
||||
task = protocol_schema.LabelInitialPromptTask(
|
||||
message_id=message.id,
|
||||
prompt=message.payload.payload.text,
|
||||
prompt=message.text,
|
||||
valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)),
|
||||
)
|
||||
|
||||
case protocol_schema.TaskRequestType.label_prompter_reply:
|
||||
logger.info("Generating a LabelPrompterReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant")
|
||||
message = messages[0].payload.payload.text
|
||||
message = messages[0].text
|
||||
task = protocol_schema.LabelPrompterReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
@@ -143,7 +139,7 @@ def generate_task(
|
||||
case protocol_schema.TaskRequestType.label_assistant_reply:
|
||||
logger.info("Generating a LabelAssistantReplyTask.")
|
||||
conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter")
|
||||
message = messages[0].payload.payload.text
|
||||
message = messages[0].text
|
||||
task = protocol_schema.LabelAssistantReplyTask(
|
||||
message_id=message.id,
|
||||
conversation=conversation,
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import Message
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol
|
||||
|
||||
|
||||
def prepare_message(m: Message) -> protocol.Message:
|
||||
if not isinstance(m.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
return protocol.Message(
|
||||
id=m.id,
|
||||
parent_id=m.parent_id,
|
||||
text=m.payload.payload.text,
|
||||
text=m.text,
|
||||
is_assistant=(m.role == "assistant"),
|
||||
created_date=m.created_date,
|
||||
)
|
||||
@@ -26,10 +21,8 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]:
|
||||
def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
conv_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
conv_messages.append(
|
||||
protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant"))
|
||||
protocol.ConversationMessage(text=message.text, is_assistant=(message.role == "assistant"))
|
||||
)
|
||||
|
||||
return protocol.Conversation(messages=conv_messages)
|
||||
@@ -38,8 +31,6 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation:
|
||||
def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree:
|
||||
tree_messages = []
|
||||
for message in tree:
|
||||
if not isinstance(message.payload.payload, MessagePayload):
|
||||
raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
tree_messages.append(prepare_message(message))
|
||||
|
||||
return protocol.MessageTree(id=tree_id, messages=tree_messages)
|
||||
|
||||
@@ -32,7 +32,7 @@ class Journal(SQLModel, table=True):
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
|
||||
@@ -49,7 +49,7 @@ class JournalIntegration(SQLModel, table=True):
|
||||
),
|
||||
)
|
||||
description: str = Field(max_length=512, primary_key=True)
|
||||
last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: str = Field(nullable=True)
|
||||
next_run: datetime = Field(nullable=True)
|
||||
last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True)
|
||||
last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
last_error: Optional[str] = Field(nullable=True)
|
||||
next_run: Optional[datetime] = Field(nullable=True)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_backend.models.db_payload import MessagePayload
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from sqlalchemy import false
|
||||
from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
@@ -19,19 +22,30 @@ class Message(SQLModel, table=True):
|
||||
pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()")
|
||||
),
|
||||
)
|
||||
parent_id: UUID = Field(nullable=True)
|
||||
parent_id: Optional[UUID] = Field(nullable=True)
|
||||
message_tree_id: UUID = Field(nullable=False, index=True)
|
||||
task_id: UUID = Field(nullable=True, index=True)
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant"
|
||||
task_id: Optional[UUID] = Field(nullable=True, index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$")
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
frontend_message_id: str = Field(max_length=200, nullable=False)
|
||||
created_date: Optional[datetime] = Field(
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True))
|
||||
payload: Optional[PayloadContainer] = Field(
|
||||
sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)
|
||||
)
|
||||
lang: str = Field(nullable=False, max_length=200, default="en-US")
|
||||
depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False))
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false()))
|
||||
|
||||
def ensure_is_message(self) -> None:
|
||||
if not self.payload or not isinstance(self.payload.payload, MessagePayload):
|
||||
raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
self.ensure_is_message()
|
||||
return self.payload.payload.text
|
||||
|
||||
@@ -22,7 +22,7 @@ class Task(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()),
|
||||
)
|
||||
expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True))
|
||||
user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True)
|
||||
payload_type: str = Field(nullable=False, max_length=200)
|
||||
payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False))
|
||||
api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id")
|
||||
|
||||
@@ -54,5 +54,19 @@
|
||||
"parent_message_id": "cec432cf",
|
||||
"text": "I'm unsure how to interpret this. Is it a riddle?",
|
||||
"role": "assistant"
|
||||
},
|
||||
{
|
||||
"task_message_id": "b8e98ed6",
|
||||
"user_message_id": "89384709",
|
||||
"parent_message_id": "0e276b98",
|
||||
"text": "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?",
|
||||
"role": "prompter"
|
||||
},
|
||||
{
|
||||
"task_message_id": "9a0e7683",
|
||||
"user_message_id": "6d452c57",
|
||||
"parent_message_id": "0e276b98",
|
||||
"text": "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?",
|
||||
"role": "prompter"
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user