diff --git a/backend/app/models/db_payload.py b/backend/app/models/db_payload.py index ff2b5f6e..52eadb67 100644 --- a/backend/app/models/db_payload.py +++ b/backend/app/models/db_payload.py @@ -72,6 +72,14 @@ class RankConversationRepliesPayload(TaskPayload): replies: list[str] +@payload_type +class RankInitialPromptsPayload(TaskPayload): + """A task to rank a set of initial prompts.""" + + type: Literal["rank_initial_prompts"] = "rank_initial_prompts" + prompts: list[str] + + @payload_type class RankUserRepliesPayload(RankConversationRepliesPayload): """A task to rank a set of user replies to a conversation.""" diff --git a/backend/app/prompt_repository.py b/backend/app/prompt_repository.py index e4877e9d..43706d00 100644 --- a/backend/app/prompt_repository.py +++ b/backend/app/prompt_repository.py @@ -156,7 +156,9 @@ class PromptRepository: # fetch work_package work_package = self.fetch_workpackage_by_postid(ranking.post_id) - work_payload: db_payload.RankConversationRepliesPayload = work_package.payload.payload + work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = ( + work_package.payload.payload + ) match type(work_payload): @@ -176,6 +178,21 @@ class PromptRepository: return reaction + case db_payload.RankInitialPromptsPayload: + # validate ranking + if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))): + raise ValueError( + f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})." + ) + + # store reaction to post + reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking) + reaction = self.insert_reaction(post.id, reaction_payload) + + logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.") + + return reaction + case _: raise ValueError( f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}" @@ -201,6 +218,9 @@ class PromptRepository: case protocol_schema.AssistantReplyTask: payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation) + case protocol_schema.RankInitialPromptsTask: + payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts) + case protocol_schema.RankUserRepliesTask: payload = db_payload.RankUserRepliesPayload( tpye=task.type, conversation=task.conversation, replies=task.replies