Feat/task handler (#1056)

* move task logic to task handlers

* rename command

* remove test code

* rename classes and add missing handler creations

* switch task back to random and fetch log_channel_id from db
This commit is contained in:
Alex Ott
2023-02-02 08:28:00 -08:00
committed by GitHub
parent 2db3450e9a
commit dfd2c35276
4 changed files with 526 additions and 465 deletions
+1 -1
View File
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
OASST_API_URL="http://localhost:8080" # No trailing '/' OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY="" OASST_API_KEY="1234"
+388 -391
View File
@@ -9,27 +9,25 @@ import lightbulb.decorators
import miru import miru
from aiosqlite import Connection from aiosqlite import Connection
from bot.messages import ( from bot.messages import (
assistant_reply_message, assistant_reply_messages,
confirm_label_response_message, confirm_label_response_message,
confirm_ranking_response_message, confirm_ranking_response_message,
confirm_text_response_message, confirm_text_response_message,
initial_prompt_message, initial_prompt_messages,
invalid_user_input_embed, label_assistant_reply_messages,
label_assistant_reply_message, label_prompter_reply_messages,
label_initial_prompt_message,
label_prompter_reply_message,
plain_embed, plain_embed,
prompter_reply_message, prompter_reply_messages,
rank_assistant_reply_message, rank_assistant_reply_message,
rank_initial_prompts_message, rank_conversation_reply_messages,
rank_prompter_reply_message, rank_initial_prompts_messages,
rank_prompter_reply_messages,
task_complete_embed, task_complete_embed,
) )
from bot.settings import Settings from bot.settings import Settings
from loguru import logger from loguru import logger
from oasst_shared.api_client import OasstApiClient, TaskType from oasst_shared.api_client import OasstApiClient
from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin") plugin = lightbulb.Plugin("WorkPlugin")
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
settings = Settings() settings = Settings()
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
class _TaskHandler(t.Generic[_Task_contra]):
"""Handle user interaction for a task."""
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
"""Create a new `TaskHandler`.
Args:
ctx (lightbulb.Context): The context of the command that started the task.
task (_Task_contra): The task to handle.
"""
self.ctx = ctx
self.task = task
self.task_messages = self.get_task_messages(task)
self.sent_messages: list[hikari.Message] = []
@staticmethod
def get_task_messages(task: _Task_contra) -> list[str]:
"""Get the messages to send to the user for the task."""
raise NotImplementedError
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
"""Send the task and wait for the user to accept/skip/cancel it."""
# Send all but the last message because we need to attach buttons to the last one
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
for task_msg in self.task_messages[:-1]:
if len(task_msg) > 2000:
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
task_msg = task_msg[:1999]
self.sent_messages.append(await self.ctx.author.send(task_msg))
# Send the last message with buttons
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
await task_accept_view.start(last_msg)
await task_accept_view.wait()
return task_accept_view.choice
async def handle(self) -> None:
"""Handle the user's response to the task.
This method should be called after `send` has been called."""
# Ack task to the backend
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
# Loop until the user's input is accepted
while True:
try:
# Wait for user to send a message
event = await self.ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
predicate=lambda e: (
e.author_id == self.ctx.author.id
and e.message.content is not None
and not e.message.content.startswith(settings.prefix)
),
timeout=MAX_TASK_TIME,
)
# Validate the message
if event.content is None or not self.check_user_input(event.content):
await self.ctx.author.send("Invalid input")
continue
# Confirm user input
if not (await self.confirm_user_input(event.content)):
continue
# Message is valid and confirmed by user
break
except asyncio.TimeoutError:
return
next_task = await self.notify(event.content, event)
if not isinstance(next_task, protocol_schema.TaskDone):
raise TypeError(f"Unknown task type: {next_task!r}")
return
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
"""Notify the backend that the user completed the task."""
raise NotImplementedError
async def confirm_user_input(self, content: str) -> bool:
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
raise NotImplementedError
def check_user_input(self, content: str) -> bool:
"""Check the user's response to the task. Returns True if the response is valid."""
raise NotImplementedError
async def cancel(self, reason: str = "not specified") -> None:
"""Cancel the task."""
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
await oasst_api.nack_task(self.task.id, reason)
_Ranking_contra = t.TypeVar(
"_Ranking_contra",
bound=protocol_schema.RankAssistantRepliesTask
| protocol_schema.RankInitialPromptsTask
| protocol_schema.RankPrompterRepliesTask
| protocol_schema.RankConversationRepliesTask,
contravariant=True,
)
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
"""This should not be used directly. Use its subclasses instead."""
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
task = await oasst_api.post_interaction(
protocol_schema.MessageRanking(
user=protocol_schema.User(
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
),
ranking=[int(r) - 1 for r in content.split(",")],
message_id=f"{self.sent_messages[0].id}",
)
)
db: Connection = self.ctx.bot.d.db
async with db.cursor() as cursor:
row = await (
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
).fetchone()
log_channel = row[0] if row else None
log_messages: list[hikari.Message] = []
if log_channel is not None:
for message in self.task_messages[:-1]:
msg = await self.ctx.bot.rest.create_message(log_channel, message)
log_messages.append(msg)
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
return task
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
return rank_assistant_reply_message(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
super().__init__(ctx, task)
@staticmethod
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
return rank_initial_prompts_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.prompt_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
return rank_prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
@staticmethod
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
return rank_conversation_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content.split(",")) == len(self.task.reply_messages) and all(
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
@staticmethod
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
return initial_prompt_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
return prompter_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
return assistant_reply_messages(task)
def check_user_input(self, content: str) -> bool:
return len(content) > 0
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
def check_user_input(self, content: str) -> bool:
user_labels = content.split(",")
return (
all([l in self.task.valid_labels for l in user_labels])
and self.task.mandatory_labels is not None
and all([m in user_labels for m in self.task.mandatory_labels])
)
async def confirm_user_input(self, content: str) -> bool:
confirm_input_view = YesNoView()
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
await confirm_input_view.start(msg)
await confirm_input_view.wait()
return bool(confirm_input_view.choice)
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
return label_assistant_reply_messages(task)
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
@staticmethod
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
return label_prompter_reply_messages(task)
summarize_story = "summarize_story"
rate_summary = "rate_summary"
@plugin.command @plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.random),
type=str,
)
@lightbulb.command("work", "Complete a task.") @lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) @lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def work(ctx: lightbulb.Context): async def work2(ctx: lightbulb.Context) -> None:
"""Create and handle a task.""" """Complete a task."""
# Only send this message if started from a server
if ctx.guild_id is not None:
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
currently_working: dict[
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
] = ctx.bot.d.currently_working
oasst_api: OasstApiClient = ctx.bot.d.oasst_api oasst_api: OasstApiClient = ctx.bot.d.oasst_api
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
# Check if the user is already working on a task
if ctx.author.id in currently_working: if ctx.author.id in currently_working:
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME) yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send( msg = await ctx.author.send(
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
case False | None: case False | None:
return return
case True: case True:
old_msg, task_id = currently_working[ctx.author.id] task_id = currently_working[ctx.author.id]
if old_msg is not None: await oasst_api.nack_task(task_id, reason="user cancelled")
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
map(lambda c: c, old_msg.components)
await old_msg.delete()
if task_id is not None:
await oasst_api.nack_task(task_id, reason="user cancelled")
await msg.delete() if ctx.guild_id:
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
currently_working[ctx.author.id] = (None, None) # Keep sending tasks until the user doesn't want more
# Create a TaskRequestType from the stringified enum value
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
logger.debug(f"Starting task_type: {task_type!r}")
try: try:
await _handle_task(ctx, task_type) while True:
task = await oasst_api.fetch_random_task(
user=protocol_schema.User(
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
),
)
# Ranking tasks
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
task_handler = RankAssistantRepliesHandler(ctx, task)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
task_handler = RankInitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
task_handler = RankPrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
task_handler = RankConversationReplyHandler(ctx, task)
# Text input tasks
elif isinstance(task, protocol_schema.InitialPromptTask):
task_handler = InitialPromptHandler(ctx, task)
elif isinstance(task, protocol_schema.PrompterReplyTask):
task_handler = PrompterReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.AssistantReplyTask):
task_handler = AssistantReplyHandler(ctx, task)
# Label tasks
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
task_handler = LabelAssistantReplyHandler(ctx, task)
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
task_handler = LabelPrompterReplyHandler(ctx, task)
else:
raise ValueError(f"Unknown task type: {type(task)}")
resp = await task_handler.send()
match resp:
case "accept":
currently_working[ctx.author.id] = task.id
await task_handler.handle()
case "next":
await task_handler.cancel("user skipped task")
case "cancel":
await task_handler.cancel("user canceled work")
break
case None:
await task_handler.cancel("select timed out")
break
finally: finally:
del currently_working[ctx.author.id] del currently_working[ctx.author.id]
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
# User cancelled
return
# Task action loop
completed = False
while not completed:
await ctx.author.send(embed=plain_embed("Please type your response below:"))
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id
and not (e.message.content or "").startswith(settings.prefix),
)
except asyncio.TimeoutError:
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
valid, err_msg = _validate_user_input(event.content, task)
if not valid or event.content is None:
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
continue
logger.debug(f"Successful user input received: {event.content}")
# Confirm user input
if isinstance(task, protocol_schema.RankConversationRepliesTask):
content = confirm_ranking_response_message(event.content, task.replies)
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
content = confirm_ranking_response_message(event.content, task.prompts)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
content = confirm_label_response_message(event.content)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
content = confirm_text_response_message(event.content)
else:
logger.critical(f"Unknown task type: {task.type}")
raise ValueError(f"Unknown task type: {task.type}")
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
msg = await ctx.author.send(content, components=confirm_resp_view)
await confirm_resp_view.start(msg)
await confirm_resp_view.wait()
match confirm_resp_view.choice:
case False | None:
continue
case True:
await msg.delete() # buttons are already gone
# Send the response to the backend
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
reply = protocol_schema.MessageRanking(
message_id=str(msg_id),
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
labels = event.content.replace(" ", "").split(",")
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
reply = protocol_schema.TextLabels(
message_id=task.message_id,
labels=labels_dict,
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
)
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
else:
logger.critical(f"Unexpected task type received: {task.type}")
raise ValueError(f"Unexpected task type received: {task.type}")
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send(embed=plain_embed("Task completed"))
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in all the log channels that the task is complete
conn: Connection = ctx.bot.d.db
async with conn.cursor() as cursor:
await cursor.execute("SELECT log_channel_id FROM guild_settings")
log_channel_ids = await cursor.fetchall()
channels = [
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
for id in log_channel_ids
]
done_embed = task_complete_embed(task, ctx.author.mention)
# This will definitely get the bot rate limited, but that's a future problem
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
# ask the user if they want to do another task
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
await another_task_view.start(msg)
await another_task_view.wait()
match another_task_view.choice:
case False | None:
done = True
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
case True:
pass
async def _select_task(
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg = await _send_task(ctx, task, msg)
msg_id = str(msg.id)
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
return None, msg_id
async def _send_task(
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
"""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
content = initial_prompt_message(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
content = rank_initial_prompts_message(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
content = rank_prompter_reply_message(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
content = rank_assistant_reply_message(task)
elif task.type == TaskRequestType.label_initial_prompt:
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
logger.debug("sending label initial prompt task")
content = label_initial_prompt_message(task)
elif task.type == TaskRequestType.label_prompter_reply:
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
logger.debug("sending label prompter reply task")
content = label_prompter_reply_message(task)
elif task.type == TaskRequestType.label_assistant_reply:
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
logger.debug("sending label assistant reply task")
content = label_assistant_reply_message(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
content = prompter_reply_message(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
content = assistant_reply_message(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
if not msg:
msg = await ctx.author.send(
content,
embed=embed,
components=view,
)
else:
await msg.edit(
content,
embed=embed,
components=view,
)
assert msg is not None
# Set the choice id as the current msg id
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
await view.start(msg)
await view.wait()
return view.choice, msg
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
"""Returns whether the user's input is valid for the task type and an error message."""
if content is None:
return False, "No input provided"
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0, "Message must be at least one character long."
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
"Message must contain numbers for all replies.",
)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.replace(" ", "").split(",")
return (
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
"Message must contain numbers for all prompts.",
)
# Labels tasks
elif task.type in (
TaskRequestType.label_initial_prompt,
TaskRequestType.label_prompter_reply,
TaskRequestType.label_assistant_reply,
):
assert isinstance(
task,
protocol_schema.LabelInitialPromptTask
| protocol_schema.LabelPrompterReplyTask
| protocol_schema.LabelAssistantReplyTask,
)
labels = content.replace(" ", "").split(",")
valid_labels = set(task.valid_labels)
return (
set(labels).issubset(valid_labels),
"Message must only contain labels from predefined set of labels.",
)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View): class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel. """View with three buttons: accept, next, and cancel.
+129 -69
View File
@@ -1,4 +1,11 @@
"""All user-facing messages and embeds.""" """All user-facing messages and embeds.
When sending a conversation
- The function will return a list of strings
- use asyncio.gather to send all messages
-
"""
from datetime import datetime from datetime import datetime
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_" return f":trophy: _{text}_"
def _label_prompt(text: str) -> str: def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
return f":question: _{text}" return f""":question: _{text}_
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
Valid labels: {", ".join(valid_labels)}
"""
def _response_prompt(text: str) -> str: def _response_prompt(text: str) -> str:
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
""" """
def _make_ordered_list(items: list[str]) -> list[str]: def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)] return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[str]) -> str: def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
return "\n\n".join(_make_ordered_list(items)) return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> str: def _conversation(conv: protocol_schema.Conversation) -> list[str]:
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages]) # return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
messages = map(
lambda m: f"""\
def _hint(hint: str | None) -> str: :robot: __Assistant__:
return f"{NL}Hint: {hint}" if hint else "" {m.text}
"""
if m.is_assistant
else f"""\
:person_red_hair: __User__:
{m.text}
""",
conv.messages,
)
return list(messages)
def _li(text: str) -> str: def _li(text: str) -> str:
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
### ###
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str: def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
"""Creates the message that gets sent to users when they request an `initial_prompt` task.""" """Creates the message that gets sent to users when they request an `initial_prompt` task."""
return f"""\ return [
f"""\
{_h1("INITIAL PROMPT")} :small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
{_writing_prompt("Please provide an initial prompt to the assistant.")} :pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
{_hint(task.hint)}
""" """
]
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str: def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task.""" """Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return f"""\ return [
f"""\
{_h1("RANK INITIAL PROMPTS")} :small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
{_ordered_list(task.prompts)} {_ordered_list(task.prompt_messages)}
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")} :trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
""" """
]
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str: def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task.""" """Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return f"""\ return [
"""\
{_h1("RANK PROMPTER REPLIES")} :small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)} def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
{_user(None)}
{_ordered_list(task.replies)}
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
"""
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task.""" """Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return f"""\ return [
"""\
{_h1("RANK ASSISTANT REPLIES")} :small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":robot: __Assistant__:,
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
{_conversation(task.conversation)} def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
{_assistant(None)} """Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
{_ordered_list(task.replies)} return [
"""\
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")} :small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
"""
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
""",
]
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str: def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
{task.prompt} {task.prompt}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")} {_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""" """
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str: def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task.""" """Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
return f"""\ return [
f"""\
{_h1("LABEL PROMPTER REPLY")} {_h1("LABEL PROMPTER REPLY")}
{_conversation(task.conversation)} """,
{_user(None)} *_conversation(task.conversation),
f"""{_user(None)}
{task.reply} {task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} {_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""" """,
]
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str: def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task.""" """Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
return f"""\ return [
f"""\
{_h1("LABEL ASSISTANT REPLY")} {_h1("LABEL ASSISTANT REPLY")}
{_conversation(task.conversation)} """,
*_conversation(task.conversation),
f"""
{_assistant(None)} {_assistant(None)}
{task.reply} {task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")} {_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""" """,
]
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str: def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `prompter_reply` task.""" """Creates the message that gets sent to users when they request a `prompter_reply` task."""
return f"""\ return [
"""\
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
{_h1("PROMPTER REPLY")} """,
*_conversation(task.conversation),
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
:speech_balloon: _Please provide a reply to the assistant._
""",
]
{_conversation(task.conversation)} # def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
{_hint(task.hint)} # """Creates the message that gets sent to users when they request a `prompter_reply` task."""
# return [
{_response_prompt("Please provide a reply to the assistant.")} # message_templates.render("title.msg", "PROMPTER REPLY"),
""" # *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
# message_templates.render("prompter_reply_task.msg", task.hint),
# ]
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str: def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `assistant_reply` task.""" """Creates the message that gets sent to users when they request a `assistant_reply` task."""
return f"""\ return [
{_h1("ASSISTANT REPLY")} """\
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
"""\
{_conversation(task.conversation)} :speech_balloon: _Please provide a reply to the user as the assistant._
""",
{_response_prompt("Please provide an assistant reply to the prompter.")} ]
"""
def confirm_text_response_message(content: str) -> str: def confirm_text_response_message(content: str) -> str:
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
""" """
def confirm_ranking_response_message(content: str, items: list[str]) -> str: def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")] user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items) original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings]) user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
+8 -4
View File
@@ -68,12 +68,15 @@ class OasstApiClient:
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]: async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
"""Make a POST request to the backend.""" """Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}") logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
logger.debug(f"response: {response}")
# If the response is not a 2XX, check to see # If the response is not a 2XX, check to see
# if the json has the fields to create an # if the json has the fields to create an
# OasstError. # OasstError.
if response.status >= 300: if response.status >= 300:
text = await response.text()
logger.debug(f"resp text: {text}")
data = await response.json() data = await response.json()
try: try:
oasst_error = protocol_schema.OasstErrorResponse(**(data or {})) oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
@@ -114,20 +117,21 @@ class OasstApiClient:
task_type: protocol_schema.TaskRequestType, task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None, user: Optional[protocol_schema.User] = None,
collective: bool = False, collective: bool = False,
lang: Optional[str] = None,
) -> protocol_schema.Task: ) -> protocol_schema.Task:
"""Fetch a task from the backend.""" """Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}") logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective) req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
resp = await self.post("/api/v1/tasks/", data=req.dict()) resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"RESP {resp}") logger.debug(f"RESP {resp}")
return self._parse_task(resp) return self._parse_task(resp)
async def fetch_random_task( async def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
) -> protocol_schema.Task: ) -> protocol_schema.Task:
"""Fetch a random task from the backend.""" """Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}") logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective) return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
async def ack_task(self, task_id: str | UUID, message_id: str) -> None: async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
"""Send an ACK for a task to the backend.""" """Send an ACK for a task to the backend."""