diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 5244920b..28ef64c2 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -77,7 +77,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) return # Invalid response - if event.content is None or not _validate_user_input(event.content, task.type): + if event.content is None or not _validate_user_input(event.content, task): await ctx.author.send("Invalid response") continue @@ -250,35 +250,45 @@ async def _send_task( return view.choice, str(msg.id) -# TODO check what the backend expects -def _validate_user_input(content: str | None, task_type: str) -> bool: +def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool: """Returns whether the user's input is valid for the task type.""" if content is None: return False + # User message input if ( - task_type == TaskRequestType.initial_prompt - or task_type == TaskRequestType.user_reply - or task_type == TaskRequestType.assistant_reply + task.type == TaskRequestType.initial_prompt + or task.type == TaskRequestType.user_reply + or task.type == TaskRequestType.assistant_reply ): + assert isinstance( + task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask + ) return len(content) > 0 - elif ( - task_type == TaskRequestType.rank_initial_prompts - or task_type == TaskRequestType.rank_user_replies - or task_type == TaskRequestType.rank_assistant_replies - ): - rankings = [int(r) for r in content.split(",")] - return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5 + # Ranking tasks + elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies: + assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask) + num_replies = len(task.replies) - elif task_type == TaskRequestType.summarize_story: + rankings = [int(r) for r in content.split(",")] + return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies + + elif task.type == TaskRequestType.rank_initial_prompts: + assert isinstance(task, protocol_schema.RankInitialPromptsTask) + num_prompts = len(task.prompts) + + rankings = [int(r) for r in content.split(",")] + return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts + + elif task.type == TaskRequestType.summarize_story: raise NotImplementedError - elif task_type == TaskRequestType.rate_summary: + elif task.type == TaskRequestType.rate_summary: raise NotImplementedError else: - logger.critical(f"Unknown task type {task_type}") - raise ValueError(f"Unknown task type {task_type}") + logger.critical(f"Unknown task type {task.type}") + raise ValueError(f"Unknown task type {task.type}") class TaskAcceptView(miru.View):