From de28d67031df1af94d6acd2ae8de54c37541c6a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Mon, 16 Jan 2023 00:19:30 +0100 Subject: [PATCH] infer role from task in store_text_reply() --- backend/oasst_backend/prompt_repository.py | 23 +++++++++++++++---- .../exceptions/oasst_api_error.py | 1 + 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 08557ba8..6483cdc2 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -150,7 +150,7 @@ class PromptRepository: self._validate_task(task) # If there's no parent message assume user started new conversation - role = "prompter" + role = None depth = 0 if task.parent_message_id: @@ -170,10 +170,23 @@ class PromptRepository: self.db.add(parent_message) depth = parent_message.depth + 1 - if parent_message.role == "assistant": - role = "prompter" - else: - role = "assistant" + + task_payload: db_payload.TaskPayload = task.payload.payload + if isinstance(task_payload, db_payload.InitialPromptPayload): + role = "prompter" + elif isinstance(task_payload, db_payload.PrompterReplyPayload): + role = "prompter" + elif isinstance(task_payload, db_payload.AssistantReplyPayload): + role = "assistant" + elif isinstance(task_payload, db_payload.SummarizationStoryPayload): + raise NotImplementedError("SummarizationStory task not implemented.") + else: + raise OasstError( + f"Unexpected task payload type: {type(task_payload).__name__}", + OasstErrorCode.TASK_UNEXPECTED_PAYLOAD_TYPE_, + ) + + assert role in ("assistant", "prompter") # create reply message new_message_id = uuid4() diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index b4432252..31ba00f6 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -59,6 +59,7 @@ class OasstErrorCode(IntEnum): TASK_ALREADY_DONE = 2105 TASK_NOT_COLLECTIVE = 2106 TASK_NOT_ASSIGNED_TO_USER = 2106 + TASK_UNEXPECTED_PAYLOAD_TYPE_ = 2107 USER_NOT_FOUND = 2200 # 3000-4000: external resources