From 96d6717be40fb06f93fbdee5ee48127f7620a2e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Sat, 7 Jan 2023 15:59:54 +0100 Subject: [PATCH] Fetch conversation for seed data tasks, minor model fixes (#485) * Fetch conversation for seed data, fix models, remove redundant payload type checks --- backend/main.py | 15 ++++++++---- .../oasst_backend/api/v1/frontend_messages.py | 7 ------ backend/oasst_backend/api/v1/messages.py | 6 ----- backend/oasst_backend/api/v1/tasks.py | 24 ++++++++----------- backend/oasst_backend/api/v1/utils.py | 13 ++-------- backend/oasst_backend/models/journal.py | 10 ++++---- backend/oasst_backend/models/message.py | 24 +++++++++++++++---- backend/oasst_backend/models/task.py | 2 +- .../test_data/generic/test_generic_data.json | 14 +++++++++++ 9 files changed, 61 insertions(+), 54 deletions(-) diff --git a/backend/main.py b/backend/main.py index 2a3bb230..1ddae390 100644 --- a/backend/main.py +++ b/backend/main.py @@ -132,12 +132,17 @@ if settings.DEBUG_USE_SEED_DATA: parent_message = pr.fetch_message_by_frontend_message_id( msg.parent_message_id, fail_if_missing=True ) - task = pr.store_task( - protocol_schema.AssistantReplyTask( - conversation=protocol_schema.Conversation( - messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)] + conversation_messages = pr.fetch_message_conversation(parent_message) + conversation = protocol_schema.Conversation( + messages=[ + protocol_schema.ConversationMessage( + text=msg.text, is_assistant=msg.role == "assistant" ) - ), + for msg in conversation_messages + ] + ) + task = pr.store_task( + protocol_schema.AssistantReplyTask(conversation=conversation), message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) diff --git a/backend/oasst_backend/api/v1/frontend_messages.py b/backend/oasst_backend/api/v1/frontend_messages.py index 956d9992..420f0d1b 100644 --- a/backend/oasst_backend/api/v1/frontend_messages.py +++ b/backend/oasst_backend/api/v1/frontend_messages.py @@ -2,9 +2,7 @@ from fastapi import APIRouter, Depends from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient -from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol from sqlmodel import Session @@ -20,11 +18,6 @@ def get_message_by_frontend_id( """ pr = PromptRepository(db, api_client, user=None) message = pr.fetch_message_by_frontend_message_id(message_id) - - if not isinstance(message.payload.payload, MessagePayload): - # Unexpected message payload - raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) - return utils.prepare_message(message) diff --git a/backend/oasst_backend/api/v1/messages.py b/backend/oasst_backend/api/v1/messages.py index 951355b3..7a2fd2e9 100644 --- a/backend/oasst_backend/api/v1/messages.py +++ b/backend/oasst_backend/api/v1/messages.py @@ -5,9 +5,7 @@ from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient -from oasst_backend.models.db_payload import MessagePayload from oasst_backend.prompt_repository import PromptRepository -from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -55,10 +53,6 @@ def get_message( """ pr = PromptRepository(db, api_client, user=None) message = pr.fetch_message(message_id) - if not isinstance(message.payload.payload, MessagePayload): - # Unexptcted message payload - raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE) - return utils.prepare_message(message) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index 9f81eabb..c1671e79 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -57,9 +57,7 @@ def generate_task( logger.info("Generating a PrompterReplyTask.") messages = pr.fetch_random_conversation("assistant") task_messages = [ - protocol_schema.ConversationMessage( - text=msg.payload.payload.text, is_assistant=(msg.role == "assistant") - ) + protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant")) for msg in messages ] @@ -70,9 +68,7 @@ def generate_task( logger.info("Generating a AssistantReplyTask.") messages = pr.fetch_random_conversation("prompter") task_messages = [ - protocol_schema.ConversationMessage( - text=msg.payload.payload.text, is_assistant=(msg.role == "assistant") - ) + protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant")) for msg in messages ] @@ -83,19 +79,19 @@ def generate_task( logger.info("Generating a RankInitialPromptsTask.") messages = pr.fetch_random_initial_prompts() - task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages]) + task = protocol_schema.RankInitialPromptsTask(prompts=[msg.text for msg in messages]) case protocol_schema.TaskRequestType.rank_prompter_replies: logger.info("Generating a RankPrompterRepliesTask.") conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant") task_messages = [ protocol_schema.ConversationMessage( - text=p.payload.payload.text, + text=p.text, is_assistant=(p.role == "assistant"), ) for p in conversation ] - replies = [p.payload.payload.text for p in replies] + replies = [p.text for p in replies] task = protocol_schema.RankPrompterRepliesTask( conversation=protocol_schema.Conversation( messages=task_messages, @@ -109,12 +105,12 @@ def generate_task( task_messages = [ protocol_schema.ConversationMessage( - text=p.payload.payload.text, + text=p.text, is_assistant=(p.role == "assistant"), ) for p in conversation ] - replies = [p.payload.payload.text for p in replies] + replies = [p.text for p in replies] task = protocol_schema.RankAssistantRepliesTask( conversation=protocol_schema.Conversation(messages=task_messages), replies=replies, @@ -125,14 +121,14 @@ def generate_task( message = pr.fetch_random_initial_prompts(1)[0] task = protocol_schema.LabelInitialPromptTask( message_id=message.id, - prompt=message.payload.payload.text, + prompt=message.text, valid_labels=list(map(lambda x: x.value, protocol_schema.TextLabel)), ) case protocol_schema.TaskRequestType.label_prompter_reply: logger.info("Generating a LabelPrompterReplyTask.") conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="assistant") - message = messages[0].payload.payload.text + message = messages[0].text task = protocol_schema.LabelPrompterReplyTask( message_id=message.id, conversation=conversation, @@ -143,7 +139,7 @@ def generate_task( case protocol_schema.TaskRequestType.label_assistant_reply: logger.info("Generating a LabelAssistantReplyTask.") conversation, messages = pr.fetch_multiple_random_replies(max_size=1, message_role="prompter") - message = messages[0].payload.payload.text + message = messages[0].text task = protocol_schema.LabelAssistantReplyTask( message_id=message.id, conversation=conversation, diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 55a7c572..5299aab6 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -1,19 +1,14 @@ -from http import HTTPStatus from uuid import UUID from oasst_backend.models import Message -from oasst_backend.models.db_payload import MessagePayload -from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol def prepare_message(m: Message) -> protocol.Message: - if not isinstance(m.payload.payload, MessagePayload): - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) return protocol.Message( id=m.id, parent_id=m.parent_id, - text=m.payload.payload.text, + text=m.text, is_assistant=(m.role == "assistant"), created_date=m.created_date, ) @@ -26,10 +21,8 @@ def prepare_message_list(messages: list[Message]) -> list[protocol.Message]: def prepare_conversation(messages: list[Message]) -> protocol.Conversation: conv_messages = [] for message in messages: - if not isinstance(message.payload.payload, MessagePayload): - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) conv_messages.append( - protocol.ConversationMessage(text=message.payload.payload.text, is_assistant=(message.role == "assistant")) + protocol.ConversationMessage(text=message.text, is_assistant=(message.role == "assistant")) ) return protocol.Conversation(messages=conv_messages) @@ -38,8 +31,6 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation: def prepare_tree(tree: list[Message], tree_id: UUID) -> protocol.MessageTree: tree_messages = [] for message in tree: - if not isinstance(message.payload.payload, MessagePayload): - raise OasstError("Server error", OasstErrorCode.SERVER_ERROR, HTTPStatus.INTERNAL_SERVER_ERROR) tree_messages.append(prepare_message(message)) return protocol.MessageTree(id=tree_id, messages=tree_messages) diff --git a/backend/oasst_backend/models/journal.py b/backend/oasst_backend/models/journal.py index 0d5a78af..b5000add 100644 --- a/backend/oasst_backend/models/journal.py +++ b/backend/oasst_backend/models/journal.py @@ -32,7 +32,7 @@ class Journal(SQLModel, table=True): created_date: Optional[datetime] = Field( sa_column=sa.Column(sa.DateTime(timezone=True), nullable=False, server_default=sa.func.current_timestamp()) ) - user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True) + user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) message_id: Optional[UUID] = Field(foreign_key="message.id", nullable=True) api_client_id: UUID = Field(foreign_key="api_client.id") @@ -49,7 +49,7 @@ class JournalIntegration(SQLModel, table=True): ), ) description: str = Field(max_length=512, primary_key=True) - last_journal_id: UUID = Field(foreign_key="journal.id", nullable=True) - last_run: datetime = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) - last_error: str = Field(nullable=True) - next_run: datetime = Field(nullable=True) + last_journal_id: Optional[UUID] = Field(foreign_key="journal.id", nullable=True) + last_run: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) + last_error: Optional[str] = Field(nullable=True) + next_run: Optional[datetime] = Field(nullable=True) diff --git a/backend/oasst_backend/models/message.py b/backend/oasst_backend/models/message.py index f07ca881..6d24fd13 100644 --- a/backend/oasst_backend/models/message.py +++ b/backend/oasst_backend/models/message.py @@ -1,9 +1,12 @@ from datetime import datetime +from http import HTTPStatus from typing import Optional from uuid import UUID, uuid4 import sqlalchemy as sa import sqlalchemy.dialects.postgresql as pg +from oasst_backend.models.db_payload import MessagePayload +from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from sqlalchemy import false from sqlmodel import Field, Index, SQLModel @@ -19,19 +22,30 @@ class Message(SQLModel, table=True): pg.UUID(as_uuid=True), primary_key=True, default=uuid4, server_default=sa.text("gen_random_uuid()") ), ) - parent_id: UUID = Field(nullable=True) + parent_id: Optional[UUID] = Field(nullable=True) message_tree_id: UUID = Field(nullable=False, index=True) - task_id: UUID = Field(nullable=True, index=True) - user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True) - role: str = Field(nullable=False, max_length=128) # valid: "prompter" | "assistant" + task_id: Optional[UUID] = Field(nullable=True, index=True) + user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) + role: str = Field(nullable=False, max_length=128, regex="^prompter|assistant$") api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") frontend_message_id: str = Field(max_length=200, nullable=False) created_date: Optional[datetime] = Field( sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) ) payload_type: str = Field(nullable=False, max_length=200) - payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True)) + payload: Optional[PayloadContainer] = Field( + sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=True) + ) lang: str = Field(nullable=False, max_length=200, default="en-US") depth: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) children_count: int = Field(sa_column=sa.Column(sa.Integer, default=0, server_default=sa.text("0"), nullable=False)) deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=false())) + + def ensure_is_message(self) -> None: + if not self.payload or not isinstance(self.payload.payload, MessagePayload): + raise OasstError("Invalid message", OasstErrorCode.INVALID_MESSAGE, HTTPStatus.INTERNAL_SERVER_ERROR) + + @property + def text(self) -> str: + self.ensure_is_message() + return self.payload.payload.text diff --git a/backend/oasst_backend/models/task.py b/backend/oasst_backend/models/task.py index 356eafea..a980c1b5 100644 --- a/backend/oasst_backend/models/task.py +++ b/backend/oasst_backend/models/task.py @@ -22,7 +22,7 @@ class Task(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()), ) expiry_date: Optional[datetime] = Field(sa_column=sa.Column(sa.DateTime(), nullable=True)) - user_id: UUID = Field(nullable=True, foreign_key="user.id", index=True) + user_id: Optional[UUID] = Field(nullable=True, foreign_key="user.id", index=True) payload_type: str = Field(nullable=False, max_length=200) payload: PayloadContainer = Field(sa_column=sa.Column(payload_column_type(PayloadContainer), nullable=False)) api_client_id: UUID = Field(nullable=False, foreign_key="api_client.id") diff --git a/backend/test_data/generic/test_generic_data.json b/backend/test_data/generic/test_generic_data.json index b634902d..4b4b4a39 100644 --- a/backend/test_data/generic/test_generic_data.json +++ b/backend/test_data/generic/test_generic_data.json @@ -54,5 +54,19 @@ "parent_message_id": "cec432cf", "text": "I'm unsure how to interpret this. Is it a riddle?", "role": "assistant" + }, + { + "task_message_id": "b8e98ed6", + "user_message_id": "89384709", + "parent_message_id": "0e276b98", + "text": "No, I just wanted to see how you reply when I type random characters. Can you tell me who invented Wikipedia?", + "role": "prompter" + }, + { + "task_message_id": "9a0e7683", + "user_message_id": "6d452c57", + "parent_message_id": "0e276b98", + "text": "Sorry, my cat sat on my keyboard. Can you print a cat in ASCII art?", + "role": "prompter" } ]