diff --git a/discord-bot/.env.example b/discord-bot/.env.example index ec114c8f..8474ee90 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,7 +1,7 @@ BOT_TOKEN= DECLARE_GLOBAL_COMMANDS= OWNER_IDS=[, ] -PREFIX="/" # Don't change, this allows for slash commands in DMs +PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs OASST_API_URL="http://localhost:8080" # No trailing '/' OASST_API_KEY="" diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index df3c5f2f..8c604e1a 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -6,7 +6,7 @@ import hikari import lightbulb import miru from bot.settings import Settings -from bot.utils import EMPTY, mention +from bot.utils import mention from oasst_shared.api_client import OasstApiClient settings = Settings() @@ -34,8 +34,11 @@ async def on_starting(event: hikari.StartingEvent): bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key) - # A set of user id's that are currently doing work. - bot.d.currently_working = set() + # A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs). + # Either both are `None` or both are not `None`. + # If both are `None`, the user is not currently selecting a task. + # TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow + bot.d.currently_working = {} @bot.listen() @@ -50,13 +53,13 @@ async def _send_error_embed( ) -> None: ctx.command embed = hikari.Embed( - title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }", + title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }", description=content, color=0xFF0000, timestamp=datetime.now().astimezone(), ).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url)) - await ctx.respond(EMPTY, embed=embed) + await ctx.respond(embed=embed) @bot.listen(lightbulb.CommandErrorEvent) @@ -65,6 +68,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None: # Unwrap the exception to get the original cause exc = event.exception.__cause__ or event.exception ctx = event.context + if not ctx.bot.rest.is_alive: + return if isinstance(event.exception, lightbulb.CommandInvocationError): if not event.context.command: @@ -114,6 +119,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None: ctx, ) elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment): - await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx) + await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx) + elif isinstance(exc, lightbulb.errors.CommandNotFound): + await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.") else: raise exc diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py index 62f21305..5940f33a 100644 --- a/discord-bot/bot/extensions/guild_settings.py +++ b/discord-bot/bot/extensions/guild_settings.py @@ -78,7 +78,6 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None: # if the bot's permissions for this channel don't contain SEND_MESSAGE # This will also filter out categories and voice channels - print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES) if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES: await ctx.respond(f"I don't have permission to send messages in {ch.mention}.") return diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 388a93f0..a2607aec 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -7,7 +7,6 @@ import lightbulb import miru from aiosqlite import Connection from bot.db.schemas import GuildSettings -from bot.utils import EMPTY from loguru import logger plugin = lightbulb.Plugin( @@ -74,7 +73,7 @@ class LabelModal(miru.Modal): ) channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id) assert isinstance(channel, hikari.TextableChannel) - await channel.send(EMPTY, embed=embed) + await channel.send(embed=embed) class LabelSelect(miru.View): @@ -164,7 +163,7 @@ async def label_message_text(ctx: lightbulb.MessageContext): msg.content, timeout=60, ) - resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL) + resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL) await label_select_view.start(await resp.message()) await label_select_view.wait() diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index c905e7a0..6b7f8ea4 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -1,14 +1,27 @@ """Work plugin for collecting user data.""" import asyncio import typing as t -from datetime import datetime +from uuid import UUID import hikari import lightbulb import lightbulb.decorators import miru from aiosqlite import Connection -from bot.utils import EMPTY +from bot.messages import ( + assistant_reply_message, + confirm_ranking_response_message, + confirm_text_response_message, + initial_prompt_message, + invalid_user_input_embed, + plain_embed, + prompter_reply_message, + rank_assistant_reply_message, + rank_initial_prompts_message, + rank_prompter_reply_message, + task_complete_embed, +) +from bot.settings import Settings from loguru import logger from oasst_shared.api_client import OasstApiClient, TaskType from oasst_shared.schemas import protocol as protocol_schema @@ -19,6 +32,8 @@ plugin = lightbulb.Plugin("WorkPlugin") MAX_TASK_TIME = 60 * 60 # 1 hour MAX_TASK_ACCEPT_TIME = 60 # 1 minute +settings = Settings() + @plugin.command @lightbulb.option( @@ -33,25 +48,50 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute @lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) async def work(ctx: lightbulb.Context): """Create and handle a task.""" - # make sure the user isn't currently doing a task - currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working + # Only send this message if started from a server + if ctx.guild_id is not None: + await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL) + + # make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it + currently_working: dict[ + hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None] + ] = ctx.bot.d.currently_working + + oasst_api: OasstApiClient = ctx.bot.d.oasst_api if ctx.author.id in currently_working: - await ctx.respond( - "You are already performing a task. Please complete that one first.", flags=hikari.MessageFlag.EPHEMERAL + yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) + msg = await ctx.author.send( + embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"), + flags=hikari.MessageFlag.EPHEMERAL, + components=yn_view, ) - return + await yn_view.start(msg) + await yn_view.wait() - currently_working.add(ctx.author.id) + match yn_view.choice: + case False | None: + return + case True: + old_msg, task_id = currently_working[ctx.author.id] + if old_msg is not None: + logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}") + map(lambda c: c, old_msg.components) + await old_msg.delete() + if task_id is not None: + await oasst_api.nack_task(task_id, reason="user cancelled") + await msg.delete() + + currently_working[ctx.author.id] = (None, None) + + # Create a TaskRequestType from the stringified enum value task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) - await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) logger.debug(f"Starting task_type: {task_type!r}") - try: await _handle_task(ctx, task_type) finally: - currently_working.remove(ctx.author.id) + del currently_working[ctx.author.id] async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None: @@ -71,38 +111,79 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No task, msg_id = await _select_task(ctx, task_type) if task is None: + # User cancelled return # Task action loop completed = False while not completed: - await ctx.author.send("Please type your response here:") + await ctx.author.send(embed=plain_embed("Please type your response here")) try: event = await ctx.bot.wait_for( - hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id + hikari.DMMessageCreateEvent, + timeout=MAX_TASK_TIME, + predicate=lambda e: e.author.id == ctx.author.id + and not (e.message.content or "").startswith(settings.prefix), ) except asyncio.TimeoutError: - await ctx.author.send("Task timed out. Exiting") + await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) await oasst_api.nack_task(task.id, reason="timed out") logger.info(f"Task {task.id} timed out") return # Invalid response - if event.content is None or not _validate_user_input(event.content, task): - await ctx.author.send("Invalid response") + valid, err_msg = _validate_user_input(event.content, task) + if not valid or event.content is None: + + await ctx.author.send(embed=invalid_user_input_embed(err_msg)) continue logger.debug(f"Successful user input received: {event.content}") + # Confirm user input + if isinstance(task, protocol_schema.RankConversationRepliesTask): + content = confirm_ranking_response_message(event.content, task.replies) + elif isinstance(task, protocol_schema.RankInitialPromptsTask): + content = confirm_ranking_response_message(event.content, task.prompts) + elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): + content = confirm_text_response_message(event.content) + else: + logger.critical(f"Unknown task type: {task.type}") + raise ValueError(f"Unknown task type: {task.type}") + + confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME) + msg = await ctx.author.send(content, components=confirm_resp_view) + await confirm_resp_view.start(msg) + await confirm_resp_view.wait() + + match confirm_resp_view.choice: + case False | None: + continue + case True: + await msg.delete() # buttons are already gone + # Send the response to the backend - reply = protocol_schema.TextReplyToMessage( - message_id=str(msg_id), - user_message_id=str(event.message_id), - user=protocol_schema.User( - auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username - ), - text=event.content, - ) + if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask): + reply = protocol_schema.MessageRanking( + message_id=str(msg_id), + ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")], + user=protocol_schema.User( + auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username + ), + ) + elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask): + reply = protocol_schema.TextReplyToMessage( + message_id=str(msg_id), + user_message_id=str(event.message_id), + user=protocol_schema.User( + auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username + ), + text=event.content, + ) + else: + logger.critical(f"Unexpected task type received: {task.type}") + raise ValueError(f"Unexpected task type received: {task.type}") + logger.debug(f"Sending reply to backend: {reply!r}") # Get next task @@ -110,7 +191,7 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No logger.info(f"New task {new_task}") if new_task.type == TaskType.done: - await ctx.author.send("Task completed") + await ctx.author.send(embed=plain_embed("Task completed")) completed = True continue else: @@ -127,33 +208,20 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No for id in log_channel_ids ] - done_embed = ( - hikari.Embed( - title="Task Completion", - description=f"`{task.type}` completed by {ctx.author.mention}", - color=hikari.Color(0x00FF00), - timestamp=datetime.now().astimezone(), - ) - .add_field("Total Tasks", "0", inline=True) - .add_field("Server Ranking", "0/0", inline=True) - .add_field("Global Ranking", "0/0", inline=True) - .set_footer(f"Task ID: {task.id}") - ) + done_embed = task_complete_embed(task, ctx.author.mention) # This will definitely get the bot rate limited, but that's a future problem - asyncio.gather( - *(ch.send(EMPTY, embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)) - ) + asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel))) # ask the user if they want to do another task - choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) - msg = await ctx.author.send("Would you like another task?", components=choice_view) - await choice_view.start(msg) - await choice_view.wait() + another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) + msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view) + await another_task_view.start(msg) + await another_task_view.wait() - match choice_view.choice: + match another_task_view.choice: case False | None: done = True - await ctx.author.send("Exiting, goodbye!") + await msg.edit(embed=plain_embed("Exiting, goodbye!")) case True: pass @@ -166,10 +234,12 @@ async def _select_task( logger.debug(f"Starting task selection for {task_type}") # Loop until the user accepts a task, cancels, or times out + msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED while True: logger.debug(f"Requesting task of type {task_type}") task = await oasst_api.fetch_task(task_type, user) - resp, msg_id = await _send_task(ctx, task) + resp, msg = await _send_task(ctx, task, msg) + msg_id = str(msg.id) logger.debug(f"User choice: {resp}") match resp: @@ -181,25 +251,24 @@ async def _select_task( case "next": logger.info(f"Task {task.id} rejected, sending NACK") await oasst_api.nack_task(task.id, "rejected") - await ctx.author.send("Sending next task...") continue case "cancel": logger.info(f"Task {task.id} canceled, sending NACK") await oasst_api.nack_task(task.id, "canceled") - await ctx.author.send("Task canceled. Exiting") + await ctx.author.send(embed=plain_embed("Task canceled. Exiting")) return None, msg_id case None: logger.info(f"Task {task.id} timed out, sending NACK") await oasst_api.nack_task(task.id, "timed out") - await ctx.author.send("Task timed out. Exiting") + await ctx.author.send(embed=plain_embed("Task timed out. Exiting")) return None, msg_id async def _send_task( - ctx: lightbulb.Context, task: protocol_schema.Task -) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]: + ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message] +) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]: """Send a task to the user. Returns the user's choice and the message ID of the task message. @@ -208,37 +277,38 @@ async def _send_task( # but the tasks aren't discord specific so that doesn't really make sense. embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED + content: hikari.UndefinedOr[str] = hikari.UNDEFINED # Create an embed based on the task's type if task.type == TaskRequestType.initial_prompt: assert isinstance(task, protocol_schema.InitialPromptTask) logger.debug("sending initial prompt task") - embed = _initial_prompt_embed(task) + content = initial_prompt_message(task) elif task.type == TaskRequestType.rank_initial_prompts: assert isinstance(task, protocol_schema.RankInitialPromptsTask) logger.debug("sending rank initial prompt task") - embed = _rank_initial_prompt_embed(task) + content = rank_initial_prompts_message(task) elif task.type == TaskRequestType.rank_prompter_replies: assert isinstance(task, protocol_schema.RankPrompterRepliesTask) logger.debug("sending rank user reply task") - embed = _rank_prompter_reply_embed(task) + content = rank_prompter_reply_message(task) elif task.type == TaskRequestType.rank_assistant_replies: assert isinstance(task, protocol_schema.RankAssistantRepliesTask) logger.debug("sending rank assistant reply task") - embed = _rank_assistant_reply_embed(task) + content = rank_assistant_reply_message(task) elif task.type == TaskRequestType.prompter_reply: assert isinstance(task, protocol_schema.PrompterReplyTask) logger.debug("sending user reply task") - embed = _prompter_reply_embed(task) + content = prompter_reply_message(task) elif task.type == TaskRequestType.assistant_reply: assert isinstance(task, protocol_schema.AssistantReplyTask) logger.debug("sending assistant reply task") - embed = _assistant_reply_embed(task) + content = assistant_reply_message(task) elif task.type == TaskRequestType.summarize_story: raise NotImplementedError @@ -250,24 +320,34 @@ async def _send_task( raise ValueError(f"unknown task type {task.type}") view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) - msg = await ctx.author.send( - EMPTY, - embed=embed, - components=view, - ) + if not msg: + msg = await ctx.author.send( + content, + embed=embed, + components=view, + ) + else: + await msg.edit( + content, + embed=embed, + components=view, + ) assert msg is not None + # Set the choice id as the current msg id + ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id) + await view.start(msg) await view.wait() - return view.choice, str(msg.id) + return view.choice, msg -def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool: - """Returns whether the user's input is valid for the task type.""" +def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]: + """Returns whether the user's input is valid for the task type and an error message.""" if content is None: - return False + return False, "No input provided" # User message input if ( @@ -279,22 +359,28 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo task, protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask, ) - return len(content) > 0 + return len(content) > 0, "Message must be at least one character long." # Ranking tasks elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies: assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask) num_replies = len(task.replies) - rankings = content.split(",") - return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies + rankings = content.replace(" ", "").split(",") + return ( + set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies, + "Message must contain numbers for all replies.", + ) elif task.type == TaskRequestType.rank_initial_prompts: assert isinstance(task, protocol_schema.RankInitialPromptsTask) num_prompts = len(task.prompts) - rankings = content.split(",") - return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts + rankings = content.replace(" ", "").split(",") + return ( + set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts, + "Message must contain numbers for all prompts.", + ) elif task.type == TaskRequestType.summarize_story: raise NotImplementedError @@ -318,22 +404,29 @@ class TaskAcceptView(miru.View): async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: logger.info("Accept button pressed") self.choice = "accept" + await ctx.message.edit(component=None) self.stop() @miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY) async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: logger.info("Next button pressed") self.choice = "next" + await ctx.message.edit(component=None) self.stop() @miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER) async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: logger.info("Cancel button pressed") self.choice = "cancel" + await ctx.message.edit(component=None) self.stop() + async def on_timeout(self) -> None: + if self.message is not None: + await self.message.edit(component=None) -class ChoiceView(miru.View): + +class YesNoView(miru.View): """View with two buttons: yes and no. The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute. @@ -344,115 +437,18 @@ class ChoiceView(miru.View): @miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS) async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: self.choice = True + await ctx.message.edit(component=None) self.stop() @miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER) async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: self.choice = False + await ctx.message.edit(component=None) self.stop() - -################################################################ -# Template Embeds # -################################################################ - -# TODO: Maybe implement a better way of creating embeds, like `from_json` or something - - -def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: - return ( - hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - -def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="Rank Initial Prompt", - description="Rank the following tasks from best to worst (1,2,3,4,5)", - timestamp=datetime.now().astimezone(), - ) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for i, prompt in enumerate(task.prompts): - embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False) - - return embed - - -def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="Rank User Reply", - description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4", - timestamp=datetime.now().astimezone(), - ) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for i, reply in enumerate(task.replies): - embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False) - - return embed - - -def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="Rank Assistant Reply", - description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4", - timestamp=datetime.now().astimezone(), - ) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for i, reply in enumerate(task.replies): - embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False) - - return embed - - -def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="User Reply", - description=f"""\ - Send the next message in the conversation as if you were the user. - {'Hint: ' if task.hint else ''} - """, - timestamp=datetime.now().astimezone(), - ) - # .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for message in task.conversation.messages: - embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False) - - return embed - - -def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="User Reply", - description="Send the next message in the conversation as if you were the user.", - timestamp=datetime.now().astimezone(), - ) - # .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for message in task.conversation.messages: - embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False) - - return embed + async def on_timeout(self) -> None: + if self.message is not None: + await self.message.edit(component=None) def load(bot: lightbulb.BotApp): diff --git a/discord-bot/bot/messages.py b/discord-bot/bot/messages.py new file mode 100644 index 00000000..0f29511a --- /dev/null +++ b/discord-bot/bot/messages.py @@ -0,0 +1,207 @@ +"""All user-facing messages and embeds.""" + +from datetime import datetime + +import hikari +from oasst_shared.schemas import protocol as protocol_schema + +NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"] +NL = "\n" + +### +# Reusable 'components' +### + + +def _h1(text: str) -> str: + return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:" + + +def _h2(text: str) -> str: + return f"__**{text}**__" + + +def _h3(text: str) -> str: + return f"__{text}__" + + +def _writing_prompt(text: str) -> str: + return f":pencil: _{text}_" + + +def _ranking_prompt(text: str) -> str: + return f":trophy: _{text}_" + + +def _response_prompt(text: str) -> str: + return f":speech_balloon: _{text}_" + + +def _summarize_prompt(text: str) -> str: + return f":notepad_spiral: _{text}_" + + +def _user(text: str | None) -> str: + return f"""\ +:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""} +""" + + +def _assistant(text: str | None) -> str: + return f"""\ +:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""} +""" + + +def _make_ordered_list(items: list[str]) -> list[str]: + return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)] + + +def _ordered_list(items: list[str]) -> str: + return "\n\n".join(_make_ordered_list(items)) + + +def _conversation(conv: protocol_schema.Conversation) -> str: + return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) + + +def _hint(hint: str | None) -> str: + return f"{NL}Hint: {hint}" if hint else "" + + +### +# Messages +### + + +def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str: + """Creates the message that gets sent to users when they request an `initial_prompt` task.""" + return f"""\ + +{_h1("INITIAL PROMPT")} + +{_writing_prompt("Please provide an initial prompt to the assistant.")} +{_hint(task.hint)} +""" + + +def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str: + """Creates the message that gets sent to users when they request a `rank_initial_prompts` task.""" + return f"""\ + +{_h1("RANK INITIAL PROMPTS")} + +{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} + + +{_ordered_list(task.prompts)} +""" + + +def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str: + """Creates the message that gets sent to users when they request a `rank_prompter_replies` task.""" + return f"""\ + +{_h1("RANK PROMPTER REPLIES")} + +{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} + + +{_conversation(task.conversation)} +{_user(None)} +{_ordered_list(task.replies)} +""" + + +def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str: + """Creates the message that gets sent to users when they request a `rank_assistant_replies` task.""" + return f"""\ + +{_h1("RANK ASSISTANT REPLIES")} + +{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} + + +{_conversation(task.conversation)} +{_assistant(None)} +{_ordered_list(task.replies)} +""" + + +def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: + """Creates the message that gets sent to users when they request a `prompter_reply` task.""" + return f"""\ + +{_h1("PROMPTER REPLY")} + +{_response_prompt("Please provide a reply to the assistant.")} + + +{_conversation(task.conversation)} +{_hint(task.hint)} +""" + + +def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str: + """Creates the message that gets sent to users when they request a `assistant_reply` task.""" + return f"""\ +{_h1("ASSISTANT REPLY")} + +{_response_prompt("Please provide a reply to the assistant.")} + + +{_conversation(task.conversation)} +""" + + +def confirm_text_response_message(content: str) -> str: + return f"""\ +{_h2("CONFIRM RESPONSE")} + +> {content} +""" + + +def confirm_ranking_response_message(content: str, items: list[str]) -> str: + user_rankings = [int(r) for r in content.replace(" ", "").split(",")] + original_list = _make_ordered_list(items) + user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings]) + + return f"""\ +{_h2("CONFIRM RESPONSE")} + +{user_ranked_list} +""" + + +### +# Embeds +### + + +def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed: + return ( + hikari.Embed( + title="Task Completion", + description=f"`{task.type}` completed by {mention}", + color=hikari.Color(0x00FF00), + timestamp=datetime.now().astimezone(), + ) + .add_field("Total Tasks", "0", inline=True) + .add_field("Server Ranking", "0/0", inline=True) + .add_field("Global Ranking", "0/0", inline=True) + .set_footer(f"Task ID: {task.id}") + ) + + +def invalid_user_input_embed(error_message: str) -> hikari.Embed: + return hikari.Embed( + title="Invalid User Input", + description=error_message, + color=hikari.Color(0xFF0000), + timestamp=datetime.now().astimezone(), + ) + + +def plain_embed(text: str) -> hikari.Embed: + return hikari.Embed(color=0x36393F, description=text) diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py index 2d968c93..530f402a 100644 --- a/discord-bot/bot/utils.py +++ b/discord-bot/bot/utils.py @@ -24,13 +24,6 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}") -EMPTY = "\u200d" -"""Zero-width joiner. - -This appears as an empty message in Discord. -""" - - def mention( id: hikari.Snowflakeish, type: t.Literal["channel", "role", "user"],