Files
Open-Assistant/discord-bot/bot/messages.py
T
Alex Ott dfd2c35276 Feat/task handler (#1056)
* move task logic to task handlers

* rename command

* remove test code

* rename classes and add missing handler creations

* switch task back to random and fetch log_channel_id from db
2023-02-02 17:28:00 +01:00

385 lines
11 KiB
Python

"""All user-facing messages and embeds.
When sending a conversation
- The function will return a list of strings
- use asyncio.gather to send all messages
-
"""
from datetime import datetime
import hikari
from oasst_shared.schemas import protocol as protocol_schema
NUMBER_EMOJIS = [":one:", ":two:", ":three:", ":four:", ":five:", ":six:", ":seven:", ":eight:", ":nine:", ":ten:"]
NL = "\n"
###
# Reusable 'components'
###
def _h1(text: str) -> str:
return f"\n:small_blue_diamond: __**{text}**__ :small_blue_diamond:"
def _h2(text: str) -> str:
return f"__**{text}**__"
def _h3(text: str) -> str:
return f"__{text}__"
def _writing_prompt(text: str) -> str:
return f":pencil: _{text}_"
def _ranking_prompt(text: str) -> str:
return f":trophy: _{text}_"
def _label_prompt(text: str, mandatory_label: list[str] | None, valid_labels: list[str]) -> str:
return f""":question: _{text}_
Mandatory labels: {", ".join(mandatory_label) if mandatory_label is not None else "None"}
Valid labels: {", ".join(valid_labels)}
"""
def _response_prompt(text: str) -> str:
return f":speech_balloon: _{text}_"
def _summarize_prompt(text: str) -> str:
return f":notepad_spiral: _{text}_"
def _user(text: str | None) -> str:
return f"""\
:person_red_hair: {_h3("User")}:{f"{NL}> **{text}**" if text is not None else ""}
"""
def _assistant(text: str | None) -> str:
return f"""\
:robot: {_h3("Assistant")}:{f"{NL}> {text}" if text is not None else ""}
"""
def _make_ordered_list(items: list[protocol_schema.ConversationMessage]) -> list[str]:
return [f"{num} {item.text}" for num, item in zip(NUMBER_EMOJIS, items)]
def _ordered_list(items: list[protocol_schema.ConversationMessage]) -> str:
return "\n\n".join(_make_ordered_list(items))
def _conversation(conv: protocol_schema.Conversation) -> list[str]:
# return "\n".join([_assistant(msg.text) if msg.is_assistant else _user(msg.text) for msg in conv.messages])
messages = map(
lambda m: f"""\
:robot: __Assistant__:
{m.text}
"""
if m.is_assistant
else f"""\
:person_red_hair: __User__:
{m.text}
""",
conv.messages,
)
return list(messages)
def _li(text: str) -> str:
return f":small_blue_diamond: {text}"
###
# Messages
###
def initial_prompt_messages(task: protocol_schema.InitialPromptTask) -> list[str]:
"""Creates the message that gets sent to users when they request an `initial_prompt` task."""
return [
f"""\
:small_blue_diamond: __**INITIAL PROMPT**__ :small_blue_diamond:
:pencil: _Please provide an initial prompt to the assistant._{f"{NL}Hint: {task.hint}" if task.hint else ""}
"""
]
def rank_initial_prompts_messages(task: protocol_schema.RankInitialPromptsTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_initial_prompts` task."""
return [
f"""\
:small_blue_diamond: __**RANK INITIAL PROMPTS**__ :small_blue_diamond:
{_ordered_list(task.prompt_messages)}
:trophy: _Reply with the numbers of best to worst prompts separated by commas (example: '4,1,3,2')_
"""
]
def rank_prompter_reply_messages(task: protocol_schema.RankPrompterRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_prompter_replies` task."""
return [
"""\
:small_blue_diamond: __**RANK PROMPTER REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
def rank_assistant_reply_message(task: protocol_schema.RankAssistantRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_assistant_replies` task."""
return [
"""\
:small_blue_diamond: __**RANK ASSISTANT REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":robot: __Assistant__:,
{_ordered_list(task.reply_messages)}
:trophy: _Reply with the numbers of best to worst replies separated by commas (example: '4,1,3,2')_
""",
]
def rank_conversation_reply_messages(task: protocol_schema.RankConversationRepliesTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `rank_conversation_replies` task."""
return [
"""\
:small_blue_diamond: __**RANK CONVERSATION REPLIES**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f""":person_red_hair: __User__:
{_ordered_list(task.reply_messages)}
""",
]
def label_initial_prompt_message(task: protocol_schema.LabelInitialPromptTask) -> str:
"""Creates the message that gets sent to users when they request a `label_initial_prompt` task."""
return f"""\
{_h1("LABEL INITIAL PROMPT")}
{task.prompt}
{_label_prompt("Reply with labels for the prompt separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
"""
def label_prompter_reply_messages(task: protocol_schema.LabelPrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_prompter_reply` task."""
return [
f"""\
{_h1("LABEL PROMPTER REPLY")}
""",
*_conversation(task.conversation),
f"""{_user(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def label_assistant_reply_messages(task: protocol_schema.LabelAssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `label_assistant_reply` task."""
return [
f"""\
{_h1("LABEL ASSISTANT REPLY")}
""",
*_conversation(task.conversation),
f"""
{_assistant(None)}
{task.reply}
{_label_prompt("Reply with labels for the reply separated by commas (example: 'profanity,misleading')", task.mandatory_labels, task.valid_labels)}
""",
]
def prompter_reply_messages(task: protocol_schema.PrompterReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `prompter_reply` task."""
return [
"""\
:small_blue_diamond: __**PROMPTER REPLY**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
f"""{f"{NL}Hint: {task.hint}" if task.hint else ""}
:speech_balloon: _Please provide a reply to the assistant._
""",
]
# def prompter_reply_messages2(task: protocol_schema.PrompterReplyTask) -> list[str]:
# """Creates the message that gets sent to users when they request a `prompter_reply` task."""
# return [
# message_templates.render("title.msg", "PROMPTER REPLY"),
# *[message_templates.render("conversation_message.msg", conv) for conv in task.conversation],
# message_templates.render("prompter_reply_task.msg", task.hint),
# ]
def assistant_reply_messages(task: protocol_schema.AssistantReplyTask) -> list[str]:
"""Creates the message that gets sent to users when they request a `assistant_reply` task."""
return [
"""\
:small_blue_diamond: __**ASSISTANT REPLY**__ :small_blue_diamond:
""",
*_conversation(task.conversation),
"""\
:speech_balloon: _Please provide a reply to the user as the assistant._
""",
]
def confirm_text_response_message(content: str) -> str:
return f"""\
{_h2("CONFIRM RESPONSE")}
> {content}
"""
def confirm_ranking_response_message(content: str, items: list[protocol_schema.ConversationMessage]) -> str:
user_rankings = [int(r) for r in content.replace(" ", "").split(",")]
original_list = _make_ordered_list(items)
user_ranked_list = "\n\n".join([original_list[r - 1] for r in user_rankings])
return f"""\
{_h2("CONFIRM RESPONSE")}
{user_ranked_list}
"""
def help_message(can_manage_guild: bool, is_dev: bool) -> str:
"""The /help command message."""
content = f"""\
{_h1("HELP")}
{_li("**`/help`**")}
Show this message.
{_li("**`/work [type]`**")}
Start a new task.
**`[type]`**:
The type of task to start. If not provided, a random task will be selected. The different types are
:small_orange_diamond: `random`: A random task type
:small_orange_diamond: ~~`summarize_story`~~ (coming soon)
:small_orange_diamond: ~~`rate_summary`~~ (coming soon)
:small_orange_diamond: `initial_prompt`: Ask the assistant something
:small_orange_diamond: `prompter_reply`: Reply to the assistant
:small_orange_diamond: `assistant_reply`: Reply to the user
:small_orange_diamond: `rank_initial_prompts`: Rank some initial prompts
:small_orange_diamond: `rank_prompter_replies`: Rank some prompter replies
:small_orange_diamond: `rank_assistant_replies`: Rank some assistant replies
To learn how to complete tasks, run `/tutorial`.
"""
if can_manage_guild:
content += f"""\
{_li("**`/settings log_channel <channel>`**")}
Set the channel that the bot logs completed task messages in.
**`<channel>`**: The channel to log completed tasks in. The bot needs to be able to send messages in this channel.
{_li("**`/settings get`**")}
Get the current settings.
"""
if is_dev:
content += f"""\
{_li("**`/reload [plugin]`**")}
Hot-reload a plugin. Only code *inside* of function bodies will be updated.
Any changes to __function signatures__, __other files__, __decorators__, or __imports__ will require a restart.
**`[plugin]`**:
The plugin to hot-reload. If no plugin is provided, all plugins are hot-reload.
"""
return content
def tutorial_message() -> str:
"""The /tutorial command message."""
# TODO: Finish message
return f"""\
{_h1("TUTORIAL")}
"""
def confirm_label_response_message(content: str) -> str:
user_labels = content.lower().replace(" ", "").split(",")
user_labels_str = ", ".join(user_labels)
return f"""\
{_h2("CONFIRM RESPONSE")}
{user_labels_str}
"""
###
# Embeds
###
def task_complete_embed(task: protocol_schema.Task, mention: str) -> hikari.Embed:
return (
hikari.Embed(
title="Task Completion",
description=f"`{task.type}` completed by {mention}",
color=hikari.Color(0x00FF00),
timestamp=datetime.now().astimezone(),
)
.add_field("Total Tasks", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
.set_footer(f"Task ID: {task.id}")
)
def invalid_user_input_embed(error_message: str) -> hikari.Embed:
return hikari.Embed(
title="Invalid User Input",
description=error_message,
color=hikari.Color(0xFF0000),
timestamp=datetime.now().astimezone(),
)
def plain_embed(text: str) -> hikari.Embed:
return hikari.Embed(color=0x36393F, description=text)