mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
Merge pull request #15 from LAION-AI/add-rank-initial-prompts-payload
added rank_initial_prompts payload and db saving
This commit is contained in:
@@ -72,6 +72,14 @@ class RankConversationRepliesPayload(TaskPayload):
|
||||
replies: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankInitialPromptsPayload(TaskPayload):
|
||||
"""A task to rank a set of initial prompts."""
|
||||
|
||||
type: Literal["rank_initial_prompts"] = "rank_initial_prompts"
|
||||
prompts: list[str]
|
||||
|
||||
|
||||
@payload_type
|
||||
class RankUserRepliesPayload(RankConversationRepliesPayload):
|
||||
"""A task to rank a set of user replies to a conversation."""
|
||||
|
||||
@@ -156,7 +156,9 @@ class PromptRepository:
|
||||
|
||||
# fetch work_package
|
||||
work_package = self.fetch_workpackage_by_postid(ranking.post_id)
|
||||
work_payload: db_payload.RankConversationRepliesPayload = work_package.payload.payload
|
||||
work_payload: db_payload.RankConversationRepliesPayload | db_payload.RankInitialPromptsPayload = (
|
||||
work_package.payload.payload
|
||||
)
|
||||
|
||||
match type(work_payload):
|
||||
|
||||
@@ -176,6 +178,21 @@ class PromptRepository:
|
||||
|
||||
return reaction
|
||||
|
||||
case db_payload.RankInitialPromptsPayload:
|
||||
# validate ranking
|
||||
if sorted(ranking.ranking) != list(range(num_prompts := len(work_payload.prompts))):
|
||||
raise ValueError(
|
||||
f"Invalid ranking submitted. Each reply index must appear exactly once ({num_prompts=})."
|
||||
)
|
||||
|
||||
# store reaction to post
|
||||
reaction_payload = db_payload.RankingReactionPayload(ranking=ranking.ranking)
|
||||
reaction = self.insert_reaction(post.id, reaction_payload)
|
||||
|
||||
logger.info(f"Ranking {ranking.ranking} stored for work_package {work_package.id}.")
|
||||
|
||||
return reaction
|
||||
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"work_package payload type mismatch: {type(work_payload)=} != {db_payload.RankConversationRepliesPayload}"
|
||||
@@ -201,6 +218,9 @@ class PromptRepository:
|
||||
case protocol_schema.AssistantReplyTask:
|
||||
payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)
|
||||
|
||||
case protocol_schema.RankInitialPromptsTask:
|
||||
payload = db_payload.RankInitialPromptsPayload(tpye=task.type, prompts=task.prompts)
|
||||
|
||||
case protocol_schema.RankUserRepliesTask:
|
||||
payload = db_payload.RankUserRepliesPayload(
|
||||
tpye=task.type, conversation=task.conversation, replies=task.replies
|
||||
|
||||
Reference in New Issue
Block a user