Merge pull request #15 from LAION-AI/add-rank-initial-prompts-payload

added rank_initial_prompts payload and db saving
This commit is contained in:
Yannic Kilcher
2022-12-16 22:38:22 +01:00
committed by GitHub
2 changed files with 29 additions and 1 deletions
+8
View File
@@ -72,6 +72,14 @@ class RankConversationRepliesPayload(TaskPayload):
replies: list[str]
@payload_type
class RankInitialPromptsPayload(TaskPayload):
"""A task to rank a set of initial prompts."""
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
prompts: list[str]
@payload_type
class RankUserRepliesPayload(RankConversationRepliesPayload):
"""A task to rank a set of user replies to a conversation."""
+21 -1
View File
@@ -156,7 +156,9 @@ class PromptRepository:
# fetch work_package
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
work_payload: db_payload.RankConversationRepliesPayload = work_package.payload.payload
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
work_package.payload.payload
)
match type(work_payload):
@@ -176,6 +178,21 @@ class PromptRepository:
return reaction
case db_payload.RankInitialPromptsPayload:
# validate ranking
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
raise ValueError(
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
)
# store reaction to post
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
reaction = self.insert_reaction(post.id, reaction_payload)
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
return reaction
case _:
raise ValueError(
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
@@ -201,6 +218,9 @@ class PromptRepository:
case protocol_schema.AssistantReplyTask:
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
case protocol_schema.RankInitialPromptsTask:
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
case protocol_schema.RankUserRepliesTask:
payload = db_payload.RankUserRepliesPayload(
tpye=task.type, conversation=task.conversation, replies=task.replies