mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Feat/task handler (#1056)
* move task logic to task handlers * rename command * remove test code * rename classes and add missing handler creations * switch task back to random and fetch log_channel_id from db
This commit is contained in:
@@ -4,4 +4,4 @@ OWNER_IDS=[<your user id>, <other user ids>]
|
|||||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||||
|
|
||||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||||
OASST_API_KEY=""
|
OASST_API_KEY="1234"
|
||||||
|
|||||||
+388
-391
@@ -9,27 +9,25 @@ import lightbulb.decorators
|
|||||||
import miru
|
import miru
|
||||||
from aiosqlite import Connection
|
from aiosqlite import Connection
|
||||||
from bot.messages import (
|
from bot.messages import (
|
||||||
assistant_reply_message,
|
assistant_reply_messages,
|
||||||
confirm_label_response_message,
|
confirm_label_response_message,
|
||||||
confirm_ranking_response_message,
|
confirm_ranking_response_message,
|
||||||
confirm_text_response_message,
|
confirm_text_response_message,
|
||||||
initial_prompt_message,
|
initial_prompt_messages,
|
||||||
invalid_user_input_embed,
|
label_assistant_reply_messages,
|
||||||
label_assistant_reply_message,
|
label_prompter_reply_messages,
|
||||||
label_initial_prompt_message,
|
|
||||||
label_prompter_reply_message,
|
|
||||||
plain_embed,
|
plain_embed,
|
||||||
prompter_reply_message,
|
prompter_reply_messages,
|
||||||
rank_assistant_reply_message,
|
rank_assistant_reply_message,
|
||||||
rank_initial_prompts_message,
|
rank_conversation_reply_messages,
|
||||||
rank_prompter_reply_message,
|
rank_initial_prompts_messages,
|
||||||
|
rank_prompter_reply_messages,
|
||||||
task_complete_embed,
|
task_complete_embed,
|
||||||
)
|
)
|
||||||
from bot.settings import Settings
|
from bot.settings import Settings
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
from oasst_shared.api_client import OasstApiClient
|
||||||
from oasst_shared.schemas import protocol as protocol_schema
|
from oasst_shared.schemas import protocol as protocol_schema
|
||||||
from oasst_shared.schemas.protocol import TaskRequestType
|
|
||||||
|
|
||||||
plugin = lightbulb.Plugin("WorkPlugin")
|
plugin = lightbulb.Plugin("WorkPlugin")
|
||||||
|
|
||||||
@@ -38,30 +36,337 @@ MAX_TASK_ACCEPT_TIME = 60 * 10 # seconds
|
|||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|
||||||
|
_Task_contra = t.TypeVar("_Task_contra", bound=protocol_schema.Task, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _TaskHandler(t.Generic[_Task_contra]):
|
||||||
|
"""Handle user interaction for a task."""
|
||||||
|
|
||||||
|
def __init__(self, ctx: lightbulb.Context, task: _Task_contra) -> None:
|
||||||
|
"""Create a new `TaskHandler`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ctx (lightbulb.Context): The context of the command that started the task.
|
||||||
|
task (_Task_contra): The task to handle.
|
||||||
|
"""
|
||||||
|
self.ctx = ctx
|
||||||
|
self.task = task
|
||||||
|
self.task_messages = self.get_task_messages(task)
|
||||||
|
self.sent_messages: list[hikari.Message] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: _Task_contra) -> list[str]:
|
||||||
|
"""Get the messages to send to the user for the task."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def send(self) -> t.Literal["accept", "next", "cancel"] | None:
|
||||||
|
"""Send the task and wait for the user to accept/skip/cancel it."""
|
||||||
|
# Send all but the last message because we need to attach buttons to the last one
|
||||||
|
logger.debug(f"Sending {len(self.task_messages)} messages\n{self.task_messages!r}")
|
||||||
|
for task_msg in self.task_messages[:-1]:
|
||||||
|
if len(task_msg) > 2000:
|
||||||
|
logger.warning(f"Attempting to send a message <2000 characters in length. Task id: {self.task.id}")
|
||||||
|
task_msg = task_msg[:1999]
|
||||||
|
self.sent_messages.append(await self.ctx.author.send(task_msg))
|
||||||
|
|
||||||
|
# Send the last message with buttons
|
||||||
|
task_accept_view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||||
|
logger.debug(f"TH Message length {len(self.task_messages[-1])}")
|
||||||
|
last_msg = await self.ctx.author.send(self.task_messages[-1][:1999], components=task_accept_view)
|
||||||
|
|
||||||
|
await task_accept_view.start(last_msg)
|
||||||
|
await task_accept_view.wait()
|
||||||
|
|
||||||
|
return task_accept_view.choice
|
||||||
|
|
||||||
|
async def handle(self) -> None:
|
||||||
|
"""Handle the user's response to the task.
|
||||||
|
|
||||||
|
This method should be called after `send` has been called."""
|
||||||
|
# Ack task to the backend
|
||||||
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||||
|
await oasst_api.ack_task(self.task.id, message_id=f"{self.sent_messages[0].id}")
|
||||||
|
|
||||||
|
# Loop until the user's input is accepted
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Wait for user to send a message
|
||||||
|
event = await self.ctx.bot.wait_for(
|
||||||
|
hikari.DMMessageCreateEvent,
|
||||||
|
predicate=lambda e: (
|
||||||
|
e.author_id == self.ctx.author.id
|
||||||
|
and e.message.content is not None
|
||||||
|
and not e.message.content.startswith(settings.prefix)
|
||||||
|
),
|
||||||
|
timeout=MAX_TASK_TIME,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate the message
|
||||||
|
if event.content is None or not self.check_user_input(event.content):
|
||||||
|
await self.ctx.author.send("Invalid input")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Confirm user input
|
||||||
|
if not (await self.confirm_user_input(event.content)):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Message is valid and confirmed by user
|
||||||
|
break
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return
|
||||||
|
|
||||||
|
next_task = await self.notify(event.content, event)
|
||||||
|
if not isinstance(next_task, protocol_schema.TaskDone):
|
||||||
|
raise TypeError(f"Unknown task type: {next_task!r}")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||||
|
"""Notify the backend that the user completed the task."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
"""Send the user's response back to the user and ask them to confirm it. Returns True if the user confirms."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
"""Check the user's response to the task. Returns True if the response is valid."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def cancel(self, reason: str = "not specified") -> None:
|
||||||
|
"""Cancel the task."""
|
||||||
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||||
|
await oasst_api.nack_task(self.task.id, reason)
|
||||||
|
|
||||||
|
|
||||||
|
_Ranking_contra = t.TypeVar(
|
||||||
|
"_Ranking_contra",
|
||||||
|
bound=protocol_schema.RankAssistantRepliesTask
|
||||||
|
| protocol_schema.RankInitialPromptsTask
|
||||||
|
| protocol_schema.RankPrompterRepliesTask
|
||||||
|
| protocol_schema.RankConversationRepliesTask,
|
||||||
|
contravariant=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RankingTaskHandler(_TaskHandler[_Ranking_contra]):
|
||||||
|
"""This should not be used directly. Use its subclasses instead."""
|
||||||
|
|
||||||
|
async def notify(self, content: str, event: hikari.DMMessageCreateEvent) -> protocol_schema.Task:
|
||||||
|
oasst_api: OasstApiClient = self.ctx.bot.d.oasst_api
|
||||||
|
|
||||||
|
task = await oasst_api.post_interaction(
|
||||||
|
protocol_schema.MessageRanking(
|
||||||
|
user=protocol_schema.User(
|
||||||
|
id=f"{self.ctx.author.id}", auth_method="discord", display_name=self.ctx.author.username
|
||||||
|
),
|
||||||
|
ranking=[int(r) - 1 for r in content.split(",")],
|
||||||
|
message_id=f"{self.sent_messages[0].id}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
db: Connection = self.ctx.bot.d.db
|
||||||
|
async with db.cursor() as cursor:
|
||||||
|
row = await (
|
||||||
|
await cursor.execute("SELECT log_channel_id FROM guilds WHERE guild_id = ?", (self.ctx.guild_id,))
|
||||||
|
).fetchone()
|
||||||
|
log_channel = row[0] if row else None
|
||||||
|
log_messages: list[hikari.Message] = []
|
||||||
|
|
||||||
|
if log_channel is not None:
|
||||||
|
for message in self.task_messages[:-1]:
|
||||||
|
msg = await self.ctx.bot.rest.create_message(log_channel, message)
|
||||||
|
log_messages.append(msg)
|
||||||
|
await self.ctx.bot.rest.create_message(log_channel, task_complete_embed(self.task, self.ctx.author.mention))
|
||||||
|
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
class RankAssistantRepliesHandler(_RankingTaskHandler[protocol_schema.RankAssistantRepliesTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||||
|
return rank_assistant_reply_message(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||||
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(
|
||||||
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||||
|
)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class RankInitialPromptHandler(_RankingTaskHandler[protocol_schema.RankInitialPromptsTask]):
|
||||||
|
def __init__(self, ctx: lightbulb.Context, task: protocol_schema.RankInitialPromptsTask) -> None:
|
||||||
|
super().__init__(ctx, task)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||||
|
return rank_initial_prompts_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content.split(",")) == len(self.task.prompt_messages) and all(
|
||||||
|
[r.isdigit() and int(r) in range(1, len(self.task.prompt_messages) + 1) for r in content.split(",")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(
|
||||||
|
confirm_ranking_response_message(content, self.task.prompt_messages), components=confirm_input_view
|
||||||
|
)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class RankPrompterReplyHandler(_RankingTaskHandler[protocol_schema.RankPrompterRepliesTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||||
|
return rank_prompter_reply_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||||
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(
|
||||||
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||||
|
)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class RankConversationReplyHandler(_RankingTaskHandler[protocol_schema.RankConversationRepliesTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||||
|
return rank_conversation_reply_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content.split(",")) == len(self.task.reply_messages) and all(
|
||||||
|
[r.isdigit() and int(r) in range(1, len(self.task.reply_messages) + 1) for r in content.split(",")]
|
||||||
|
)
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(
|
||||||
|
confirm_ranking_response_message(content, self.task.reply_messages), components=confirm_input_view
|
||||||
|
)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class InitialPromptHandler(_TaskHandler[protocol_schema.InitialPromptTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||||
|
return initial_prompt_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content) > 0
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class PrompterReplyHandler(_TaskHandler[protocol_schema.PrompterReplyTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||||
|
return prompter_reply_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content) > 0
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantReplyHandler(_TaskHandler[protocol_schema.AssistantReplyTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||||
|
return assistant_reply_messages(task)
|
||||||
|
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
return len(content) > 0
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(confirm_text_response_message(content), components=confirm_input_view)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
_Label_contra = t.TypeVar("_Label_contra", bound=protocol_schema.LabelConversationReplyTask, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class _LabelConversationReplyHandler(_TaskHandler[_Label_contra]):
|
||||||
|
def check_user_input(self, content: str) -> bool:
|
||||||
|
user_labels = content.split(",")
|
||||||
|
return (
|
||||||
|
all([l in self.task.valid_labels for l in user_labels])
|
||||||
|
and self.task.mandatory_labels is not None
|
||||||
|
and all([m in user_labels for m in self.task.mandatory_labels])
|
||||||
|
)
|
||||||
|
|
||||||
|
async def confirm_user_input(self, content: str) -> bool:
|
||||||
|
confirm_input_view = YesNoView()
|
||||||
|
msg = await self.ctx.author.send(confirm_label_response_message(content), components=confirm_input_view)
|
||||||
|
await confirm_input_view.start(msg)
|
||||||
|
await confirm_input_view.wait()
|
||||||
|
|
||||||
|
return bool(confirm_input_view.choice)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelAssistantReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelAssistantReplyTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||||
|
return label_assistant_reply_messages(task)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelPrompterReplyHandler(_LabelConversationReplyHandler[protocol_schema.LabelPrompterReplyTask]):
|
||||||
|
@staticmethod
|
||||||
|
def get_task_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||||
|
return label_prompter_reply_messages(task)
|
||||||
|
|
||||||
|
|
||||||
|
summarize_story = "summarize_story"
|
||||||
|
rate_summary = "rate_summary"
|
||||||
|
|
||||||
|
|
||||||
@plugin.command
|
@plugin.command
|
||||||
@lightbulb.option(
|
|
||||||
"type",
|
|
||||||
"The type of task to request.",
|
|
||||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
|
||||||
required=False,
|
|
||||||
default=str(TaskRequestType.random),
|
|
||||||
type=str,
|
|
||||||
)
|
|
||||||
@lightbulb.command("work", "Complete a task.")
|
@lightbulb.command("work", "Complete a task.")
|
||||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||||
async def work(ctx: lightbulb.Context):
|
async def work2(ctx: lightbulb.Context) -> None:
|
||||||
"""Create and handle a task."""
|
"""Complete a task."""
|
||||||
# Only send this message if started from a server
|
|
||||||
if ctx.guild_id is not None:
|
|
||||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
|
||||||
|
|
||||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
|
||||||
currently_working: dict[
|
|
||||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
|
||||||
] = ctx.bot.d.currently_working
|
|
||||||
|
|
||||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||||
|
currently_working: dict[hikari.Snowflake, UUID] = ctx.bot.d.currently_working
|
||||||
|
|
||||||
|
# Check if the user is already working on a task
|
||||||
if ctx.author.id in currently_working:
|
if ctx.author.id in currently_working:
|
||||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||||
msg = await ctx.author.send(
|
msg = await ctx.author.send(
|
||||||
@@ -76,374 +381,66 @@ async def work(ctx: lightbulb.Context):
|
|||||||
case False | None:
|
case False | None:
|
||||||
return
|
return
|
||||||
case True:
|
case True:
|
||||||
old_msg, task_id = currently_working[ctx.author.id]
|
task_id = currently_working[ctx.author.id]
|
||||||
if old_msg is not None:
|
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
|
||||||
map(lambda c: c, old_msg.components)
|
|
||||||
await old_msg.delete()
|
|
||||||
if task_id is not None:
|
|
||||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
|
||||||
|
|
||||||
await msg.delete()
|
if ctx.guild_id:
|
||||||
|
await ctx.respond("check DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||||
|
|
||||||
currently_working[ctx.author.id] = (None, None)
|
# Keep sending tasks until the user doesn't want more
|
||||||
|
|
||||||
# Create a TaskRequestType from the stringified enum value
|
|
||||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
|
||||||
|
|
||||||
logger.debug(f"Starting task_type: {task_type!r}")
|
|
||||||
try:
|
try:
|
||||||
await _handle_task(ctx, task_type)
|
while True:
|
||||||
|
task = await oasst_api.fetch_random_task(
|
||||||
|
user=protocol_schema.User(
|
||||||
|
id=f"{ctx.author.id}", display_name=ctx.author.username, auth_method="discord"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ranking tasks
|
||||||
|
if isinstance(task, protocol_schema.RankAssistantRepliesTask):
|
||||||
|
task_handler = RankAssistantRepliesHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||||
|
task_handler = RankInitialPromptHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.RankPrompterRepliesTask):
|
||||||
|
task_handler = RankPrompterReplyHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||||
|
task_handler = RankConversationReplyHandler(ctx, task)
|
||||||
|
|
||||||
|
# Text input tasks
|
||||||
|
elif isinstance(task, protocol_schema.InitialPromptTask):
|
||||||
|
task_handler = InitialPromptHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.PrompterReplyTask):
|
||||||
|
task_handler = PrompterReplyHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.AssistantReplyTask):
|
||||||
|
task_handler = AssistantReplyHandler(ctx, task)
|
||||||
|
|
||||||
|
# Label tasks
|
||||||
|
elif isinstance(task, protocol_schema.LabelAssistantReplyTask):
|
||||||
|
task_handler = LabelAssistantReplyHandler(ctx, task)
|
||||||
|
elif isinstance(task, protocol_schema.LabelPrompterReplyTask):
|
||||||
|
task_handler = LabelPrompterReplyHandler(ctx, task)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown task type: {type(task)}")
|
||||||
|
|
||||||
|
resp = await task_handler.send()
|
||||||
|
|
||||||
|
match resp:
|
||||||
|
case "accept":
|
||||||
|
currently_working[ctx.author.id] = task.id
|
||||||
|
await task_handler.handle()
|
||||||
|
case "next":
|
||||||
|
await task_handler.cancel("user skipped task")
|
||||||
|
case "cancel":
|
||||||
|
await task_handler.cancel("user canceled work")
|
||||||
|
break
|
||||||
|
case None:
|
||||||
|
await task_handler.cancel("select timed out")
|
||||||
|
break
|
||||||
finally:
|
finally:
|
||||||
del currently_working[ctx.author.id]
|
del currently_working[ctx.author.id]
|
||||||
|
|
||||||
|
|
||||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
|
||||||
"""Handle creating and collecting user input for a task.
|
|
||||||
|
|
||||||
Continually present tasks to the user until they select one, cancel, or time out.
|
|
||||||
If they select one, present the task steps until a `task_done` task is received.
|
|
||||||
Finally, ask the user if they want to perform another task (of the same type).
|
|
||||||
"""
|
|
||||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
||||||
|
|
||||||
# Continue to complete tasks until the user doesn't want to do another
|
|
||||||
done = False
|
|
||||||
while not done:
|
|
||||||
|
|
||||||
# Loop until the user accepts a task
|
|
||||||
task, msg_id = await _select_task(ctx, task_type)
|
|
||||||
|
|
||||||
if task is None:
|
|
||||||
# User cancelled
|
|
||||||
return
|
|
||||||
|
|
||||||
# Task action loop
|
|
||||||
completed = False
|
|
||||||
while not completed:
|
|
||||||
await ctx.author.send(embed=plain_embed("Please type your response below:"))
|
|
||||||
try:
|
|
||||||
event = await ctx.bot.wait_for(
|
|
||||||
hikari.DMMessageCreateEvent,
|
|
||||||
timeout=MAX_TASK_TIME,
|
|
||||||
predicate=lambda e: e.author.id == ctx.author.id
|
|
||||||
and not (e.message.content or "").startswith(settings.prefix),
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
|
||||||
await oasst_api.nack_task(task.id, reason="timed out")
|
|
||||||
logger.info(f"Task {task.id} timed out")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Invalid response
|
|
||||||
valid, err_msg = _validate_user_input(event.content, task)
|
|
||||||
if not valid or event.content is None:
|
|
||||||
|
|
||||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.debug(f"Successful user input received: {event.content}")
|
|
||||||
|
|
||||||
# Confirm user input
|
|
||||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
|
||||||
content = confirm_ranking_response_message(event.content, task.replies)
|
|
||||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
|
||||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
|
||||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
|
||||||
content = confirm_label_response_message(event.content)
|
|
||||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
|
||||||
content = confirm_text_response_message(event.content)
|
|
||||||
else:
|
|
||||||
logger.critical(f"Unknown task type: {task.type}")
|
|
||||||
raise ValueError(f"Unknown task type: {task.type}")
|
|
||||||
|
|
||||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
|
||||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
|
||||||
await confirm_resp_view.start(msg)
|
|
||||||
await confirm_resp_view.wait()
|
|
||||||
|
|
||||||
match confirm_resp_view.choice:
|
|
||||||
case False | None:
|
|
||||||
continue
|
|
||||||
case True:
|
|
||||||
await msg.delete() # buttons are already gone
|
|
||||||
|
|
||||||
# Send the response to the backend
|
|
||||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
|
||||||
reply = protocol_schema.MessageRanking(
|
|
||||||
message_id=str(msg_id),
|
|
||||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
|
||||||
user=protocol_schema.User(
|
|
||||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(task, protocol_schema.LabelConversationReplyTask | protocol_schema.LabelInitialPromptTask):
|
|
||||||
labels = event.content.replace(" ", "").split(",")
|
|
||||||
labels_dict = {label: 1 if label in labels else 0 for label in task.valid_labels}
|
|
||||||
|
|
||||||
reply = protocol_schema.TextLabels(
|
|
||||||
message_id=task.message_id,
|
|
||||||
labels=labels_dict,
|
|
||||||
user=protocol_schema.User(
|
|
||||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
||||||
),
|
|
||||||
)
|
|
||||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
|
||||||
reply = protocol_schema.TextReplyToMessage(
|
|
||||||
message_id=str(msg_id),
|
|
||||||
user_message_id=str(event.message_id),
|
|
||||||
user=protocol_schema.User(
|
|
||||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
|
||||||
),
|
|
||||||
text=event.content,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.critical(f"Unexpected task type received: {task.type}")
|
|
||||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
|
||||||
|
|
||||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
|
||||||
|
|
||||||
# Get next task
|
|
||||||
new_task = await oasst_api.post_interaction(reply)
|
|
||||||
logger.info(f"New task {new_task}")
|
|
||||||
|
|
||||||
if new_task.type == TaskType.done:
|
|
||||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
|
||||||
completed = True
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
|
||||||
|
|
||||||
# Send a message in all the log channels that the task is complete
|
|
||||||
conn: Connection = ctx.bot.d.db
|
|
||||||
async with conn.cursor() as cursor:
|
|
||||||
await cursor.execute("SELECT log_channel_id FROM guild_settings")
|
|
||||||
log_channel_ids = await cursor.fetchall()
|
|
||||||
|
|
||||||
channels = [
|
|
||||||
ctx.bot.cache.get_guild_channel(id[0]) or await ctx.bot.rest.fetch_channel(id[0])
|
|
||||||
for id in log_channel_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
|
||||||
# This will definitely get the bot rate limited, but that's a future problem
|
|
||||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
|
||||||
|
|
||||||
# ask the user if they want to do another task
|
|
||||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
||||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
|
||||||
await another_task_view.start(msg)
|
|
||||||
await another_task_view.wait()
|
|
||||||
|
|
||||||
match another_task_view.choice:
|
|
||||||
case False | None:
|
|
||||||
done = True
|
|
||||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
|
||||||
case True:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
async def _select_task(
|
|
||||||
ctx: lightbulb.Context, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
|
||||||
) -> tuple[protocol_schema.Task | None, str]:
|
|
||||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
|
||||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
|
||||||
logger.debug(f"Starting task selection for {task_type}")
|
|
||||||
|
|
||||||
# Loop until the user accepts a task, cancels, or times out
|
|
||||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
|
||||||
while True:
|
|
||||||
logger.debug(f"Requesting task of type {task_type}")
|
|
||||||
task = await oasst_api.fetch_task(task_type, user)
|
|
||||||
resp, msg = await _send_task(ctx, task, msg)
|
|
||||||
msg_id = str(msg.id)
|
|
||||||
|
|
||||||
logger.debug(f"User choice: {resp}")
|
|
||||||
match resp:
|
|
||||||
case "accept":
|
|
||||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
|
||||||
await oasst_api.ack_task(task.id, msg_id)
|
|
||||||
return task, msg_id
|
|
||||||
|
|
||||||
case "next":
|
|
||||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
|
||||||
await oasst_api.nack_task(task.id, "rejected")
|
|
||||||
continue
|
|
||||||
|
|
||||||
case "cancel":
|
|
||||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
|
||||||
await oasst_api.nack_task(task.id, "canceled")
|
|
||||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
|
||||||
return None, msg_id
|
|
||||||
|
|
||||||
case None:
|
|
||||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
|
||||||
await oasst_api.nack_task(task.id, "timed out")
|
|
||||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
|
||||||
return None, msg_id
|
|
||||||
|
|
||||||
|
|
||||||
async def _send_task(
|
|
||||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
|
||||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
|
||||||
"""Send a task to the user.
|
|
||||||
|
|
||||||
Returns the user's choice and the message ID of the task message.
|
|
||||||
"""
|
|
||||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
|
||||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
|
||||||
|
|
||||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
|
||||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
|
||||||
|
|
||||||
# Create an embed based on the task's type
|
|
||||||
if task.type == TaskRequestType.initial_prompt:
|
|
||||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
|
||||||
logger.debug("sending initial prompt task")
|
|
||||||
content = initial_prompt_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
|
||||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
|
||||||
logger.debug("sending rank initial prompt task")
|
|
||||||
content = rank_initial_prompts_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
|
||||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
|
||||||
logger.debug("sending rank user reply task")
|
|
||||||
content = rank_prompter_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
|
||||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
|
||||||
logger.debug("sending rank assistant reply task")
|
|
||||||
content = rank_assistant_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.label_initial_prompt:
|
|
||||||
assert isinstance(task, protocol_schema.LabelInitialPromptTask)
|
|
||||||
logger.debug("sending label initial prompt task")
|
|
||||||
content = label_initial_prompt_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.label_prompter_reply:
|
|
||||||
assert isinstance(task, protocol_schema.LabelPrompterReplyTask)
|
|
||||||
logger.debug("sending label prompter reply task")
|
|
||||||
content = label_prompter_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.label_assistant_reply:
|
|
||||||
assert isinstance(task, protocol_schema.LabelAssistantReplyTask)
|
|
||||||
logger.debug("sending label assistant reply task")
|
|
||||||
content = label_assistant_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.prompter_reply:
|
|
||||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
|
||||||
logger.debug("sending user reply task")
|
|
||||||
content = prompter_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.assistant_reply:
|
|
||||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
|
||||||
logger.debug("sending assistant reply task")
|
|
||||||
content = assistant_reply_message(task)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.summarize_story:
|
|
||||||
raise NotImplementedError
|
|
||||||
elif task.type == TaskRequestType.rate_summary:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.critical(f"unknown task type {task.type}")
|
|
||||||
raise ValueError(f"unknown task type {task.type}")
|
|
||||||
|
|
||||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
|
||||||
if not msg:
|
|
||||||
msg = await ctx.author.send(
|
|
||||||
content,
|
|
||||||
embed=embed,
|
|
||||||
components=view,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await msg.edit(
|
|
||||||
content,
|
|
||||||
embed=embed,
|
|
||||||
components=view,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert msg is not None
|
|
||||||
|
|
||||||
# Set the choice id as the current msg id
|
|
||||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
|
||||||
|
|
||||||
await view.start(msg)
|
|
||||||
await view.wait()
|
|
||||||
|
|
||||||
return view.choice, msg
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
|
||||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
|
||||||
if content is None:
|
|
||||||
return False, "No input provided"
|
|
||||||
|
|
||||||
# User message input
|
|
||||||
if (
|
|
||||||
task.type == TaskRequestType.initial_prompt
|
|
||||||
or task.type == TaskRequestType.prompter_reply
|
|
||||||
or task.type == TaskRequestType.assistant_reply
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
task,
|
|
||||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
|
||||||
)
|
|
||||||
return len(content) > 0, "Message must be at least one character long."
|
|
||||||
|
|
||||||
# Ranking tasks
|
|
||||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
|
||||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
|
||||||
num_replies = len(task.replies)
|
|
||||||
|
|
||||||
rankings = content.replace(" ", "").split(",")
|
|
||||||
return (
|
|
||||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
|
||||||
"Message must contain numbers for all replies.",
|
|
||||||
)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
|
||||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
|
||||||
num_prompts = len(task.prompts)
|
|
||||||
|
|
||||||
rankings = content.replace(" ", "").split(",")
|
|
||||||
return (
|
|
||||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
|
||||||
"Message must contain numbers for all prompts.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Labels tasks
|
|
||||||
elif task.type in (
|
|
||||||
TaskRequestType.label_initial_prompt,
|
|
||||||
TaskRequestType.label_prompter_reply,
|
|
||||||
TaskRequestType.label_assistant_reply,
|
|
||||||
):
|
|
||||||
assert isinstance(
|
|
||||||
task,
|
|
||||||
protocol_schema.LabelInitialPromptTask
|
|
||||||
| protocol_schema.LabelPrompterReplyTask
|
|
||||||
| protocol_schema.LabelAssistantReplyTask,
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = content.replace(" ", "").split(",")
|
|
||||||
valid_labels = set(task.valid_labels)
|
|
||||||
return (
|
|
||||||
set(labels).issubset(valid_labels),
|
|
||||||
"Message must only contain labels from predefined set of labels.",
|
|
||||||
)
|
|
||||||
|
|
||||||
elif task.type == TaskRequestType.summarize_story:
|
|
||||||
raise NotImplementedError
|
|
||||||
elif task.type == TaskRequestType.rate_summary:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.critical(f"Unknown task type {task.type}")
|
|
||||||
raise ValueError(f"Unknown task type {task.type}")
|
|
||||||
|
|
||||||
|
|
||||||
class TaskAcceptView(miru.View):
|
class TaskAcceptView(miru.View):
|
||||||
"""View with three buttons: accept, next, and cancel.
|
"""View with three buttons: accept, next, and cancel.
|
||||||
|
|
||||||
|
|||||||
+129
-69
@@ -1,4 +1,11 @@
|
|||||||
"""All user-facing messages and embeds."""
|
"""All user-facing messages and embeds.
|
||||||
|
|
||||||
|
When sending a conversation
|
||||||
|
- The function will return a list of strings
|
||||||
|
- use asyncio.gather to send all messages
|
||||||
|
|
||||||
|
-
|
||||||
|
"""
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
@@ -33,8 +40,11 @@ def _ranking_prompt(text: str) -> str:
|
|||||||
return f":trophy: _{text}_"
|
return f":trophy: _{text}_"
|
||||||
|
|
||||||
|
|
||||||
def _label_prompt(text: str) -> str:
|
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
|
||||||
return f":question: _{text}"
|
return f""":question: _{text}_
|
||||||
|
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
|
||||||
|
Valid labels: {", ".join(valid_labels)}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _response_prompt(text: str) -> str:
|
def _response_prompt(text: str) -> str:
|
||||||
@@ -57,20 +67,29 @@ def _assistant(text: str | None) -> str:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
|
||||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||||
|
|
||||||
|
|
||||||
def _ordered_list(items: list[str]) -> str:
|
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
|
||||||
return "\n\n".join(_make_ordered_list(items))
|
return "\n\n".join(_make_ordered_list(items))
|
||||||
|
|
||||||
|
|
||||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
|
||||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||||
|
messages = map(
|
||||||
|
lambda m: f"""\
|
||||||
def _hint(hint: str | None) -> str:
|
:robot: __Assistant__:
|
||||||
return f"{NL}Hint: {hint}" if hint else ""
|
{m.text}
|
||||||
|
"""
|
||||||
|
if m.is_assistant
|
||||||
|
else f"""\
|
||||||
|
:person_red_hair: __User__:
|
||||||
|
{m.text}
|
||||||
|
""",
|
||||||
|
conv.messages,
|
||||||
|
)
|
||||||
|
return list(messages)
|
||||||
|
|
||||||
|
|
||||||
def _li(text: str) -> str:
|
def _li(text: str) -> str:
|
||||||
@@ -82,59 +101,80 @@ def _li(text: str) -> str:
|
|||||||
###
|
###
|
||||||
|
|
||||||
|
|
||||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
f"""\
|
||||||
|
|
||||||
{_h1("INITIAL PROMPT")}
|
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
|
||||||
|
|
||||||
|
|
||||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||||
{_hint(task.hint)}
|
|
||||||
"""
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
f"""\
|
||||||
|
|
||||||
{_h1("RANK INITIAL PROMPTS")}
|
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
|
||||||
|
|
||||||
|
|
||||||
{_ordered_list(task.prompts)}
|
{_ordered_list(task.prompt_messages)}
|
||||||
|
|
||||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
|
||||||
"""
|
"""
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
"""\
|
||||||
|
|
||||||
{_h1("RANK PROMPTER REPLIES")}
|
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
|
||||||
|
|
||||||
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
f""":person_red_hair: __User__:
|
||||||
|
{_ordered_list(task.reply_messages)}
|
||||||
|
|
||||||
|
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
|
||||||
{_user(None)}
|
|
||||||
{_ordered_list(task.replies)}
|
|
||||||
|
|
||||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
|
||||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
"""\
|
||||||
|
|
||||||
{_h1("RANK ASSISTANT REPLIES")}
|
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
|
||||||
|
|
||||||
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
f""":robot: __Assistant__:,
|
||||||
|
{_ordered_list(task.reply_messages)}
|
||||||
|
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
|
||||||
{_assistant(None)}
|
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
|
||||||
{_ordered_list(task.replies)}
|
return [
|
||||||
|
"""\
|
||||||
|
|
||||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
|
||||||
"""
|
|
||||||
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
f""":person_red_hair: __User__:
|
||||||
|
{_ordered_list(task.reply_messages)}
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
|
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
|
||||||
@@ -146,64 +186,84 @@ def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -
|
|||||||
|
|
||||||
{task.prompt}
|
{task.prompt}
|
||||||
|
|
||||||
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')")}
|
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def label_prompter_reply_message(task: protocol_schema.LabelPrompterReplyTask) -> str:
|
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
|
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
f"""\
|
||||||
|
|
||||||
{_h1("LABEL PROMPTER REPLY")}
|
{_h1("LABEL PROMPTER REPLY")}
|
||||||
|
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
""",
|
||||||
{_user(None)}
|
*_conversation(task.conversation),
|
||||||
|
f"""{_user(None)}
|
||||||
{task.reply}
|
{task.reply}
|
||||||
|
|
||||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||||
"""
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def label_assistant_reply_message(task: protocol_schema.LabelAssistantReplyTask) -> str:
|
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
|
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
f"""\
|
||||||
|
|
||||||
{_h1("LABEL ASSISTANT REPLY")}
|
{_h1("LABEL ASSISTANT REPLY")}
|
||||||
|
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
f"""
|
||||||
{_assistant(None)}
|
{_assistant(None)}
|
||||||
{task.reply}
|
{task.reply}
|
||||||
|
|
||||||
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')")}
|
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
|
||||||
"""
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||||
return f"""\
|
return [
|
||||||
|
"""\
|
||||||
|
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
|
||||||
|
|
||||||
{_h1("PROMPTER REPLY")}
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
|
||||||
|
|
||||||
|
:speech_balloon: _Please provide a reply to the assistant._
|
||||||
|
""",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
|
||||||
{_hint(task.hint)}
|
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||||
|
# return [
|
||||||
{_response_prompt("Please provide a reply to the assistant.")}
|
# message_templates.render("title.msg", "PROMPTER REPLY"),
|
||||||
"""
|
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
|
||||||
|
# message_templates.render("prompter_reply_task.msg", task.hint),
|
||||||
|
# ]
|
||||||
|
|
||||||
|
|
||||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
|
||||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||||
return f"""\
|
return [
|
||||||
{_h1("ASSISTANT REPLY")}
|
"""\
|
||||||
|
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
|
||||||
|
|
||||||
|
""",
|
||||||
|
*_conversation(task.conversation),
|
||||||
|
"""\
|
||||||
|
|
||||||
{_conversation(task.conversation)}
|
:speech_balloon: _Please provide a reply to the user as the assistant._
|
||||||
|
""",
|
||||||
{_response_prompt("Please provide an assistant reply to the prompter.")}
|
]
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def confirm_text_response_message(content: str) -> str:
|
def confirm_text_response_message(content: str) -> str:
|
||||||
@@ -214,7 +274,7 @@ def confirm_text_response_message(content: str) -> str:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
|
||||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||||
original_list = _make_ordered_list(items)
|
original_list = _make_ordered_list(items)
|
||||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||||
|
|||||||
@@ -68,12 +68,15 @@ class OasstApiClient:
|
|||||||
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
|
async def post(self, path: str, data: dict[str, t.Any]) -> Optional[dict[str, t.Any]]:
|
||||||
"""Make a POST request to the backend."""
|
"""Make a POST request to the backend."""
|
||||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"x-api-key": self.api_key})
|
||||||
|
logger.debug(f"response: {response}")
|
||||||
|
|
||||||
# If the response is not a 2XX, check to see
|
# If the response is not a 2XX, check to see
|
||||||
# if the json has the fields to create an
|
# if the json has the fields to create an
|
||||||
# OasstError.
|
# OasstError.
|
||||||
if response.status >= 300:
|
if response.status >= 300:
|
||||||
|
text = await response.text()
|
||||||
|
logger.debug(f"resp text: {text}")
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
try:
|
try:
|
||||||
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
oasst_error = protocol_schema.OasstErrorResponse(**(data or {}))
|
||||||
@@ -114,20 +117,21 @@ class OasstApiClient:
|
|||||||
task_type: protocol_schema.TaskRequestType,
|
task_type: protocol_schema.TaskRequestType,
|
||||||
user: Optional[protocol_schema.User] = None,
|
user: Optional[protocol_schema.User] = None,
|
||||||
collective: bool = False,
|
collective: bool = False,
|
||||||
|
lang: Optional[str] = None,
|
||||||
) -> protocol_schema.Task:
|
) -> protocol_schema.Task:
|
||||||
"""Fetch a task from the backend."""
|
"""Fetch a task from the backend."""
|
||||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective, lang=lang)
|
||||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||||
logger.debug(f"RESP {resp}")
|
logger.debug(f"RESP {resp}")
|
||||||
return self._parse_task(resp)
|
return self._parse_task(resp)
|
||||||
|
|
||||||
async def fetch_random_task(
|
async def fetch_random_task(
|
||||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
self, user: Optional[protocol_schema.User] = None, collective: bool = False, lang: Optional[str] = None
|
||||||
) -> protocol_schema.Task:
|
) -> protocol_schema.Task:
|
||||||
"""Fetch a random task from the backend."""
|
"""Fetch a random task from the backend."""
|
||||||
logger.debug(f"Fetching random for user {user}")
|
logger.debug(f"Fetching random for user {user}")
|
||||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective, lang)
|
||||||
|
|
||||||
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
|
async def ack_task(self, task_id: str | UUID, message_id: str) -> None:
|
||||||
"""Send an ACK for a task to the backend."""
|
"""Send an ACK for a task to the backend."""
|
||||||
|
|||||||
Reference in New Issue
Block a user