mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Feat/better messages (#318)
* improve messages and UX * move log statement Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="/" # Don't change, this allows for slash commands in DMs
|
||||
PREFIX="/" # DO NOT LEAVE EMPTY, slash command prefix in DMs
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
|
||||
+13
-6
@@ -6,7 +6,7 @@ import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from bot.settings import Settings
|
||||
from bot.utils import EMPTY, mention
|
||||
from bot.utils import mention
|
||||
from oasst_shared.api_client import OasstApiClient
|
||||
|
||||
settings = Settings()
|
||||
@@ -34,8 +34,11 @@ async def on_starting(event: hikari.StartingEvent):
|
||||
|
||||
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
|
||||
|
||||
# A set of user id's that are currently doing work.
|
||||
bot.d.currently_working = set()
|
||||
# A `dict[hikari.Message | None, UUID | None]]` that maps user IDs to (task msg ID, task UUIDs).
|
||||
# Either both are `None` or both are not `None`.
|
||||
# If both are `None`, the user is not currently selecting a task.
|
||||
# TODO: Grow this on startup so we don't have to re-allocate memory every time it needs to grow
|
||||
bot.d.currently_working = {}
|
||||
|
||||
|
||||
@bot.listen()
|
||||
@@ -50,13 +53,13 @@ async def _send_error_embed(
|
||||
) -> None:
|
||||
ctx.command
|
||||
embed = hikari.Embed(
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `{ctx.command.name}`' if ctx.command else '' }",
|
||||
title=f"`{exception.__class__.__name__}` Error{f' in `/{ctx.command.name}`' if ctx.command else '' }",
|
||||
description=content,
|
||||
color=0xFF0000,
|
||||
timestamp=datetime.now().astimezone(),
|
||||
).set_author(name=ctx.author.username, url=str(ctx.author.avatar_url))
|
||||
|
||||
await ctx.respond(EMPTY, embed=embed)
|
||||
await ctx.respond(embed=embed)
|
||||
|
||||
|
||||
@bot.listen(lightbulb.CommandErrorEvent)
|
||||
@@ -65,6 +68,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
# Unwrap the exception to get the original cause
|
||||
exc = event.exception.__cause__ or event.exception
|
||||
ctx = event.context
|
||||
if not ctx.bot.rest.is_alive:
|
||||
return
|
||||
|
||||
if isinstance(event.exception, lightbulb.CommandInvocationError):
|
||||
if not event.context.command:
|
||||
@@ -114,6 +119,8 @@ async def on_error(event: lightbulb.CommandErrorEvent) -> None:
|
||||
ctx,
|
||||
)
|
||||
elif isinstance(exc, lightbulb.errors.MissingRequiredAttachment):
|
||||
await _send_error_embed("Not enough attachemnts were supplied to this command.", exc, ctx)
|
||||
await _send_error_embed("Not enough attachments were supplied to this command.", exc, ctx)
|
||||
elif isinstance(exc, lightbulb.errors.CommandNotFound):
|
||||
await ctx.respond(f"`/{exc.invoked_with}` is not a valid command. Use `/help` to see a list of commands.")
|
||||
else:
|
||||
raise exc
|
||||
|
||||
@@ -78,7 +78,6 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
|
||||
# if the bot's permissions for this channel don't contain SEND_MESSAGE
|
||||
# This will also filter out categories and voice channels
|
||||
print(permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES)
|
||||
if not permissions_in(ch, own_member) & hikari.Permissions.SEND_MESSAGES:
|
||||
await ctx.respond(f"I don't have permission to send messages in {ch.mention}.")
|
||||
return
|
||||
|
||||
@@ -7,7 +7,6 @@ import lightbulb
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
@@ -74,7 +73,7 @@ class LabelModal(miru.Modal):
|
||||
)
|
||||
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel)
|
||||
await channel.send(EMPTY, embed=embed)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
|
||||
class LabelSelect(miru.View):
|
||||
@@ -164,7 +163,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
msg.content,
|
||||
timeout=60,
|
||||
)
|
||||
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
resp = await ctx.respond(embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
await label_select_view.start(await resp.message())
|
||||
await label_select_view.wait()
|
||||
|
||||
+171
-175
@@ -1,14 +1,27 @@
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.utils import EMPTY
|
||||
from bot.messages import (
|
||||
assistant_reply_message,
|
||||
confirm_ranking_response_message,
|
||||
confirm_text_response_message,
|
||||
initial_prompt_message,
|
||||
invalid_user_input_embed,
|
||||
plain_embed,
|
||||
prompter_reply_message,
|
||||
rank_assistant_reply_message,
|
||||
rank_initial_prompts_message,
|
||||
rank_prompter_reply_message,
|
||||
task_complete_embed,
|
||||
)
|
||||
from bot.settings import Settings
|
||||
from loguru import logger
|
||||
from oasst_shared.api_client import OasstApiClient, TaskType
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -19,6 +32,8 @@ plugin = lightbulb.Plugin("WorkPlugin")
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
settings = Settings()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
@@ -33,25 +48,50 @@ MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def work(ctx: lightbulb.Context):
|
||||
"""Create and handle a task."""
|
||||
# make sure the user isn't currently doing a task
|
||||
currently_working: set[hikari.Snowflakeish] = ctx.bot.d.currently_working
|
||||
# Only send this message if started from a server
|
||||
if ctx.guild_id is not None:
|
||||
await ctx.respond(embed=plain_embed("Sending you a task, check your DMs"), flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
# make sure the user isn't currently doing a task, and if they are, ask if they want to cancel it
|
||||
currently_working: dict[
|
||||
hikari.Snowflakeish, tuple[hikari.Message | None, UUID | None]
|
||||
] = ctx.bot.d.currently_working
|
||||
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
if ctx.author.id in currently_working:
|
||||
await ctx.respond(
|
||||
"You are already performing a task. Please complete that one first.", flags=hikari.MessageFlag.EPHEMERAL
|
||||
yn_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
embed=plain_embed("You are already working. Would you like to cancel your old task start a new one?"),
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=yn_view,
|
||||
)
|
||||
return
|
||||
await yn_view.start(msg)
|
||||
await yn_view.wait()
|
||||
|
||||
currently_working.add(ctx.author.id)
|
||||
match yn_view.choice:
|
||||
case False | None:
|
||||
return
|
||||
case True:
|
||||
old_msg, task_id = currently_working[ctx.author.id]
|
||||
if old_msg is not None:
|
||||
logger.info(f"User {ctx.author.id} cancelled task {task_id}, deleting message {old_msg.id}")
|
||||
map(lambda c: c, old_msg.components)
|
||||
await old_msg.delete()
|
||||
if task_id is not None:
|
||||
await oasst_api.nack_task(task_id, reason="user cancelled")
|
||||
|
||||
await msg.delete()
|
||||
|
||||
currently_working[ctx.author.id] = (None, None)
|
||||
|
||||
# Create a TaskRequestType from the stringified enum value
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
try:
|
||||
await _handle_task(ctx, task_type)
|
||||
finally:
|
||||
currently_working.remove(ctx.author.id)
|
||||
del currently_working[ctx.author.id]
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> None:
|
||||
@@ -71,38 +111,79 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
# User cancelled
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
await ctx.author.send(embed=plain_embed("Please type your response here"))
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id
|
||||
and not (e.message.content or "").startswith(settings.prefix),
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
valid, err_msg = _validate_user_input(event.content, task)
|
||||
if not valid or event.content is None:
|
||||
|
||||
await ctx.author.send(embed=invalid_user_input_embed(err_msg))
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Confirm user input
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask):
|
||||
content = confirm_ranking_response_message(event.content, task.replies)
|
||||
elif isinstance(task, protocol_schema.RankInitialPromptsTask):
|
||||
content = confirm_ranking_response_message(event.content, task.prompts)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
content = confirm_text_response_message(event.content)
|
||||
else:
|
||||
logger.critical(f"Unknown task type: {task.type}")
|
||||
raise ValueError(f"Unknown task type: {task.type}")
|
||||
|
||||
confirm_resp_view = YesNoView(timeout=MAX_TASK_TIME)
|
||||
msg = await ctx.author.send(content, components=confirm_resp_view)
|
||||
await confirm_resp_view.start(msg)
|
||||
await confirm_resp_view.wait()
|
||||
|
||||
match confirm_resp_view.choice:
|
||||
case False | None:
|
||||
continue
|
||||
case True:
|
||||
await msg.delete() # buttons are already gone
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
if isinstance(task, protocol_schema.RankConversationRepliesTask | protocol_schema.RankInitialPromptsTask):
|
||||
reply = protocol_schema.MessageRanking(
|
||||
message_id=str(msg_id),
|
||||
ranking=[int(r) - 1 for r in event.content.replace(" ", "").split(",")],
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
)
|
||||
elif isinstance(task, protocol_schema.ReplyToConversationTask | protocol_schema.InitialPromptTask):
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {task.type}")
|
||||
raise ValueError(f"Unexpected task type received: {task.type}")
|
||||
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
@@ -110,7 +191,7 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
await ctx.author.send(embed=plain_embed("Task completed"))
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
@@ -127,33 +208,20 @@ async def _handle_task(ctx: lightbulb.Context, task_type: TaskRequestType) -> No
|
||||
for id in log_channel_ids
|
||||
]
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {ctx.author.mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
done_embed = task_complete_embed(task, ctx.author.mention)
|
||||
# This will definitely get the bot rate limited, but that's a future problem
|
||||
asyncio.gather(
|
||||
*(ch.send(EMPTY, embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel))
|
||||
)
|
||||
asyncio.gather(*(ch.send(embed=done_embed) for ch in channels if isinstance(ch, hikari.TextableChannel)))
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
another_task_view = YesNoView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(embed=plain_embed("Would you like another task?"), components=another_task_view)
|
||||
await another_task_view.start(msg)
|
||||
await another_task_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
match another_task_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
await msg.edit(embed=plain_embed("Exiting, goodbye!"))
|
||||
case True:
|
||||
pass
|
||||
|
||||
@@ -166,10 +234,12 @@ async def _select_task(
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
msg: hikari.UndefinedOr[hikari.Message] = hikari.UNDEFINED
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
resp, msg = await _send_task(ctx, task, msg)
|
||||
msg_id = str(msg.id)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
@@ -181,25 +251,24 @@ async def _select_task(
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task canceled. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await ctx.author.send(embed=plain_embed("Task timed out. Exiting"))
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
ctx: lightbulb.Context, task: protocol_schema.Task, msg: hikari.UndefinedOr[hikari.Message]
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, hikari.Message]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
@@ -208,37 +277,38 @@ async def _send_task(
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
content: hikari.UndefinedOr[str] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
content = initial_prompt_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
content = rank_initial_prompts_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
content = rank_prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
content = rank_assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _prompter_reply_embed(task)
|
||||
content = prompter_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
content = assistant_reply_message(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -250,24 +320,34 @@ async def _send_task(
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
EMPTY,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
if not msg:
|
||||
msg = await ctx.author.send(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
else:
|
||||
await msg.edit(
|
||||
content,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
# Set the choice id as the current msg id
|
||||
ctx.bot.d.currently_working[ctx.author.id] = (msg, task.id)
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
return view.choice, msg
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> tuple[bool, str]:
|
||||
"""Returns whether the user's input is valid for the task type and an error message."""
|
||||
if content is None:
|
||||
return False
|
||||
return False, "No input provided"
|
||||
|
||||
# User message input
|
||||
if (
|
||||
@@ -279,22 +359,28 @@ def _validate_user_input(content: str | None, task: protocol_schema.Task) -> boo
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
return len(content) > 0, "Message must be at least one character long."
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies,
|
||||
"Message must contain numbers for all replies.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
rankings = content.replace(" ", "").split(",")
|
||||
return (
|
||||
set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts,
|
||||
"Message must contain numbers for all prompts.",
|
||||
)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
@@ -318,22 +404,29 @@ class TaskAcceptView(miru.View):
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
|
||||
class YesNoView(miru.View):
|
||||
"""View with two buttons: yes and no.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
@@ -344,115 +437,18 @@ class ChoiceView(miru.View):
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
await ctx.message.edit(component=None)
|
||||
self.stop()
|
||||
|
||||
|
||||
################################################################
|
||||
# Template Embeds #
|
||||
################################################################
|
||||
|
||||
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description="Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
description="Rank the following user replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Assistant Reply",
|
||||
description="Rank the following assistant replies from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description=f"""\
|
||||
Send the next message in the conversation as if you were the user.
|
||||
{'Hint: ' if task.hint else ''}
|
||||
""",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description="Send the next message in the conversation as if you were the user.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
async def on_timeout(self) -> None:
|
||||
if self.message is not None:
|
||||
await self.message.edit(component=None)
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""All user-facing messages and embeds."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
|
||||
NL = "\n"
|
||||
|
||||
###
|
||||
# Reusable 'components'
|
||||
###
|
||||
|
||||
|
||||
def _h1(text: str) -> str:
|
||||
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
|
||||
|
||||
|
||||
def _h2(text: str) -> str:
|
||||
return f"__**{text}**__"
|
||||
|
||||
|
||||
def _h3(text: str) -> str:
|
||||
return f"__{text}__"
|
||||
|
||||
|
||||
def _writing_prompt(text: str) -> str:
|
||||
return f":pencil: _{text}_"
|
||||
|
||||
|
||||
def _ranking_prompt(text: str) -> str:
|
||||
return f":trophy: _{text}_"
|
||||
|
||||
|
||||
def _response_prompt(text: str) -> str:
|
||||
return f":speech_balloon: _{text}_"
|
||||
|
||||
|
||||
def _summarize_prompt(text: str) -> str:
|
||||
return f":notepad_spiral: _{text}_"
|
||||
|
||||
|
||||
def _user(text: str | None) -> str:
|
||||
return f"""\
|
||||
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _assistant(text: str | None) -> str:
|
||||
return f"""\
|
||||
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
|
||||
"""
|
||||
|
||||
|
||||
def _make_ordered_list(items: list[str]) -> list[str]:
|
||||
return [f"{num} {item}" for num, item in zip(NUMBER_EMOJIS, items)]
|
||||
|
||||
|
||||
def _ordered_list(items: list[str]) -> str:
|
||||
return "\n\n".join(_make_ordered_list(items))
|
||||
|
||||
|
||||
def _conversation(conv: protocol_schema.Conversation) -> str:
|
||||
return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
|
||||
|
||||
|
||||
def _hint(hint: str | None) -> str:
|
||||
return f"{NL}Hint: {hint}" if hint else ""
|
||||
|
||||
|
||||
###
|
||||
# Messages
|
||||
###
|
||||
|
||||
|
||||
def initial_prompt_message(task: protocol_schema.InitialPromptTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("INITIAL PROMPT")}
|
||||
|
||||
{_writing_prompt("Please provide an initial prompt to the assistant.")}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_initial_prompts_message(task: protocol_schema.RankInitialPromptsTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK INITIAL PROMPTS")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_ordered_list(task.prompts)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_prompter_reply_message(task: protocol_schema.RankPrompterRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK PROMPTER REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_user(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("RANK ASSISTANT REPLIES")}
|
||||
|
||||
{_ranking_prompt("Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_assistant(None)}
|
||||
{_ordered_list(task.replies)}
|
||||
"""
|
||||
|
||||
|
||||
def prompter_reply_message(task: protocol_schema.PrompterReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
|
||||
return f"""\
|
||||
|
||||
{_h1("PROMPTER REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
{_hint(task.hint)}
|
||||
"""
|
||||
|
||||
|
||||
def assistant_reply_message(task: protocol_schema.AssistantReplyTask) -> str:
|
||||
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
|
||||
return f"""\
|
||||
{_h1("ASSISTANT REPLY")}
|
||||
|
||||
{_response_prompt("Please provide a reply to the assistant.")}
|
||||
|
||||
|
||||
{_conversation(task.conversation)}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_text_response_message(content: str) -> str:
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
> {content}
|
||||
"""
|
||||
|
||||
|
||||
def confirm_ranking_response_message(content: str, items: list[str]) -> str:
|
||||
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
|
||||
original_list = _make_ordered_list(items)
|
||||
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
|
||||
|
||||
return f"""\
|
||||
{_h2("CONFIRM RESPONSE")}
|
||||
|
||||
{user_ranked_list}
|
||||
"""
|
||||
|
||||
|
||||
###
|
||||
# Embeds
|
||||
###
|
||||
|
||||
|
||||
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
|
||||
return hikari.Embed(
|
||||
title="Invalid User Input",
|
||||
description=error_message,
|
||||
color=hikari.Color(0xFF0000),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
|
||||
|
||||
def plain_embed(text: str) -> hikari.Embed:
|
||||
return hikari.Embed(color=0x36393F, description=text)
|
||||
@@ -24,13 +24,6 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
EMPTY = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
|
||||
def mention(
|
||||
id: hikari.Snowflakeish,
|
||||
type: t.Literal["channel", "role", "user"],
|
||||
|
||||
Reference in New Issue
Block a user