diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py deleted file mode 100644 index 0caa1595..00000000 --- a/discord-bot/api_client.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -from typing import Optional, Type -import typing as t - -import requests -from oasst_shared.schemas import protocol as protocol_schema - - -class TaskType(str, enum.Enum): - summarize_story = "summarize_story" - rate_summary = "rate_summary" - initial_prompt = "initial_prompt" - user_reply = "user_reply" - assistant_reply = "assistant_reply" - rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" - rank_assistant_replies = "rank_assistant_replies" - done = "task_done" - - -class ApiClient: - def __init__(self, backend_url: str, api_key: str): - self.backend_url = backend_url - self.api_key = api_key - - task_models_map: dict[str, Type[protocol_schema.Task]] = { - TaskType.summarize_story: protocol_schema.SummarizeStoryTask, - TaskType.rate_summary: protocol_schema.RateSummaryTask, - TaskType.initial_prompt: protocol_schema.InitialPromptTask, - TaskType.user_reply: protocol_schema.UserReplyTask, - TaskType.assistant_reply: protocol_schema.AssistantReplyTask, - TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, - TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, - TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, - TaskType.done: protocol_schema.TaskDone, - } - self.task_models_map = task_models_map - - def post(self, path: str, json: dict) -> dict: - response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) - response.raise_for_status() - return response.json() - - def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: - if not isinstance(data, dict): - raise ValueError("dict expected") - - task_type = data.get("type") - if task_type not in self.task_models_map: - raise RuntimeError(f"Unsupported task type: {task_type}") - - return self.task_models_map[task_type].parse_obj(data) - - def fetch_task( - self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None - ) -> protocol_schema.Task: - req = protocol_schema.TaskRequest(type=task_type, user=user) - data = self.post("/api/v1/tasks/", req.dict()) - return self._parse_task(data) - - def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: - return self.fetch_task(protocol_schema.TaskRequestType.random, user) - - def ack_task(self, task_id: str, post_id: str) -> None: - req = protocol_schema.TaskAck(post_id=post_id) - return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) - - def nack_task(self, task_id: str, reason: str) -> None: - req = protocol_schema.TaskNAck(reason=reason) - return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - data = self.post("/api/v1/tasks/interaction", interaction.dict()) - return self._parse_task(data) diff --git a/discord-bot/bot/api_client.py b/discord-bot/bot/api_client.py new file mode 100644 index 00000000..cec1900f --- /dev/null +++ b/discord-bot/bot/api_client.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import asyncio +import enum +import typing as t +from typing import Optional, Type +from uuid import UUID + +import aiohttp +from loguru import logger + +from oasst_shared.schemas import protocol as protocol_schema + + +class TaskType(str, enum.Enum): + summarize_story = "summarize_story" + rate_summary = "rate_summary" + initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" + rank_initial_prompts = "rank_initial_prompts" + rank_user_replies = "rank_user_replies" + rank_assistant_replies = "rank_assistant_replies" + done = "task_done" + + +class OasstApiClient: + """API Client for interacting with the OASST backend.""" + + def __init__(self, backend_url: str, api_key: str): + """Create a new OasstApiClient. + + Args: + backend_url (str): The base backend URL. + api_key (str): The API key to use for authentication. + """ + logger.debug("Opening OasstApiClient session") + self.session = aiohttp.ClientSession() + self.backend_url = backend_url + self.api_key = api_key + + self.task_models_map: dict[str, Type[protocol_schema.Task]] = { + TaskType.summarize_story: protocol_schema.SummarizeStoryTask, + TaskType.rate_summary: protocol_schema.RateSummaryTask, + TaskType.initial_prompt: protocol_schema.InitialPromptTask, + TaskType.user_reply: protocol_schema.UserReplyTask, + TaskType.assistant_reply: protocol_schema.AssistantReplyTask, + TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, + TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, + TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, + TaskType.done: protocol_schema.TaskDone, + } + + async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]: + """Make a POST request to the backend.""" + logger.debug(f"POST {self.backend_url}{path} DATA: {data}") + response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) + response.raise_for_status() + return await response.json() + + def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: + task_type = data.get("type") + + if not isinstance(task_type, str): + logger.error(f"task type must be a `str`: {task_type}") + raise ValueError(f"task type must be a `str`: {task_type}") + + model = self.task_models_map.get(task_type) + if not model: + logger.error(f"Unsupported task type: {task_type}") + raise ValueError(f"Unsupported task type: {task_type}") + return self.task_models_map[task_type].parse_obj(data) + + async def fetch_task( + self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + ) -> protocol_schema.Task: + """Fetch a task from the backend.""" + logger.debug(f"Fetching task {task_type} for user {user}") + req = protocol_schema.TaskRequest(type=task_type.value, user=user) + resp = await self.post(f"/api/v1/tasks/", data=req.dict()) + print("resp", resp) + return self._parse_task(resp) + + async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + """Fetch a random task from the backend.""" + logger.debug(f"Fetching random for user {user}") + return await self.fetch_task(protocol_schema.TaskRequestType.random, user) + + async def ack_task(self, task_id: str | UUID, post_id: str): + """Send an ACK for a task to the backend.""" + logger.debug(f"ACK task {task_id} with post {post_id}") + req = protocol_schema.TaskAck(post_id=post_id) + return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict()) + + async def nack_task(self, task_id: str | UUID, reason: str): + """Send a NACK for a task to the backend.""" + logger.debug(f"NACK task {task_id} with reason {reason}") + req = protocol_schema.TaskNAck(reason=reason) + return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict()) + + async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: + """Send a completed task to the backend.""" + logger.debug(f"Interaction: {interaction}") + resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict()) + + return self._parse_task(resp) + + async def close(self): + logger.debug("Closing OasstApiClient session") + await self.session.close() + + +async def main(): + api = OasstApiClient("http://localhost:8080", "test") + try: + task = await api.fetch_task(protocol_schema.TaskRequestType.initial_prompt, None) + print(task) + finally: + + await api.close() + # session = aiohttp.ClientSession() + # try: + # resp = await session.post("http://localhost:8080/api/v1/tasks/", json={"type": "initial_prompt", "user": None}) + # resp.raise_for_status() + # print(await resp.text()) + # finally: + # await session.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index af163545..de8ceacf 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -6,6 +6,7 @@ import lightbulb import miru from bot.config import Config +from bot.api_client import OasstApiClient config = Config.from_env() @@ -29,8 +30,11 @@ async def on_starting(event: hikari.StartingEvent): await bot.d.db.executescript(open("./bot/db/schema.sql").read()) await bot.d.db.commit() + bot.d.oasst_api = OasstApiClient("http://localhost:8080", "any_key") + @bot.listen() async def on_stopping(event: hikari.StoppingEvent): """Cleanup.""" await bot.d.db.close() + await bot.d.oasst_api.close() diff --git a/discord-bot/bot/extensions/_example.py b/discord-bot/bot/extensions/_example.py index 8ac7fe21..330f5909 100644 --- a/discord-bot/bot/extensions/_example.py +++ b/discord-bot/bot/extensions/_example.py @@ -1,5 +1,6 @@ +# TODO: Convert file to markdown # -*- coding: utf-8 -*- -"""Example plugins for reference. +"""Example plugin for reference. Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`. """ @@ -396,6 +397,10 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None: await view.start(await resp.message()) +# TODO: Database example +# TODO: Rest client example + + def load(bot: lightbulb.BotApp): """Add the plugin to the bot.""" bot.add_plugin(plugin) diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py index b70a22fd..28bcede3 100644 --- a/discord-bot/bot/extensions/hot_reload.py +++ b/discord-bot/bot/extensions/hot_reload.py @@ -15,7 +15,7 @@ EXTENSIONS_FOLDER = "bot/extensions" def _get_extensions() -> list[str]: # Recursively get all the .py files in the extensions directory not starting with an `_`. - exts = glob("bot/extensions/**/*[!_].py", recursive=True) + exts = glob("bot/extensions/**/[!_]*.py", recursive=True) # Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension") return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts] diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py new file mode 100644 index 00000000..dfe51160 --- /dev/null +++ b/discord-bot/bot/extensions/tasks.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +"""Task plugin for testing different data collection methods.""" +import asyncio +import logging +import typing as t +from datetime import datetime, timedelta + +import hikari + +import lightbulb +import lightbulb.decorators +import miru +from bot.utils import format_time +from oasst_shared.schemas.protocol import TaskRequestType + +plugin = lightbulb.Plugin("TaskPlugin") + +MAX_TASK_TIME = 60 * 60 +MAX_TASK_ACCEPT_TIME = 60 +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True) +@lightbulb.implements(lightbulb.SlashCommand) +async def task_thread(ctx: lightbulb.SlashContext): + """Request a task from the backend.""" + typ: str = ctx.options.type + + # Create a thread for the task + thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}") + + await ctx.respond(f"Please complete the task in the thread: {thread.mention}") + + # Send the task in the thread + # TODO: Request task from the backend + await thread.send( + f"Please complete the task.\nSample Task\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + ) + + # Wait for the user to respond + try: + event = await ctx.bot.wait_for( + hikari.GuildMessageCreateEvent, + timeout=MAX_TASK_TIME, + predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id, + ) + await ctx.respond(f"Received message: {event.message.content}") + # TODO: Send the message to the backend + except asyncio.TimeoutError: + await ctx.respond("You took too long to respond.") + finally: + await thread.delete() + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True) +@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) +async def task_dm(ctx: lightbulb.Context): + """Request a task from the backend.""" + typ: str = ctx.options.type + + # Create a thread for the task + + await ctx.respond(f"Please complete the task in your DMs") + + # Send the task in the thread + # TODO: Request task from the backend + await ctx.author.send( + f"Please complete the task.\nSample Task ({typ})\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + ) + + # Wait for the user to respond + try: + event = await ctx.bot.wait_for( + hikari.DMMessageCreateEvent, + timeout=MAX_TASK_TIME, + predicate=lambda e: e.author.id == ctx.author.id, + ) + await ctx.respond(f"Received message: {event.message.content}") + # TODO: Send the message to the backend + except asyncio.TimeoutError: + await ctx.respond("You took too long to respond.") + + +class TaskModal(miru.Modal): + """Modal for submitting a task.""" + + response = miru.TextInput( + label="Response", + placeholder="Enter your response!", + required=True, + style=hikari.TextInputStyle.PARAGRAPH, + row=2, + ) + + async def callback(self, context: miru.ModalContext) -> None: + await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL) + # TODO: Send the message to the backend + + +class ModalView(miru.View): + """View for opening a modal.""" + + def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + self.modal_title = modal_title + self.task = task + + @miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY) + async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + modal = TaskModal(title=self.modal_title) + modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1)) + await ctx.respond_with_modal(modal) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True) +@lightbulb.implements(lightbulb.SlashCommand) +async def task_modal(ctx: lightbulb.SlashContext): + """Request a task from the backend.""" + # typ: str = ctx.options.type + view = ModalView( + modal_title=f"Assistant Response", + task="Please explain the moon landing to a six year old.", + timeout=MAX_TASK_TIME, + ) + resp = await ctx.respond( + "Task - Respond to the prompt as if you were the Assistant:", + flags=hikari.MessageFlag.EPHEMERAL, + components=view, + ) + await view.start(await resp.message()) + + +class RatingView(miru.View): + """View for rating a task.""" + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + self.presses: list[str] = [] + + def _close_if_all_pressed(self) -> None: + if len(self.presses) == 5: + self.stop() + + @miru.button(label="1", style=hikari.ButtonStyle.PRIMARY) + async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("1") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="2", style=hikari.ButtonStyle.PRIMARY) + async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("2") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="3", style=hikari.ButtonStyle.PRIMARY) + async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("3") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="4", style=hikari.ButtonStyle.PRIMARY) + async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("4") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="5", style=hikari.ButtonStyle.PRIMARY) + async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("5") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="Reset", style=hikari.ButtonStyle.DANGER) + async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.presses = [] + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + + +class SelectRating(miru.View): + @miru.select( + options=[ + hikari.SelectMenuOption( + label="1", + value="1", + description=None, + emoji=None, + is_default=False, + ), + hikari.SelectMenuOption( + label="2", + value="2", + description=None, + emoji=None, + is_default=False, + ), + hikari.SelectMenuOption( + label="3", + value="3", + description=None, + emoji=None, + is_default=False, + ), + ], + placeholder="Select the good responses", + min_values=0, + max_values=3, + row=3, + ) + async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None: + await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL) + + +@plugin.command +@lightbulb.command("rating_task", "Rate stuff.") +@lightbulb.implements(lightbulb.SlashCommand) +async def rating_task(ctx: lightbulb.SlashContext): + """Rate stuff.""" + + # Message Based rating + await ctx.respond( + "List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL + ) + try: + event = await ctx.bot.wait_for( + hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id + ) + + except asyncio.TimeoutError: + await ctx.respond("Timed out waiting for response") + return + + if event.content is None: + await ctx.respond("No content in message") + return + ratings = event.content.replace(" ", "").split(",") + + # Check if the ratings are valid + if len(ratings) != 5: + await ctx.respond("Invalid number of ratings") + if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]): + await ctx.respond("Invalid rating") + + await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL) + # Button Based rating + view = RatingView(timeout=MAX_TASK_TIME) + + resp = await ctx.respond("Click the buttons in order of best to worst response", components=view) + await view.start(await resp.message()) + await view.wait() + await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL) + await resp.delete() + + # Select Based rating + select_view = SelectRating(timeout=MAX_TASK_TIME) + resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL) + await select_view.start(await resp_2.message()) + await select_view.wait() + await resp_2.delete() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py new file mode 100644 index 00000000..e6ea3d7c --- /dev/null +++ b/discord-bot/bot/extensions/work.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +"""Work plugin for collecting user data.""" +import asyncio +import logging +import typing as t +from datetime import datetime + +import hikari + +import lightbulb +import lightbulb.decorators +import miru +from bot.api_client import OasstApiClient, TaskType +from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TaskRequestType +from bot.utils import ZWJ + +plugin = lightbulb.Plugin("WorkPlugin") + +MAX_TASK_TIME = 60 * 60 # 1 hour +MAX_TASK_ACCEPT_TIME = 60 # 1 minute + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], + required=False, + default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random + type=str, +) +@lightbulb.command("work", "Complete a task.") +@lightbulb.implements(lightbulb.SlashCommand) +async def work(ctx: lightbulb.SlashContext): + """Create and handle a task.""" + task_type: TaskRequestType = TaskRequestType(ctx.options.type) + + await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) + logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}") + + await _handle_task(ctx, task_type) + + +async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None: + """Handle creating and collecting user input for a task. + + Continually present tasks to the user until they select one, cancel, or time out. + If they select one, present the task steps until a `task_done` task is received. + Finally, ask the user if they want to perform another task (of the same type). + """ + + oasst_api: OasstApiClient = ctx.bot.d.oasst_api + + # Continue to complete tasks until the user doesn't want to do another + done = False + while not done: + + # Loop until the user accepts a task + task, msg_id = await _select_task(ctx, task_type) + + if task is None: + return + + # Task action loop + completed = False + while not completed: + await ctx.author.send("Please type your response here:") + try: + event = await ctx.bot.wait_for( + hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id + ) + except asyncio.TimeoutError: + await ctx.author.send("Task timed out. Exiting") + # TODO: NACK task maybe? + return + + # Invalid response + if event.content is None: + await ctx.author.send("No content in message") + continue + + logger.info(f"User input received: {event.content}") + + # Send the response to the backend + reply = protocol_schema.TextReplyToPost( + post_id=str(msg_id), + user_post_id=str(event.message_id), + user=protocol_schema.User( + auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username + ), + text=event.content, + ) + logger.debug(f"Sending reply to backend: {reply!r}") + + # Get next task + new_task = await oasst_api.post_interaction(reply) + logger.info(f"New task {new_task}") + + if new_task.type == TaskType.done: + await ctx.author.send("Task completed") + completed = True + continue + else: + logger.fatal(f"Unexpected task type received: {new_task.type}") + + # ask the user if they want to do another task + choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) + msg = await ctx.author.send("Would you like another task?", components=choice_view) + await choice_view.start(msg) + await choice_view.wait() + + match choice_view.choice: + case False | None: + done = True + await ctx.author.send("Exiting, goodbye!") + case True: + pass + + +async def _select_task( + ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None +) -> tuple[protocol_schema.Task | None, str]: + """Present tasks to the user until they accept one, cancel, or time out.""" + oasst_api: OasstApiClient = ctx.bot.d.oasst_api + logger.debug(f"Starting task selection for {task_type}") + + # Loop until the user accepts a task, cancels, or times out + while True: + logger.debug(f"Requesting task of type {task_type}") + task = await oasst_api.fetch_task(task_type, user) + resp, msg_id = await _send_task(ctx, task) + + logger.debug(f"user choice: {resp}") + match resp: + case "accept": + logger.info(f"Task {task.id} accepted, sending ACK") + await oasst_api.ack_task(task.id, msg_id) + return task, msg_id + + case "next": + logger.info(f"Task {task.id} rejected, sending NACK") + await oasst_api.nack_task(task.id, "rejected") + await ctx.author.send("Sending next task...") + continue + + case "cancel": + logger.info(f"Task {task.id} canceled, sending NACK") + await oasst_api.nack_task(task.id, "canceled") + await ctx.author.send("Task canceled. Exiting") + return None, msg_id + + case None: + logger.info(f"Task {task.id} timed out, sending NACK") + await oasst_api.nack_task(task.id, "timed out") + await ctx.author.send("Task timed out. Exiting") + return None, msg_id + + +async def _send_task( + ctx: lightbulb.SlashContext, task: protocol_schema.Task +) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]: + """Send a task to the user. + + Returns the user's choice and the message ID of the task message.""" + + # The clean way to do this would be to attach a `to_embed` method to the task classes + # but the tasks aren't discord specific so that doesn't really make sense. + + view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) + embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED + + # Create an embed based on the task's type + if task.type == TaskRequestType.initial_prompt: + assert isinstance(task, protocol_schema.InitialPromptTask) + logger.info("sending initial prompt task") + embed = _initial_prompt_embed(task) + + elif task.type == TaskRequestType.rank_initial_prompts: + assert isinstance(task, protocol_schema.RankInitialPromptsTask) + logger.info("sending rank initial prompt task") + embed = _rank_initial_prompt_embed(task) + + else: + logger.error(f"unknown task type {task.type}") + + msg = await ctx.author.send( + ZWJ, + embed=embed, + components=view, + ) + + assert msg is not None + + await view.start(msg) + await view.wait() + + return view.choice, str(msg.id) + + +def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: + return ( + hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) + .set_image( + "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", + ) + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + +def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="Rank Initial Prompt", + description=f"Rank the following tasks from best to worst (1,2,3,4,5)", + timestamp=datetime.now().astimezone(), + ) + .set_image( + "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", + ) + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for i, prompt in enumerate(task.prompts): + embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False) + + return embed + + +class TaskAcceptView(miru.View): + """View with three buttons: accept, next, and cancel. + + The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute. + """ + + choice: t.Literal["accept", "next", "cancel"] | None = None + + @miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS) + async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Accept button pressed") + self.choice = "accept" + self.stop() + + @miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY) + async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Next button pressed") + self.choice = "next" + self.stop() + + @miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER) + async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Cancel button pressed") + self.choice = "cancel" + self.stop() + + +class ChoiceView(miru.View): + choice: bool | None = None + + @miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS) + async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.choice = True + self.stop() + + @miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER) + async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.choice = False + self.stop() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py index beb81c36..1ff6ef1f 100644 --- a/discord-bot/bot/utils.py +++ b/discord-bot/bot/utils.py @@ -21,3 +21,10 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s return f"" case _: raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}") + + +ZWJ = "\u200d" +"""Zero-width joiner. + +This appears as an empty message in Discord. +""" diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 49c5e1ba..17348c12 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -7,4 +7,5 @@ hikari-miru # modals and buttons python-dotenv # .env file support aiosqlite # database aiohttp # http client -aiohttp[speedups] # speedups for aiohttp \ No newline at end of file +aiohttp[speedups] # speedups for aiohttp +loguru \ No newline at end of file