From dfd2c352764440bbc4f5877d0c26b1a29d107afa Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 2 Feb 2023 08:28:00 -0800 Subject: [PATCH] Feat/task handler (#1056) * move task logic to task handlers * rename command * remove test code * rename classes and add missing handler creations * switch task back to random and fetch log_channel_id from db --- discord-bot/.env.example | 2 +- discord-bot/bot/extensions/work.py | 779 ++++++++++++------------ discord-bot/bot/messages.py | 198 +++--- oasst-shared/oasst_shared/api_client.py | 12 +- 4 files changed, 526 insertions(+), 465 deletions(-) diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 8474ee90..33262896 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -4,4 +4,4 @@ OWNER_IDS=[, ] PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs OASST_API_URL="http://localhost:8080" # No trailing '/' -OASST_API_KEY="" +OASST_API_KEY="1234" diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 51daca3b..7a57265f 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -9,27 +9,25 @@ import lightbulb.decorators import miru from aiosqlite import Connection from bot.messages import ( - assistant_reply_message, + assistant_reply_messages, confirm_label_response_message, confirm_ranking_response_message, confirm_text_response_message, - initial_prompt_message, - invalid_user_input_embed, - label_assistant_reply_message, - label_initial_prompt_message, - label_prompter_reply_message, + initial_prompt_messages, + label_assistant_reply_messages, + label_prompter_reply_messages, plain_embed, - prompter_reply_message, + prompter_reply_messages, rank_assistant_reply_message, - rank_initial_prompts_message, - rank_prompter_reply_message, + rank_conversation_reply_messages, + rank_initial_prompts_messages, + rank_prompter_reply_messages, task_complete_embed, ) from bot.settings import Settings from loguru import logger -from oasst_shared.api_client import OasstApiClient, TaskType +from oasst_shared.api_client import OasstApiClient from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import TaskRequestType plugin = lightbulb.Plugin("WorkPlugin") @@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds settings = Settings() +_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True) + + +class _TaskHandler(t.Generic[_Task_contra]): + """Handle user interaction for a task.""" + + def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None: + """Create a new `TaskHandler`. + + Args: + ctx (lightbulb.Context): The context of the command that started the task. + task (_Task_contra): The task to handle. + """ + self.ctx = ctx + self.task = task + self.task_messages = self.get_task_messages(task) + self.sent_messages: list[hikari.Message] = [] + + @staticmethod + def get_task_messages(task: _Task_contra) -> list[str]: + """Get the messages to send to the user for the task.""" + raise NotImplementedError + + async def send(self) -> t.Literal["accept", "next", "cancel"] | None: + """Send the task and wait for the user to accept/skip/cancel it.""" + # Send all but the last message because we need to attach buttons to the last one + logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}") + for task_msg in self.task_messages[:-1]: + if len(task_msg) > 2000: + logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}") + task_msg = task_msg[:1999] + self.sent_messages.append(await self.ctx.author.send(task_msg)) + + # Send the last message with buttons + task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) + logger.debug(f"TH Message length {len(self.task_messages[-1])}") + last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view) + + await task_accept_view.start(last_msg) + await task_accept_view.wait() + + return task_accept_view.choice + + async def handle(self) -> None: + """Handle the user's response to the task. + + This method should be called after `send` has been called.""" + # Ack task to the backend + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}") + + # Loop until the user's input is accepted + while True: + try: + # Wait for user to send a message + event = await self.ctx.bot.wait_for( + hikari.DMMessageCreateEvent, + predicate=lambda e: ( + e.author_id == self.ctx.author.id + and e.message.content is not None + and not e.message.content.startswith(settings.prefix) + ), + timeout=MAX_TASK_TIME, + ) + + # Validate the message + if event.content is None or not self.check_user_input(event.content): + await self.ctx.author.send("Invalid input") + continue + + # Confirm user input + if not (await self.confirm_user_input(event.content)): + continue + + # Message is valid and confirmed by user + break + + except asyncio.TimeoutError: + return + + next_task = await self.notify(event.content, event) + if not isinstance(next_task, protocol_schema.TaskDone): + raise TypeError(f"Unknown task type: {next_task!r}") + + return + + async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task: + """Notify the backend that the user completed the task.""" + raise NotImplementedError + + async def confirm_user_input(self, content: str) -> bool: + """Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms.""" + raise NotImplementedError + + def check_user_input(self, content: str) -> bool: + """Check the user's response to the task. Returns True if the response is valid.""" + raise NotImplementedError + + async def cancel(self, reason: str = "not specified") -> None: + """Cancel the task.""" + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + await oasst_api.nack_task(self.task.id, reason) + + +_Ranking_contra = t.TypeVar( + "_Ranking_contra", + bound=protocol_schema.RankAssistantRepliesTask + | protocol_schema.RankInitialPromptsTask + | protocol_schema.RankPrompterRepliesTask + | protocol_schema.RankConversationRepliesTask, + contravariant=True, +) + + +class _RankingTaskHandler(_TaskHandler[_Ranking_contra]): + """This should not be used directly. Use its subclasses instead.""" + + async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task: + oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api + + task = await oasst_api.post_interaction( + protocol_schema.MessageRanking( + user=protocol_schema.User( + id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username + ), + ranking=[int(r) - 1 for r in content.split(",")], + message_id=f"{self.sent_messages[0].id}", + ) + ) + + db: Connection = self.ctx.bot.d.db + async with db.cursor() as cursor: + row = await ( + await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,)) + ).fetchone() + log_channel = row[0] if row else None + log_messages: list[hikari.Message] = [] + + if log_channel is not None: + for message in self.task_messages[:-1]: + msg = await self.ctx.bot.rest.create_message(log_channel, message) + log_messages.append(msg) + await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention)) + + return task + + +class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]: + return rank_assistant_reply_message(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]): + def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None: + super().__init__(ctx, task) + + @staticmethod + def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]: + return rank_initial_prompts_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.prompt_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]: + return rank_prompter_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]): + @staticmethod + def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]: + return rank_conversation_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content.split(",")) == len(self.task.reply_messages) and all( + [r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")] + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send( + confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view + ) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]): + @staticmethod + def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]: + return initial_prompt_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]: + return prompter_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]: + return assistant_reply_messages(task) + + def check_user_input(self, content: str) -> bool: + return len(content) > 0 + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True) + + +class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]): + def check_user_input(self, content: str) -> bool: + user_labels = content.split(",") + return ( + all([l in self.task.valid_labels for l in user_labels]) + and self.task.mandatory_labels is not None + and all([m in user_labels for m in self.task.mandatory_labels]) + ) + + async def confirm_user_input(self, content: str) -> bool: + confirm_input_view = YesNoView() + msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view) + await confirm_input_view.start(msg) + await confirm_input_view.wait() + + return bool(confirm_input_view.choice) + + +class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]: + return label_assistant_reply_messages(task) + + +class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]): + @staticmethod + def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]: + return label_prompter_reply_messages(task) + + +summarize_story = "summarize_story" +rate_summary = "rate_summary" + @plugin.command -@lightbulb.option( - "type", - "The type of task to request.", - choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], - required=False, - default=str(TaskRequestType.random), - type=str, -) @lightbulb.command("work", "Complete a task.") @lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) -async def work(ctx: lightbulb.Context): - """Create and handle a task.""" - # Only send this message if started from a server - if ctx.guild_id is not None: - await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL) - - # make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it - currently_working: dict[ - hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None] - ] = ctx.bot.d.currently_working - +async def work2(ctx: lightbulb.Context) -> None: + """Complete a task.""" oasst_api: OasstApiClient = ctx.bot.d.oasst_api + currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working + + # Check if the user is already working on a task if ctx.author.id in currently_working: yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) msg = await ctx.author.send( @@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context): case False | None: return case True: - old_msg, task_id = currently_working[ctx.author.id] - if old_msg is not None: - logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}") - map(lambda c: c, old_msg.components) - await old_msg.delete() - if task_id is not None: - await oasst_api.nack_task(task_id, reason="user cancelled") + task_id = currently_working[ctx.author.id] + await oasst_api.nack_task(task_id, reason="user cancelled") - await msg.delete() + if ctx.guild_id: + await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL) - currently_working[ctx.author.id] = (None, None) - - # Create a TaskRequestType from the stringified enum value - task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) - - logger.debug(f"Starting task_type: {task_type!r}") + # Keep sending tasks until the user doesn't want more try: - await _handle_task(ctx, task_type) + while True: + task = await oasst_api.fetch_random_task( + user=protocol_schema.User( + id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord" + ), + ) + + # Ranking tasks + if isinstance(task, protocol_schema.RankAssistantRepliesTask): + task_handler = RankAssistantRepliesHandler(ctx, task) + elif isinstance(task, protocol_schema.RankInitialPromptsTask): + task_handler = RankInitialPromptHandler(ctx, task) + elif isinstance(task, protocol_schema.RankPrompterRepliesTask): + task_handler = RankPrompterReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.RankConversationRepliesTask): + task_handler = RankConversationReplyHandler(ctx, task) + + # Text input tasks + elif isinstance(task, protocol_schema.InitialPromptTask): + task_handler = InitialPromptHandler(ctx, task) + elif isinstance(task, protocol_schema.PrompterReplyTask): + task_handler = PrompterReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.AssistantReplyTask): + task_handler = AssistantReplyHandler(ctx, task) + + # Label tasks + elif isinstance(task, protocol_schema.LabelAssistantReplyTask): + task_handler = LabelAssistantReplyHandler(ctx, task) + elif isinstance(task, protocol_schema.LabelPrompterReplyTask): + task_handler = LabelPrompterReplyHandler(ctx, task) + + else: + raise ValueError(f"Unknown task type: {type(task)}") + + resp = await task_handler.send() + + match resp: + case "accept": + currently_working[ctx.author.id] = task.id + await task_handler.handle() + case "next": + await task_handler.cancel("user skipped task") + case "cancel": + await task_handler.cancel("user canceled work") + break + case None: + await task_handler.cancel("select timed out") + break finally: del currently_working[ctx.author.id] -async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None: - """Handle creating and collecting user input for a task. - - Continually present tasks to the user until they select one, cancel, or time out. - If they select one, present the task steps until a `task_done` task is received. - Finally, ask the user if they want to perform another task (of the same type). - """ - oasst_api: OasstApiClient = ctx.bot.d.oasst_api - - # Continue to complete tasks until the user doesn't want to do another - done = False - while not done: - - # Loop until the user accepts a task - task, msg_id = await _select_task(ctx, task_type) - - if task is None: - # User cancelled - return - - # Task action loop - completed = False - while not completed: - await ctx.author.send(embed=plain_embed("Please type your response below:")) - try: - event = await ctx.bot.wait_for( - hikari.DMMessageCreateEvent, - timeout=MAX_TASK_TIME, - predicate=lambda e: e.author.id == ctx.author.id - and not (e.message.content or "").startswith(settings.prefix), - ) - except asyncio.TimeoutError: - await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) - await oasst_api.nack_task(task.id, reason="timed out") - logger.info(f"Task {task.id} timed out") - return - - # Invalid response - valid, err_msg = _validate_user_input(event.content, task) - if not valid or event.content is None: - - await ctx.author.send(embed=invalid_user_input_embed(err_msg)) - continue - - logger.debug(f"Successful user input received: {event.content}") - - # Confirm user input - if isinstance(task, protocol_schema.RankConversationRepliesTask): - content = confirm_ranking_response_message(event.content, task.replies) - elif isinstance(task, protocol_schema.RankInitialPromptsTask): - content = confirm_ranking_response_message(event.content, task.prompts) - elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): - content = confirm_label_response_message(event.content) - elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): - content = confirm_text_response_message(event.content) - else: - logger.critical(f"Unknown task type: {task.type}") - raise ValueError(f"Unknown task type: {task.type}") - - confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME) - msg = await ctx.author.send(content, components=confirm_resp_view) - await confirm_resp_view.start(msg) - await confirm_resp_view.wait() - - match confirm_resp_view.choice: - case False | None: - continue - case True: - await msg.delete() # buttons are already gone - - # Send the response to the backend - if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask): - reply = protocol_schema.MessageRanking( - message_id=str(msg_id), - ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")], - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - ) - elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask): - labels = event.content.replace(" ", "").split(",") - labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels} - - reply = protocol_schema.TextLabels( - message_id=task.message_id, - labels=labels_dict, - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - ) - elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): - reply = protocol_schema.TextReplyToMessage( - message_id=str(msg_id), - user_message_id=str(event.message_id), - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - text=event.content, - ) - else: - logger.critical(f"Unexpected task type received: {task.type}") - raise ValueError(f"Unexpected task type received: {task.type}") - - logger.debug(f"Sending reply to backend: {reply!r}") - - # Get next task - new_task = await oasst_api.post_interaction(reply) - logger.info(f"New task {new_task}") - - if new_task.type == TaskType.done: - await ctx.author.send(embed=plain_embed("Task completed")) - completed = True - continue - else: - logger.critical(f"Unexpected task type received: {new_task.type}") - - # Send a message in all the log channels that the task is complete - conn: Connection = ctx.bot.d.db - async with conn.cursor() as cursor: - await cursor.execute("SELECT log_channel_id FROM guild_settings") - log_channel_ids = await cursor.fetchall() - - channels = [ - ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0]) - for id in log_channel_ids - ] - - done_embed = task_complete_embed(task, ctx.author.mention) - # This will definitely get the bot rate limited, but that's a future problem - asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel))) - - # ask the user if they want to do another task - another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) - msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view) - await another_task_view.start(msg) - await another_task_view.wait() - - match another_task_view.choice: - case False | None: - done = True - await msg.edit(embed=plain_embed("Exiting, goodbye!")) - case True: - pass - - -async def _select_task( - ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None -) -> tuple[protocol_schema.Task | None, str]: - """Present tasks to the user until they accept one, cancel, or time out.""" - oasst_api: OasstApiClient = ctx.bot.d.oasst_api - logger.debug(f"Starting task selection for {task_type}") - - # Loop until the user accepts a task, cancels, or times out - msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED - while True: - logger.debug(f"Requesting task of type {task_type}") - task = await oasst_api.fetch_task(task_type, user) - resp, msg = await _send_task(ctx, task, msg) - msg_id = str(msg.id) - - logger.debug(f"User choice: {resp}") - match resp: - case "accept": - logger.info(f"Task {task.id} accepted, sending ACK") - await oasst_api.ack_task(task.id, msg_id) - return task, msg_id - - case "next": - logger.info(f"Task {task.id} rejected, sending NACK") - await oasst_api.nack_task(task.id, "rejected") - continue - - case "cancel": - logger.info(f"Task {task.id} canceled, sending NACK") - await oasst_api.nack_task(task.id, "canceled") - await ctx.author.send(embed=plain_embed("Task canceled. Exiting")) - return None, msg_id - - case None: - logger.info(f"Task {task.id} timed out, sending NACK") - await oasst_api.nack_task(task.id, "timed out") - await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) - return None, msg_id - - -async def _send_task( - ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message] -) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]: - """Send a task to the user. - - Returns the user's choice and the message ID of the task message. - """ - # The clean way to do this would be to attach a `to_embed` method to the task classes - # but the tasks aren't discord specific so that doesn't really make sense. - - embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED - content: hikari.UndefinedOr[str] = hikari.UNDEFINED - - # Create an embed based on the task's type - if task.type == TaskRequestType.initial_prompt: - assert isinstance(task, protocol_schema.InitialPromptTask) - logger.debug("sending initial prompt task") - content = initial_prompt_message(task) - - elif task.type == TaskRequestType.rank_initial_prompts: - assert isinstance(task, protocol_schema.RankInitialPromptsTask) - logger.debug("sending rank initial prompt task") - content = rank_initial_prompts_message(task) - - elif task.type == TaskRequestType.rank_prompter_replies: - assert isinstance(task, protocol_schema.RankPrompterRepliesTask) - logger.debug("sending rank user reply task") - content = rank_prompter_reply_message(task) - - elif task.type == TaskRequestType.rank_assistant_replies: - assert isinstance(task, protocol_schema.RankAssistantRepliesTask) - logger.debug("sending rank assistant reply task") - content = rank_assistant_reply_message(task) - - elif task.type == TaskRequestType.label_initial_prompt: - assert isinstance(task, protocol_schema.LabelInitialPromptTask) - logger.debug("sending label initial prompt task") - content = label_initial_prompt_message(task) - - elif task.type == TaskRequestType.label_prompter_reply: - assert isinstance(task, protocol_schema.LabelPrompterReplyTask) - logger.debug("sending label prompter reply task") - content = label_prompter_reply_message(task) - - elif task.type == TaskRequestType.label_assistant_reply: - assert isinstance(task, protocol_schema.LabelAssistantReplyTask) - logger.debug("sending label assistant reply task") - content = label_assistant_reply_message(task) - - elif task.type == TaskRequestType.prompter_reply: - assert isinstance(task, protocol_schema.PrompterReplyTask) - logger.debug("sending user reply task") - content = prompter_reply_message(task) - - elif task.type == TaskRequestType.assistant_reply: - assert isinstance(task, protocol_schema.AssistantReplyTask) - logger.debug("sending assistant reply task") - content = assistant_reply_message(task) - - elif task.type == TaskRequestType.summarize_story: - raise NotImplementedError - elif task.type == TaskRequestType.rate_summary: - raise NotImplementedError - - else: - logger.critical(f"unknown task type {task.type}") - raise ValueError(f"unknown task type {task.type}") - - view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) - if not msg: - msg = await ctx.author.send( - content, - embed=embed, - components=view, - ) - else: - await msg.edit( - content, - embed=embed, - components=view, - ) - - assert msg is not None - - # Set the choice id as the current msg id - ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id) - - await view.start(msg) - await view.wait() - - return view.choice, msg - - -def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]: - """Returns whether the user's input is valid for the task type and an error message.""" - if content is None: - return False, "No input provided" - - # User message input - if ( - task.type == TaskRequestType.initial_prompt - or task.type == TaskRequestType.prompter_reply - or task.type == TaskRequestType.assistant_reply - ): - assert isinstance( - task, - protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask, - ) - return len(content) > 0, "Message must be at least one character long." - - # Ranking tasks - elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies: - assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask) - num_replies = len(task.replies) - - rankings = content.replace(" ", "").split(",") - return ( - set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies, - "Message must contain numbers for all replies.", - ) - - elif task.type == TaskRequestType.rank_initial_prompts: - assert isinstance(task, protocol_schema.RankInitialPromptsTask) - num_prompts = len(task.prompts) - - rankings = content.replace(" ", "").split(",") - return ( - set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts, - "Message must contain numbers for all prompts.", - ) - - # Labels tasks - elif task.type in ( - TaskRequestType.label_initial_prompt, - TaskRequestType.label_prompter_reply, - TaskRequestType.label_assistant_reply, - ): - assert isinstance( - task, - protocol_schema.LabelInitialPromptTask - | protocol_schema.LabelPrompterReplyTask - | protocol_schema.LabelAssistantReplyTask, - ) - - labels = content.replace(" ", "").split(",") - valid_labels = set(task.valid_labels) - return ( - set(labels).issubset(valid_labels), - "Message must only contain labels from predefined set of labels.", - ) - - elif task.type == TaskRequestType.summarize_story: - raise NotImplementedError - elif task.type == TaskRequestType.rate_summary: - raise NotImplementedError - - else: - logger.critical(f"Unknown task type {task.type}") - raise ValueError(f"Unknown task type {task.type}") - - class TaskAcceptView(miru.View): """View with three buttons: accept, next, and cancel. diff --git a/discord-bot/bot/messages.py b/discord-bot/bot/messages.py index 7bd57bb9..26ad9158 100644 --- a/discord-bot/bot/messages.py +++ b/discord-bot/bot/messages.py @@ -1,4 +1,11 @@ -"""All user-facing messages and embeds.""" +"""All user-facing messages and embeds. + +When sending a conversation +- The function will return a list of strings + - use asyncio.gather to send all messages + +- +""" from datetime import datetime @@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str: return f":trophy: _{text}_" -def _label_prompt(text: str) -> str: - return f":question: _{text}" +def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str: + return f""":question: _{text}_ +Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"} +Valid labels: {", ".join(valid_labels)} +""" def _response_prompt(text: str) -> str: @@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str: """ -def _make_ordered_list(items: list[str]) -> list[str]: - return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)] +def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]: + return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)] -def _ordered_list(items: list[str]) -> str: +def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str: return "\n\n".join(_make_ordered_list(items)) -def _conversation(conv: protocol_schema.Conversation) -> str: - return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) - - -def _hint(hint: str | None) -> str: - return f"{NL}Hint: {hint}" if hint else "" +def _conversation(conv: protocol_schema.Conversation) -> list[str]: + # return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) + messages = map( + lambda m: f"""\ +:robot: __Assistant__: +{m.text} +""" + if m.is_assistant + else f"""\ +:person_red_hair: __User__: +{m.text} +""", + conv.messages, + ) + return list(messages) def _li(text: str) -> str: @@ -82,59 +101,80 @@ def _li(text: str) -> str: ### -def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str: +def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]: """Creates the message that gets sent to users when they request an `initial_prompt` task.""" - return f"""\ + return [ + f"""\ -{_h1("INITIAL PROMPT")} +:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond: -{_writing_prompt("Please provide an initial prompt to the assistant.")} -{_hint(task.hint)} +:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""} """ + ] -def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str: +def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_initial_prompts` task.""" - return f"""\ + return [ + f"""\ -{_h1("RANK INITIAL PROMPTS")} +:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond: -{_ordered_list(task.prompts)} +{_ordered_list(task.prompt_messages)} -{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} +:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_ """ + ] -def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str: +def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_prompter_replies` task.""" - return f"""\ + return [ + """\ -{_h1("RANK PROMPTER REPLIES")} +:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":person_red_hair: __User__: +{_ordered_list(task.reply_messages)} + +:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_ +""", + ] -{_conversation(task.conversation)} -{_user(None)} -{_ordered_list(task.replies)} - -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} -""" - - -def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str: +def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]: """Creates the message that gets sent to users when they request a `rank_assistant_replies` task.""" - return f"""\ + return [ + """\ -{_h1("RANK ASSISTANT REPLIES")} +:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":robot: __Assistant__:, +{_ordered_list(task.reply_messages)} +:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_ +""", + ] -{_conversation(task.conversation)} -{_assistant(None)} -{_ordered_list(task.replies)} +def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]: + """Creates the message that gets sent to users when they request a `rank_conversation_replies` task.""" + return [ + """\ -{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} -""" +:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond: + +""", + *_conversation(task.conversation), + f""":person_red_hair: __User__: +{_ordered_list(task.reply_messages)} +""", + ] def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str: @@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) - {task.prompt} -{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")} +{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} """ -def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str: +def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `label_prompter_reply` task.""" - return f"""\ + return [ + f"""\ {_h1("LABEL PROMPTER REPLY")} -{_conversation(task.conversation)} -{_user(None)} +""", + *_conversation(task.conversation), + f"""{_user(None)} {task.reply} -{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} -""" +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} +""", + ] -def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str: +def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `label_assistant_reply` task.""" - return f"""\ + return [ + f"""\ {_h1("LABEL ASSISTANT REPLY")} -{_conversation(task.conversation)} +""", + *_conversation(task.conversation), + f""" {_assistant(None)} {task.reply} -{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} -""" +{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)} +""", + ] -def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: +def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `prompter_reply` task.""" - return f"""\ + return [ + """\ +:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond: -{_h1("PROMPTER REPLY")} +""", + *_conversation(task.conversation), + f"""{f"{NL}Hint: {task.hint}" if task.hint else ""} + +:speech_balloon: _Please provide a reply to the assistant._ +""", + ] -{_conversation(task.conversation)} -{_hint(task.hint)} - -{_response_prompt("Please provide a reply to the assistant.")} -""" +# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]: +# """Creates the message that gets sent to users when they request a `prompter_reply` task.""" +# return [ +# message_templates.render("title.msg", "PROMPTER REPLY"), +# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation], +# message_templates.render("prompter_reply_task.msg", task.hint), +# ] -def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str: +def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]: """Creates the message that gets sent to users when they request a `assistant_reply` task.""" - return f"""\ -{_h1("ASSISTANT REPLY")} + return [ + """\ +:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond: +""", + *_conversation(task.conversation), + """\ -{_conversation(task.conversation)} - -{_response_prompt("Please provide an assistant reply to the prompter.")} -""" +:speech_balloon: _Please provide a reply to the user as the assistant._ +""", + ] def confirm_text_response_message(content: str) -> str: @@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str: """ -def confirm_ranking_response_message(content: str, items: list[str]) -> str: +def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str: user_rankings = [int(r) for r in content.replace(" ", "").split(",")] original_list = _make_ordered_list(items) user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings]) diff --git a/oasst-shared/oasst_shared/api_client.py b/oasst-shared/oasst_shared/api_client.py index 1ee2865b..26592fb8 100644 --- a/oasst-shared/oasst_shared/api_client.py +++ b/oasst-shared/oasst_shared/api_client.py @@ -68,12 +68,15 @@ class OasstApiClient: async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]: """Make a POST request to the backend.""" logger.debug(f"POST {self.backend_url}{path} DATA: {data}") - response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) + response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key}) + logger.debug(f"response: {response}") # If the response is not a 2XX, check to see # if the json has the fields to create an # OasstError. if response.status >= 300: + text = await response.text() + logger.debug(f"resp text: {text}") data = await response.json() try: oasst_error = protocol_schema.OasstErrorResponse(**(data or {})) @@ -114,20 +117,21 @@ class OasstApiClient: task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None, collective: bool = False, + lang: Optional[str] = None, ) -> protocol_schema.Task: """Fetch a task from the backend.""" logger.debug(f"Fetching task {task_type} for user {user}") - req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective) + req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang) resp = await self.post("/api/v1/tasks/", data=req.dict()) logger.debug(f"RESP {resp}") return self._parse_task(resp) async def fetch_random_task( - self, user: Optional[protocol_schema.User] = None, collective: bool = False + self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None ) -> protocol_schema.Task: """Fetch a random task from the backend.""" logger.debug(f"Fetching random for user {user}") - return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective) + return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang) async def ack_task(self, task_id: str | UUID, message_id: str) -> None: """Send an ACK for a task to the backend."""