Files
Open-Assistant/discord-bot/bot/extensions/work.py
T
Alex Ott dfd2c35276 Feat/task handler (#1056)
* move task logic to task handlers

* rename command

* remove test code

* rename classes and add missing handler creations

* switch task back to random and fetch log_channel_id from db
2023-02-02 17:28:00 +01:00

511 lines
20 KiB
Python

"""Work plugin for collecting user data."""
import asyncio
import typing as t
from uuid import UUID
import hikari
import lightbulb
import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.messages import (
assistant_reply_messages,
confirm_label_response_message,
confirm_ranking_response_message,
confirm_text_response_message,
initial_prompt_messages,
label_assistant_reply_messages,
label_prompter_reply_messages,
plain_embed,
prompter_reply_messages,
rank_assistant_reply_message,
rank_conversation_reply_messages,
rank_initial_prompts_messages,
rank_prompter_reply_messages,
task_complete_embed,
)
from bot.settings import Settings
from loguru import logger
from oasst_shared.api_client import OasstApiClient
from oasst_shared.schemas import protocol as protocol_schema
plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # seconds
MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
settings = Settings()
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
class _TaskHandler(t.Generic[_Task_contra]):
"""Handle user interaction for a task."""
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
"""Create a new `TaskHandler`.
Args:
ctx (lightbulb.Context): The context of the command that started the task.
task (_Task_contra): The task to handle.
"""
self.ctx = ctx
self.task = task
self.task_messages = self.get_task_messages(task)
self.sent_messages: list[hikari.Message] = []
@staticmethod
def get_task_messages(task: _Task_contra) -> list[str]:
"""Get the messages to send to the user for the task."""
raise NotImplementedError
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
"""Send the task and wait for the user to accept/skip/cancel it."""
# Send all but the last message because we need to attach buttons to the last one
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
for task_msg in self.task_messages[:-1]:
if len(task_msg) > 2000:
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
task_msg = task_msg[:1999]
self.sent_messages.append(await self.ctx.author.send(task_msg))
# Send the last message with buttons
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
await task_accept_view.start(last_msg)
await task_accept_view.wait()
return task_accept_view.choice
async def handle(self) -> None:
"""Handle the user's response to the task.
This method should be called after `send` has been called."""
# Ack task to the backend
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
# Loop until the user's input is accepted
while True:
try:
# Wait for user to send a message
event = await self.ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
predicate=lambda e: (
e.author_id == self.ctx.author.id
and e.message.content is not None
and not e.message.content.startswith(settings.prefix)
),
timeout=MAX_TASK_TIME,
)
# Validate the message
if event.content is None or not self.check_user_input(event.content):
await self.ctx.author.send("Invalid input")
continue
# Confirm user input
if not (await self.confirm_user_input(event.content)):
continue
# Message is valid and confirmed by user
break
except asyncio.TimeoutError:
return
next_task = await self.notify(event.content, event)
if not isinstance(next_task, protocol_schema.TaskDone):
raise TypeError(f"Unknown task type: {next_task!r}")
return
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
"""Notify the backend that the user completed the task."""
raise NotImplementedError
async def confirm_user_input(self, content: str) -> bool:
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
raise NotImplementedError
def check_user_input(self, content: str) -> bool:
"""Check the user's response to the task. Returns True if the response is valid."""
raise NotImplementedError
async def cancel(self, reason: str = "not specified") -> None:
"""Cancel the task."""
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.nack_task(self.task.id, reason)
_Ranking_contra = t.TypeVar(
"_Ranking_contra",
bound=protocol_schema.RankAssistantRepliesTask
| protocol_schema.RankInitialPromptsTask
| protocol_schema.RankPrompterRepliesTask
| protocol_schema.RankConversationRepliesTask,
contravariant=True,
)
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
"""This should not be used directly. Use its subclasses instead."""
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
task = await oasst_api.post_interaction(
protocol_schema.MessageRanking(
user=protocol_schema.User(
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
),
ranking=[int(r) - 1 for r in content.split(",")],
message_id=f"{self.sent_messages[0].id}",
)
)
db: Connection = self.ctx.bot.d.db
async with db.cursor() as cursor:
row = await (
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
).fetchone()
log_channel = row[0] if row else None
log_messages: list[hikari.Message] = []
if log_channel is not None:
for message in self.task_messages[:-1]:
msg = await self.ctx.bot.rest.create_message(log_channel, message)
log_messages.append(msg)
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
return task
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
return rank_assistant_reply_message(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
super().__init__(ctx, task)
@staticmethod
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
return rank_initial_prompts_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.prompt_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
return rank_prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
return rank_conversation_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
@staticmethod
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
return initial_prompt_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
return prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
return assistant_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
def check_user_input(self, content: str) -> bool:
user_labels = content.split(",")
return (
all([l in self.task.valid_labels for l in user_labels])
and self.task.mandatory_labels is not None
and all([m in user_labels for m in self.task.mandatory_labels])
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
return label_assistant_reply_messages(task)
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
return label_prompter_reply_messages(task)
summarize_story = "summarize_story"
rate_summary = "rate_summary"
@plugin.command
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work2(ctx: lightbulb.Context) -> None:
"""Complete a task."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
# Check if the user is already working on a task
if ctx.author.id in currently_working:
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
flags=hikari.MessageFlag.EPHEMERAL,
components=yn_view,
)
await yn_view.start(msg)
await yn_view.wait()
match yn_view.choice:
case False | None:
return
case True:
task_id = currently_working[ctx.author.id]
await oasst_api.nack_task(task_id, reason="user cancelled")
if ctx.guild_id:
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
# Keep sending tasks until the user doesn't want more
try:
while True:
task = await oasst_api.fetch_random_task(
user=protocol_schema.User(
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
),
)
# Ranking tasks
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
task_handler = RankAssistantRepliesHandler(ctx, task)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
task_handler = RankInitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
task_handler = RankPrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
task_handler = RankConversationReplyHandler(ctx, task)
# Text input tasks
elif isinstance(task, protocol_schema.InitialPromptTask):
task_handler = InitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.PrompterReplyTask):
task_handler = PrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.AssistantReplyTask):
task_handler = AssistantReplyHandler(ctx, task)
# Label tasks
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
task_handler = LabelAssistantReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
task_handler = LabelPrompterReplyHandler(ctx, task)
else:
raise ValueError(f"Unknown task type: {type(task)}")
resp = await task_handler.send()
match resp:
case "accept":
currently_working[ctx.author.id] = task.id
await task_handler.handle()
case "next":
await task_handler.cancel("user skipped task")
case "cancel":
await task_handler.cancel("user canceled work")
break
case None:
await task_handler.cancel("select timed out")
break
finally:
del currently_working[ctx.author.id]
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: t.Literal["accept", "next", "cancel"] | None = None
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Accept button pressed")
self.choice = "accept"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Next button pressed")
self.choice = "next"
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Cancel button pressed")
self.choice = "cancel"
await ctx.message.edit(component=None)
self.stop()
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
class YesNoView(miru.View):
"""View with two buttons: yes and no.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: bool | None = None
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = True
await ctx.message.edit(component=None)
self.stop()
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = False
await ctx.message.edit(component=None)
self.stop()
async def on_timeout(self) -> None:
if self.message is not None:
await self.message.edit(component=None)
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)