mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-02 17:00:28 +08:00
update user input validator
This commit is contained in:
@@ -77,7 +77,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task.type):
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
continue
|
||||
|
||||
@@ -250,35 +250,45 @@ async def _send_task(
|
||||
return view.choice, str(msg.id)
|
||||
|
||||
|
||||
# TODO check what the backend expects
|
||||
def _validate_user_input(content: str | None, task_type: str) -> bool:
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
if content is None:
|
||||
return False
|
||||
|
||||
# User message input
|
||||
if (
|
||||
task_type == TaskRequestType.initial_prompt
|
||||
or task_type == TaskRequestType.user_reply
|
||||
or task_type == TaskRequestType.assistant_reply
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.user_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask
|
||||
)
|
||||
return len(content) > 0
|
||||
|
||||
elif (
|
||||
task_type == TaskRequestType.rank_initial_prompts
|
||||
or task_type == TaskRequestType.rank_user_replies
|
||||
or task_type == TaskRequestType.rank_assistant_replies
|
||||
):
|
||||
rankings = [int(r) for r in content.split(",")]
|
||||
return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
elif task_type == TaskRequestType.summarize_story:
|
||||
rankings = [int(r) for r in content.split(",")]
|
||||
return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = [int(r) for r in content.split(",")]
|
||||
return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task_type == TaskRequestType.rate_summary:
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"Unknown task type {task_type}")
|
||||
raise ValueError(f"Unknown task type {task_type}")
|
||||
logger.critical(f"Unknown task type {task.type}")
|
||||
raise ValueError(f"Unknown task type {task.type}")
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
|
||||
Reference in New Issue
Block a user