mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
69bc799cd9
* Implement label task for initial prompts and replies * Resolve formatting * Include missing argument * Modify text_labels API to match new model, update DB schema accordingly * Send valid labels as part of label tasks * Send correctly formatted valid_labels list * Fix request format * Fix request details for text-frontend reply label task * Include message_id in tasks * Address review comments * Fix alembic tree
514 lines
20 KiB
Python
514 lines
20 KiB
Python
"""Work plugin for collecting user data."""
|
|
import asyncio
|
|
import typing as t
|
|
from uuid import UUID
|
|
|
|
import hikari
|
|
import lightbulb
|
|
import lightbulb.decorators
|
|
import miru
|
|
from aiosqlite import Connection
|
|
from bot.messages import (
|
|
assistant_reply_message,
|
|
confirm_label_response_message,
|
|
confirm_ranking_response_message,
|
|
confirm_text_response_message,
|
|
initial_prompt_message,
|
|
invalid_user_input_embed,
|
|
label_assistant_reply_message,
|
|
label_initial_prompt_message,
|
|
label_prompter_reply_message,
|
|
plain_embed,
|
|
prompter_reply_message,
|
|
rank_assistant_reply_message,
|
|
rank_initial_prompts_message,
|
|
rank_prompter_reply_message,
|
|
task_complete_embed,
|
|
)
|
|
from bot.settings import Settings
|
|
from loguru import logger
|
|
from oasst_shared.api_client import OasstApiClient, TaskType
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from oasst_shared.schemas.protocol import TaskRequestType
|
|
|
|
plugin = lightbulb.Plugin("WorkPlugin")
|
|
|
|
MAX_TASK_TIME = 60 * 60 # seconds
|
|
MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
|
|
|
|
settings = Settings()
|
|
|
|
|
|
@plugin.command
|
|
@lightbulb.option(
|
|
"type",
|
|
"The type of task to request.",
|
|
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
|
required=False,
|
|
default=str(TaskRequestType.random),
|
|
type=str,
|
|
)
|
|
@lightbulb.command("work", "Complete a task.")
|
|
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
|
async def work(ctx: lightbulb.Context):
|
|
"""Create and handle a task."""
|
|
# Only send this message if started from a server
|
|
if ctx.guild_id is not None:
|
|
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
|
|
|
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
|
currently_working: dict[
|
|
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
|
] = ctx.bot.d.currently_working
|
|
|
|
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
if ctx.author.id in currently_working:
|
|
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
msg = await ctx.author.send(
|
|
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
|
|
flags=hikari.MessageFlag.EPHEMERAL,
|
|
components=yn_view,
|
|
)
|
|
await yn_view.start(msg)
|
|
await yn_view.wait()
|
|
|
|
match yn_view.choice:
|
|
case False | None:
|
|
return
|
|
case True:
|
|
old_msg, task_id = currently_working[ctx.author.id]
|
|
if old_msg is not None:
|
|
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
|
map(lambda c: c, old_msg.components)
|
|
await old_msg.delete()
|
|
if task_id is not None:
|
|
await oasst_api.nack_task(task_id, reason="user cancelled")
|
|
|
|
await msg.delete()
|
|
|
|
currently_working[ctx.author.id] = (None, None)
|
|
|
|
# Create a TaskRequestType from the stringified enum value
|
|
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
|
|
|
logger.debug(f"Starting task_type: {task_type!r}")
|
|
try:
|
|
await _handle_task(ctx, task_type)
|
|
finally:
|
|
del currently_working[ctx.author.id]
|
|
|
|
|
|
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
|
"""Handle creating and collecting user input for a task.
|
|
|
|
Continually present tasks to the user until they select one, cancel, or time out.
|
|
If they select one, present the task steps until a `task_done` task is received.
|
|
Finally, ask the user if they want to perform another task (of the same type).
|
|
"""
|
|
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
|
|
# Continue to complete tasks until the user doesn't want to do another
|
|
done = False
|
|
while not done:
|
|
|
|
# Loop until the user accepts a task
|
|
task, msg_id = await _select_task(ctx, task_type)
|
|
|
|
if task is None:
|
|
# User cancelled
|
|
return
|
|
|
|
# Task action loop
|
|
completed = False
|
|
while not completed:
|
|
await ctx.author.send(embed=plain_embed("Please type your response below:"))
|
|
try:
|
|
event = await ctx.bot.wait_for(
|
|
hikari.DMMessageCreateEvent,
|
|
timeout=MAX_TASK_TIME,
|
|
predicate=lambda e: e.author.id == ctx.author.id
|
|
and not (e.message.content or "").startswith(settings.prefix),
|
|
)
|
|
except asyncio.TimeoutError:
|
|
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
|
await oasst_api.nack_task(task.id, reason="timed out")
|
|
logger.info(f"Task {task.id} timed out")
|
|
return
|
|
|
|
# Invalid response
|
|
valid, err_msg = _validate_user_input(event.content, task)
|
|
if not valid or event.content is None:
|
|
|
|
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
|
continue
|
|
|
|
logger.debug(f"Successful user input received: {event.content}")
|
|
|
|
# Confirm user input
|
|
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
|
content = confirm_ranking_response_message(event.content, task.replies)
|
|
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
|
content = confirm_ranking_response_message(event.content, task.prompts)
|
|
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
|
content = confirm_label_response_message(event.content)
|
|
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
|
content = confirm_text_response_message(event.content)
|
|
else:
|
|
logger.critical(f"Unknown task type: {task.type}")
|
|
raise ValueError(f"Unknown task type: {task.type}")
|
|
|
|
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
|
msg = await ctx.author.send(content, components=confirm_resp_view)
|
|
await confirm_resp_view.start(msg)
|
|
await confirm_resp_view.wait()
|
|
|
|
match confirm_resp_view.choice:
|
|
case False | None:
|
|
continue
|
|
case True:
|
|
await msg.delete() # buttons are already gone
|
|
|
|
# Send the response to the backend
|
|
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
|
reply = protocol_schema.MessageRanking(
|
|
message_id=str(msg_id),
|
|
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
|
user=protocol_schema.User(
|
|
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
),
|
|
)
|
|
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
|
labels = event.content.replace(" ", "").split(",")
|
|
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
|
|
|
|
reply = protocol_schema.TextLabels(
|
|
message_id=task.message_id,
|
|
labels=labels_dict,
|
|
user=protocol_schema.User(
|
|
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
),
|
|
)
|
|
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
|
reply = protocol_schema.TextReplyToMessage(
|
|
message_id=str(msg_id),
|
|
user_message_id=str(event.message_id),
|
|
user=protocol_schema.User(
|
|
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
),
|
|
text=event.content,
|
|
)
|
|
else:
|
|
logger.critical(f"Unexpected task type received: {task.type}")
|
|
raise ValueError(f"Unexpected task type received: {task.type}")
|
|
|
|
logger.debug(f"Sending reply to backend: {reply!r}")
|
|
|
|
# Get next task
|
|
new_task = await oasst_api.post_interaction(reply)
|
|
logger.info(f"New task {new_task}")
|
|
|
|
if new_task.type == TaskType.done:
|
|
await ctx.author.send(embed=plain_embed("Task completed"))
|
|
completed = True
|
|
continue
|
|
else:
|
|
logger.critical(f"Unexpected task type received: {new_task.type}")
|
|
|
|
# Send a message in all the log channels that the task is complete
|
|
conn: Connection = ctx.bot.d.db
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
|
log_channel_ids = await cursor.fetchall()
|
|
|
|
channels = [
|
|
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
|
for id in log_channel_ids
|
|
]
|
|
|
|
done_embed = task_complete_embed(task, ctx.author.mention)
|
|
# This will definitely get the bot rate limited, but that's a future problem
|
|
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
|
|
|
# ask the user if they want to do another task
|
|
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
|
await another_task_view.start(msg)
|
|
await another_task_view.wait()
|
|
|
|
match another_task_view.choice:
|
|
case False | None:
|
|
done = True
|
|
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
|
case True:
|
|
pass
|
|
|
|
|
|
async def _select_task(
|
|
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
|
) -> tuple[protocol_schema.Task | None, str]:
|
|
"""Present tasks to the user until they accept one, cancel, or time out."""
|
|
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
logger.debug(f"Starting task selection for {task_type}")
|
|
|
|
# Loop until the user accepts a task, cancels, or times out
|
|
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
|
while True:
|
|
logger.debug(f"Requesting task of type {task_type}")
|
|
task = await oasst_api.fetch_task(task_type, user)
|
|
resp, msg = await _send_task(ctx, task, msg)
|
|
msg_id = str(msg.id)
|
|
|
|
logger.debug(f"User choice: {resp}")
|
|
match resp:
|
|
case "accept":
|
|
logger.info(f"Task {task.id} accepted, sending ACK")
|
|
await oasst_api.ack_task(task.id, msg_id)
|
|
return task, msg_id
|
|
|
|
case "next":
|
|
logger.info(f"Task {task.id} rejected, sending NACK")
|
|
await oasst_api.nack_task(task.id, "rejected")
|
|
continue
|
|
|
|
case "cancel":
|
|
logger.info(f"Task {task.id} canceled, sending NACK")
|
|
await oasst_api.nack_task(task.id, "canceled")
|
|
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
|
return None, msg_id
|
|
|
|
case None:
|
|
logger.info(f"Task {task.id} timed out, sending NACK")
|
|
await oasst_api.nack_task(task.id, "timed out")
|
|
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
|
return None, msg_id
|
|
|
|
|
|
async def _send_task(
|
|
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
|
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
|
"""Send a task to the user.
|
|
|
|
Returns the user's choice and the message ID of the task message.
|
|
"""
|
|
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
|
# but the tasks aren't discord specific so that doesn't really make sense.
|
|
|
|
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
|
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
|
|
|
# Create an embed based on the task's type
|
|
if task.type == TaskRequestType.initial_prompt:
|
|
assert isinstance(task, protocol_schema.InitialPromptTask)
|
|
logger.debug("sending initial prompt task")
|
|
content = initial_prompt_message(task)
|
|
|
|
elif task.type == TaskRequestType.rank_initial_prompts:
|
|
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
|
logger.debug("sending rank initial prompt task")
|
|
content = rank_initial_prompts_message(task)
|
|
|
|
elif task.type == TaskRequestType.rank_prompter_replies:
|
|
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
|
logger.debug("sending rank user reply task")
|
|
content = rank_prompter_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.rank_assistant_replies:
|
|
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
|
logger.debug("sending rank assistant reply task")
|
|
content = rank_assistant_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.label_initial_prompt:
|
|
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
|
|
logger.debug("sending label initial prompt task")
|
|
content = label_initial_prompt_message(task)
|
|
|
|
elif task.type == TaskRequestType.label_prompter_reply:
|
|
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
|
|
logger.debug("sending label prompter reply task")
|
|
content = label_prompter_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.label_assistant_reply:
|
|
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
|
|
logger.debug("sending label assistant reply task")
|
|
content = label_assistant_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.prompter_reply:
|
|
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
|
logger.debug("sending user reply task")
|
|
content = prompter_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.assistant_reply:
|
|
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
|
logger.debug("sending assistant reply task")
|
|
content = assistant_reply_message(task)
|
|
|
|
elif task.type == TaskRequestType.summarize_story:
|
|
raise NotImplementedError
|
|
elif task.type == TaskRequestType.rate_summary:
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
logger.critical(f"unknown task type {task.type}")
|
|
raise ValueError(f"unknown task type {task.type}")
|
|
|
|
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
if not msg:
|
|
msg = await ctx.author.send(
|
|
content,
|
|
embed=embed,
|
|
components=view,
|
|
)
|
|
else:
|
|
await msg.edit(
|
|
content,
|
|
embed=embed,
|
|
components=view,
|
|
)
|
|
|
|
assert msg is not None
|
|
|
|
# Set the choice id as the current msg id
|
|
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
|
|
|
await view.start(msg)
|
|
await view.wait()
|
|
|
|
return view.choice, msg
|
|
|
|
|
|
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
|
"""Returns whether the user's input is valid for the task type and an error message."""
|
|
if content is None:
|
|
return False, "No input provided"
|
|
|
|
# User message input
|
|
if (
|
|
task.type == TaskRequestType.initial_prompt
|
|
or task.type == TaskRequestType.prompter_reply
|
|
or task.type == TaskRequestType.assistant_reply
|
|
):
|
|
assert isinstance(
|
|
task,
|
|
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
|
)
|
|
return len(content) > 0, "Message must be at least one character long."
|
|
|
|
# Ranking tasks
|
|
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
|
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
|
num_replies = len(task.replies)
|
|
|
|
rankings = content.replace(" ", "").split(",")
|
|
return (
|
|
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
|
"Message must contain numbers for all replies.",
|
|
)
|
|
|
|
elif task.type == TaskRequestType.rank_initial_prompts:
|
|
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
|
num_prompts = len(task.prompts)
|
|
|
|
rankings = content.replace(" ", "").split(",")
|
|
return (
|
|
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
|
"Message must contain numbers for all prompts.",
|
|
)
|
|
|
|
# Labels tasks
|
|
elif task.type in (
|
|
TaskRequestType.label_initial_prompt,
|
|
TaskRequestType.label_prompter_reply,
|
|
TaskRequestType.label_assistant_reply,
|
|
):
|
|
assert isinstance(
|
|
task,
|
|
protocol_schema.LabelInitialPromptTask
|
|
| protocol_schema.LabelPrompterReplyTask
|
|
| protocol_schema.LabelAssistantReplyTask,
|
|
)
|
|
|
|
labels = content.replace(" ", "").split(",")
|
|
valid_labels = set(task.valid_labels)
|
|
return (
|
|
set(labels).issubset(valid_labels),
|
|
"Message must only contain labels from predefined set of labels.",
|
|
)
|
|
|
|
elif task.type == TaskRequestType.summarize_story:
|
|
raise NotImplementedError
|
|
elif task.type == TaskRequestType.rate_summary:
|
|
raise NotImplementedError
|
|
|
|
else:
|
|
logger.critical(f"Unknown task type {task.type}")
|
|
raise ValueError(f"Unknown task type {task.type}")
|
|
|
|
|
|
class TaskAcceptView(miru.View):
|
|
"""View with three buttons: accept, next, and cancel.
|
|
|
|
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
|
"""
|
|
|
|
choice: t.Literal["accept", "next", "cancel"] | None = None
|
|
|
|
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
|
|
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Accept button pressed")
|
|
self.choice = "accept"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
|
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Next button pressed")
|
|
self.choice = "next"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
|
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
logger.info("Cancel button pressed")
|
|
self.choice = "cancel"
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
async def on_timeout(self) -> None:
|
|
if self.message is not None:
|
|
await self.message.edit(component=None)
|
|
|
|
|
|
class YesNoView(miru.View):
|
|
"""View with two buttons: yes and no.
|
|
|
|
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
|
"""
|
|
|
|
choice: bool | None = None
|
|
|
|
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
|
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
self.choice = True
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
|
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
|
self.choice = False
|
|
await ctx.message.edit(component=None)
|
|
self.stop()
|
|
|
|
async def on_timeout(self) -> None:
|
|
if self.message is not None:
|
|
await self.message.edit(component=None)
|
|
|
|
|
|
def load(bot: lightbulb.BotApp):
|
|
"""Add the plugin to the bot."""
|
|
bot.add_plugin(plugin)
|
|
|
|
|
|
def unload(bot: lightbulb.BotApp):
|
|
"""Remove the plugin to the bot."""
|
|
bot.remove_plugin(plugin)
|