diff --git a/backend/main.py b/backend/main.py index 1ddae390..1c93fc9f 100644 --- a/backend/main.py +++ b/backend/main.py @@ -136,9 +136,12 @@ if settings.DEBUG_USE_SEED_DATA: conversation = protocol_schema.Conversation( messages=[ protocol_schema.ConversationMessage( - text=msg.text, is_assistant=msg.role == "assistant" + text=cmsg.text, + is_assistant=cmsg.role == "assistant", + message_id=cmsg.id, + fronend_message_id=cmsg.frontend_message_id, ) - for msg in conversation_messages + for cmsg in conversation_messages ] ) task = pr.store_task( diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index c1671e79..3860bb07 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -57,7 +57,12 @@ def generate_task( logger.info("Generating a PrompterReplyTask.") messages = pr.fetch_random_conversation("assistant") task_messages = [ - protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant")) + protocol_schema.ConversationMessage( + text=msg.text, + is_assistant=(msg.role == "assistant"), + message_id=msg.id, + front_end_id=msg.front_end_id, + ) for msg in messages ] @@ -68,7 +73,12 @@ def generate_task( logger.info("Generating a AssistantReplyTask.") messages = pr.fetch_random_conversation("prompter") task_messages = [ - protocol_schema.ConversationMessage(text=msg.text, is_assistant=(msg.role == "assistant")) + protocol_schema.ConversationMessage( + text=msg.text, + is_assistant=(msg.role == "assistant"), + message_id=msg.id, + front_end_id=msg.front_end_id, + ) for msg in messages ] @@ -88,6 +98,8 @@ def generate_task( protocol_schema.ConversationMessage( text=p.text, is_assistant=(p.role == "assistant"), + message_id=p.id, + front_end_id=p.front_end_id, ) for p in conversation ] @@ -107,6 +119,8 @@ def generate_task( protocol_schema.ConversationMessage( text=p.text, is_assistant=(p.role == "assistant"), + message_id=p.id, + front_end_id=p.front_end_id, ) for p in conversation ] diff --git a/backend/oasst_backend/api/v1/utils.py b/backend/oasst_backend/api/v1/utils.py index 5299aab6..4e20395f 100644 --- a/backend/oasst_backend/api/v1/utils.py +++ b/backend/oasst_backend/api/v1/utils.py @@ -22,7 +22,12 @@ def prepare_conversation(messages: list[Message]) -> protocol.Conversation: conv_messages = [] for message in messages: conv_messages.append( - protocol.ConversationMessage(text=message.text, is_assistant=(message.role == "assistant")) + protocol.ConversationMessage( + text=message.text, + is_assistant=(message.role == "assistant"), + message_id=message.id, + frontend_message_id=message.frontend_message_id, + ) ) return protocol.Conversation(messages=conv_messages) diff --git a/oasst-shared/oasst_shared/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py index 1cafc93d..f8af590e 100644 --- a/oasst-shared/oasst_shared/schemas/protocol.py +++ b/oasst-shared/oasst_shared/schemas/protocol.py @@ -34,6 +34,8 @@ class ConversationMessage(BaseModel): text: str is_assistant: bool + message_id: Optional[UUID] = None + frontend_message_id: Optional[str] = None class Conversation(BaseModel):