diff --git a/discord-bot/bot/extensions/EXAMPLES.md b/discord-bot/bot/extensions/EXAMPLES.md index f031cd72..29598fde 100644 --- a/discord-bot/bot/extensions/EXAMPLES.md +++ b/discord-bot/bot/extensions/EXAMPLES.md @@ -396,9 +396,6 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None: await view.start(await resp.message()) -# TODO: Database example -# TODO: Rest client example - def load(bot: lightbulb.BotApp): """Add the plugin to the bot.""" diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 1f278ca4..53d0a1fd 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -69,7 +69,6 @@ class LabelModal(miru.Modal): .add_field("Total Labeled Message", "0", inline=True) .add_field("Server Ranking", "0/0", inline=True) .add_field("Global Ranking", "0/0", inline=True) - .set_footer("Message ID: TODO") ) channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id) assert isinstance(channel, hikari.TextableChannel) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index ba71f41b..8e3ad7b5 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -31,7 +31,7 @@ logger.setLevel(logging.DEBUG) "The type of task to request.", choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], required=False, - default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random + default=str(TaskRequestType.random), type=str, ) @lightbulb.command("work", "Complete a task.") @@ -79,11 +79,11 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) return # Invalid response - if event.content is None: - await ctx.author.send("No content in message") + if event.content is None or not _validate_user_input(event.content, task.type): + await ctx.author.send("Invalid response") continue - logger.info(f"User input received: {event.content}") + logger.info(f"Successful user input received: {event.content}") # Send the response to the backend reply = protocol_schema.TextReplyToPost( @@ -108,7 +108,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) logger.fatal(f"Unexpected task type received: {new_task.type}") # Send a message in the log channel that the task is complete - # TODO: Maybe do something with the msg ID + # TODO: Maybe do something with the msg ID so users can rate the "answer" assert ctx.guild_id is not None conn: Connection = ctx.bot.d.db guild_settings = await GuildSettings.from_db(conn, ctx.guild_id) @@ -252,6 +252,37 @@ async def _send_task( return view.choice, str(msg.id) +# TODO check what the backend expects +def _validate_user_input(content: str | None, task_type: str) -> bool: + """Returns whether the user's input is valid for the task type.""" + if content is None: + return False + + if ( + task_type == TaskRequestType.initial_prompt + or task_type == TaskRequestType.user_reply + or task_type == TaskRequestType.assistant_reply + ): + return len(content) > 0 + + elif ( + task_type == TaskRequestType.rank_initial_prompts + or task_type == TaskRequestType.rank_user_replies + or task_type == TaskRequestType.rank_assistant_replies + ): + rankings = [int(r) for r in content.split(",")] + return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5 + + elif task_type == TaskRequestType.summarize_story: + raise NotImplementedError + elif task_type == TaskRequestType.rate_summary: + raise NotImplementedError + + else: + logger.fatal(f"Unknown task type {task_type}") + raise ValueError(f"Unknown task type {task_type}") + + class TaskAcceptView(miru.View): """View with three buttons: accept, next, and cancel.