mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
dfd2c35276
* move task logic to task handlers * rename command * remove test code * rename classes and add missing handler creations * switch task back to random and fetch log_channel_id from db
511 lines
20 KiB
Python
511 lines
20 KiB
Python
"""Work plugin for collecting user data."""
|
|
import asyncio
|
|
import typing as t
|
|
from uuid import UUID
|
|
|
|
import hikari
|
|
import lightbulb
|
|
import lightbulb.decorators
|
|
import miru
|
|
from aiosqlite import Connection
|
|
from bot.messages import (
|
|
assistant_reply_messages,
|
|
confirm_label_response_message,
|
|
confirm_ranking_response_message,
|
|
confirm_text_response_message,
|
|
initial_prompt_messages,
|
|
label_assistant_reply_messages,
|
|
label_prompter_reply_messages,
|
|
plain_embed,
|
|
prompter_reply_messages,
|
|
rank_assistant_reply_message,
|
|
rank_conversation_reply_messages,
|
|
rank_initial_prompts_messages,
|
|
rank_prompter_reply_messages,
|
|
task_complete_embed,
|
|
)
|
|
from bot.settings import Settings
|
|
from loguru import logger
|
|
from oasst_shared.api_client import OasstApiClient
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
|
|
plugin = lightbulb.Plugin("WorkPlugin")
|
|
|
|
MAX_TASK_TIME = 60 * 60 # seconds
|
|
MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
|
|
|
|
settings = Settings()
|
|
|
|
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
|
|
|
|
|
|
class _TaskHandler(t.Generic[_Task_contra]):
|
|
"""Handle user interaction for a task."""
|
|
|
|
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
|
|
"""Create a new `TaskHandler`.
|
|
|
|
Args:
|
|
ctx (lightbulb.Context): The context of the command that started the task.
|
|
task (_Task_contra): The task to handle.
|
|
"""
|
|
self.ctx = ctx
|
|
self.task = task
|
|
self.task_messages = self.get_task_messages(task)
|
|
self.sent_messages: list[hikari.Message] = []
|
|
|
|
@staticmethod
|
|
def get_task_messages(task: _Task_contra) -> list[str]:
|
|
"""Get the messages to send to the user for the task."""
|
|
raise NotImplementedError
|
|
|
|
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
|
|
"""Send the task and wait for the user to accept/skip/cancel it."""
|
|
# Send all but the last message because we need to attach buttons to the last one
|
|
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
|
|
for task_msg in self.task_messages[:-1]:
|
|
if len(task_msg) > 2000:
|
|
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
|
|
task_msg = task_msg[:1999]
|
|
self.sent_messages.append(await self.ctx.author.send(task_msg))
|
|
|
|
# Send the last message with buttons
|
|
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
|
|
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
|
|
|
|
await task_accept_view.start(last_msg)
|
|
await task_accept_view.wait()
|
|
|
|
return task_accept_view.choice
|
|
|
|
async def handle(self) -> None:
|
|
"""Handle the user's response to the task.
|
|
|
|
This method should be called after `send` has been called."""
|
|
# Ack task to the backend
|
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
|
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
|
|
|
|
# Loop until the user's input is accepted
|
|
while True:
|
|
try:
|
|
# Wait for user to send a message
|
|
event = await self.ctx.bot.wait_for(
|
|
hikari.DMMessageCreateEvent,
|
|
predicate=lambda e: (
|
|
e.author_id == self.ctx.author.id
|
|
and e.message.content is not None
|
|
and not e.message.content.startswith(settings.prefix)
|
|
),
|
|
timeout=MAX_TASK_TIME,
|
|
)
|
|
|
|
# Validate the message
|
|
if event.content is None or not self.check_user_input(event.content):
|
|
await self.ctx.author.send("Invalid input")
|
|
continue
|
|
|
|
# Confirm user input
|
|
if not (await self.confirm_user_input(event.content)):
|
|
continue
|
|
|
|
# Message is valid and confirmed by user
|
|
break
|
|
|
|
except asyncio.TimeoutError:
|
|
return
|
|
|
|
next_task = await self.notify(event.content, event)
|
|
if not isinstance(next_task, protocol_schema.TaskDone):
|
|
raise TypeError(f"Unknown task type: {next_task!r}")
|
|
|
|
return
|
|
|
|
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
|
"""Notify the backend that the user completed the task."""
|
|
raise NotImplementedError
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
|
|
raise NotImplementedError
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
"""Check the user's response to the task. Returns True if the response is valid."""
|
|
raise NotImplementedError
|
|
|
|
async def cancel(self, reason: str = "not specified") -> None:
|
|
"""Cancel the task."""
|
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
|
await oasst_api.nack_task(self.task.id, reason)
|
|
|
|
|
|
_Ranking_contra = t.TypeVar(
|
|
"_Ranking_contra",
|
|
bound=protocol_schema.RankAssistantRepliesTask
|
|
| protocol_schema.RankInitialPromptsTask
|
|
| protocol_schema.RankPrompterRepliesTask
|
|
| protocol_schema.RankConversationRepliesTask,
|
|
contravariant=True,
|
|
)
|
|
|
|
|
|
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
|
|
"""This should not be used directly. Use its subclasses instead."""
|
|
|
|
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
|
|
|
task = await oasst_api.post_interaction(
|
|
protocol_schema.MessageRanking(
|
|
user=protocol_schema.User(
|
|
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
|
|
),
|
|
ranking=[int(r) - 1 for r in content.split(",")],
|
|
message_id=f"{self.sent_messages[0].id}",
|
|
)
|
|
)
|
|
|
|
db: Connection = self.ctx.bot.d.db
|
|
async with db.cursor() as cursor:
|
|
row = await (
|
|
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
|
|
).fetchone()
|
|
log_channel = row[0] if row else None
|
|
log_messages: list[hikari.Message] = []
|
|
|
|
if log_channel is not None:
|
|
for message in self.task_messages[:-1]:
|
|
msg = await self.ctx.bot.rest.create_message(log_channel, message)
|
|
log_messages.append(msg)
|
|
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
|
|
|
|
return task
|
|
|
|
|
|
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
|
return rank_assistant_reply_message(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
|
)
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(
|
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
|
)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
|
|
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
|
|
super().__init__(ctx, task)
|
|
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
|
return rank_initial_prompts_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content.split(",")) == len(self.task.prompt_messages) and all(
|
|
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
|
|
)
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(
|
|
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
|
|
)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
|
return rank_prompter_reply_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
|
)
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(
|
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
|
)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
|
return rank_conversation_reply_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
|
)
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(
|
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
|
)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
|
return initial_prompt_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content) > 0
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
|
return prompter_reply_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content) > 0
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
|
return assistant_reply_messages(task)
|
|
|
|
def check_user_input(self, content: str) -> bool:
|
|
return len(content) > 0
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
|
|
|
|
|
|
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
|
|
def check_user_input(self, content: str) -> bool:
|
|
user_labels = content.split(",")
|
|
return (
|
|
all([l in self.task.valid_labels for l in user_labels])
|
|
and self.task.mandatory_labels is not None
|
|
and all([m in user_labels for m in self.task.mandatory_labels])
|
|
)
|
|
|
|
async def confirm_user_input(self, content: str) -> bool:
|
|
confirm_input_view = YesNoView()
|
|
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
|
|
await confirm_input_view.start(msg)
|
|
await confirm_input_view.wait()
|
|
|
|
return bool(confirm_input_view.choice)
|
|
|
|
|
|
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
|
return label_assistant_reply_messages(task)
|
|
|
|
|
|
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
|
|
@staticmethod
|
|
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
|
return label_prompter_reply_messages(task)
|
|
|
|
|
|
summarize_story = "summarize_story"
|
|
rate_summary = "rate_summary"
|
|
|
|
|
|
@plugin.command
|
|
@lightbulb.command("work", "Complete a task.")
|
|
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
|
async def work2(ctx: lightbulb.Context) -> None:
|
|
"""Complete a task."""
|
|
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
|
|
|
|
# Check if the user is already working on a task
|
|
if ctx.author.id in currently_working:
|
|
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
msg = await ctx.author.send(
|
|
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
|
|
flags=hikari.MessageFlag.EPHEMERAL,
|
|
components=yn_view,
|
|
)
|
|
await yn_view.start(msg)
|
|
await yn_view.wait()
|
|
|
|
match yn_view.choice:
|
|
case False | None:
|
|
return
|
|
case True:
|
|
task_id = currently_working[ctx.author.id]
|
|
await oasst_api.nack_task(task_id, reason="user cancelled")
|
|
|
|
if ctx.guild_id:
|
|
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
|
|
|
# Keep sending tasks until the user doesn't want more
|
|
try:
|
|
while True:
|
|
task = await oasst_api.fetch_random_task(
|
|
user=protocol_schema.User(
|
|
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
|
|
),
|
|
)
|
|
|
|
# Ranking tasks
|
|
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
|
|
task_handler = RankAssistantRepliesHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
|
task_handler = RankInitialPromptHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
|
|
task_handler = RankPrompterReplyHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
|
|
task_handler = RankConversationReplyHandler(ctx, task)
|
|
|
|
# Text input tasks
|
|
elif isinstance(task, protocol_schema.InitialPromptTask):
|
|
task_handler = InitialPromptHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.PrompterReplyTask):
|
|
task_handler = PrompterReplyHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.AssistantReplyTask):
|
|
task_handler = AssistantReplyHandler(ctx, task)
|
|
|
|
# Label tasks
|
|
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
|
|
task_handler = LabelAssistantReplyHandler(ctx, task)
|
|
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
|
|
task_handler = LabelPrompterReplyHandler(ctx, task)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown task type: {type(task)}")
|
|
|
|
resp = await task_handler.send()
|
|
|
|
match resp:
|
|
case "accept":
|
|
currently_working[ctx.author.id] = task.id
|
|
await task_handler.handle()
|
|
case "next":
|
|
await task_handler.cancel("user skipped task")
|
|
case "cancel":
|
|
await task_handler.cancel("user canceled work")
|
|
break
|
|
case None:
|
|
await task_handler.cancel("select timed out")
|
|
break
|
|
finally:
|
|
del currently_working[ctx.author.id]
|
|
|
|
|
|
class TaskAcceptView(miru.View):
|
|
"""View with three buttons: accept, next, and cancel.
|
|
|
|
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
|
"""
|
|
|
|
choice: t.Literal["accept", "next", "cancel"] | None = None
|
|
|
|
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
|
|
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Accept button pressed")
|
|
self.choice = "accept"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
|
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Next button pressed")
|
|
self.choice = "next"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
|
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Cancel button pressed")
|
|
self.choice = "cancel"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
async def on_timeout(self) -> None:
|
|
if self.message is not None:
|
|
await self.message.edit(component=None)
|
|
|
|
|
|
class YesNoView(miru.View):
|
|
"""View with two buttons: yes and no.
|
|
|
|
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
|
"""
|
|
|
|
choice: bool | None = None
|
|
|
|
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
|
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
self.choice = True
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
|
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
self.choice = False
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
async def on_timeout(self) -> None:
|
|
if self.message is not None:
|
|
await self.message.edit(component=None)
|
|
|
|
|
|
def load(bot: lightbulb.BotApp):
|
|
"""Add the plugin to the bot."""
|
|
bot.add_plugin(plugin)
|
|
|
|
|
|
def unload(bot: lightbulb.BotApp):
|
|
"""Remove the plugin to the bot."""
|
|
bot.remove_plugin(plugin)
|