From 534c99610beb2eced2275f66046a11fc65de686d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Fri, 30 Dec 2022 23:56:16 +0100 Subject: [PATCH] missing 'user' -> 'prompter' replacement --- backend/oasst_backend/api/v1/tasks.py | 2 +- backend/oasst_backend/prompt_repository.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index bb8c2efc..60a19281 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -69,7 +69,7 @@ def generate_task( parent_message_id = messages[-1].id case protocol_schema.TaskRequestType.assistant_reply: logger.info("Generating a AssistantReplyTask.") - messages = pr.fetch_random_conversation("user") + messages = pr.fetch_random_conversation("prompter") task_messages = [ protocol_schema.ConversationMessage( text=msg.payload.payload.text, is_assistant=(msg.role == "assistant") diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 5606ca69..15ed3816 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -419,22 +419,22 @@ class PromptRepository: the user should reply as a human and hence the last message of the conversation needs to have "assistant" role. """ - mt_messages = self.fetch_random_message_tree(last_message_role) - if not mt_messages: + messages_tree = self.fetch_random_message_tree(last_message_role) + if not messages_tree: raise OasstError("No message tree found", OasstErrorCode.NO_MESSAGE_TREE_FOUND) if last_message_role: - conv_messages = [m for m in mt_messages if m.role == last_message_role] + conv_messages = [m for m in messages_tree if m.role == last_message_role] conv_messages = [random.choice(conv_messages)] else: - conv_messages = [random.choice(mt_messages)] - mt_messages = {m.id: m for m in mt_messages} + conv_messages = [random.choice(messages_tree)] + messages_tree = {m.id: m for m in messages_tree} while True: if not conv_messages[-1].parent_id: # reached the start of the conversation break - parent_message = mt_messages[conv_messages[-1].parent_id] + parent_message = messages_tree[conv_messages[-1].parent_id] conv_messages.append(parent_message) return list(reversed(conv_messages))