update user input validator

This commit is contained in:
Alex Ott
2022-12-30 17:44:20 -08:00
parent 37f30f4e31
commit 004a868cb4
+27 -17
View File
@@ -77,7 +77,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
return
# Invalid response
if event.content is None or not _validate_user_input(event.content, task.type):
if event.content is None or not _validate_user_input(event.content, task):
await ctx.author.send("Invalid response")
continue
@@ -250,35 +250,45 @@ async def _send_task(
return view.choice, str(msg.id)
# TODO check what the backend expects
def _validate_user_input(content: str | None, task_type: str) -> bool:
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
"""Returns whether the user's input is valid for the task type."""
if content is None:
return False
# User message input
if (
task_type == TaskRequestType.initial_prompt
or task_type == TaskRequestType.user_reply
or task_type == TaskRequestType.assistant_reply
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.user_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask
)
return len(content) > 0
elif (
task_type == TaskRequestType.rank_initial_prompts
or task_type == TaskRequestType.rank_user_replies
or task_type == TaskRequestType.rank_assistant_replies
):
rankings = [int(r) for r in content.split(",")]
return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5
# Ranking tasks
elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
elif task_type == TaskRequestType.summarize_story:
rankings = [int(r) for r in content.split(",")]
return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = [int(r) for r in content.split(",")]
return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task_type == TaskRequestType.rate_summary:
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task_type}")
raise ValueError(f"Unknown task type {task_type}")
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View):