mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
6452bb860d
added playbook to deploy dev machine added playbook to deploy dev machine added next.js font module, updated fonts, updated login page replaced logos, changed logo on login and header added 404 and email confirmation pages, changed discord and github buttons to show conditionally also removed unused imports, fixed one spelling error, and minor styling changes Quick format to the authenticated user page, updated header with user profile, styling updates fixed html encoding added checkout for release re-vamped release config and ports
268 lines
10 KiB
Python
268 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import annotations
|
|
|
|
from abc import abstractmethod
|
|
from datetime import timedelta
|
|
|
|
import discord
|
|
from api_client import ApiClient
|
|
from bot_base import BotBase
|
|
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
|
|
from loguru import logger
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
|
|
|
|
|
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
|
name = discord.ui.TextInput(label="Name")
|
|
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
|
|
|
|
async def on_submit(self, interaction: discord.Interaction):
|
|
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
|
|
|
|
|
|
class ChannelTaskBase(AutoDestructThreadHandler):
|
|
thread_name: str = "Replies"
|
|
expires_after: timedelta = timedelta(minutes=5)
|
|
backend: ApiClient
|
|
|
|
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
|
|
try:
|
|
self.bot = bot
|
|
self.task = task
|
|
self.backend = bot.backend
|
|
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
|
msg = await self.send_first_message()
|
|
self.first_message = msg
|
|
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
|
await self.on_thread_created(self.thread)
|
|
except Exception:
|
|
logger.exception("start task failed")
|
|
await self.cleanup() # try to cleanup messag or thread
|
|
raise
|
|
|
|
bot.register_reply_handler(msg_id=msg.id, handler=self)
|
|
return msg
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def send_first_message(self) -> discord.message:
|
|
...
|
|
|
|
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
|
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
|
|
|
async def post_teaser_msg(self, template_name: str):
|
|
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
|
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
|
return await self.bot.post_template(
|
|
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
|
|
)
|
|
|
|
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
|
api_response = await self.backend.post_interaction(interaction)
|
|
if api_response.type != "task_done":
|
|
# multi-step tasks are not supported yet
|
|
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
|
|
raise RuntimeError("Unexpected response from backend received")
|
|
return api_response
|
|
|
|
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
|
return self.backend.post_interaction(
|
|
protocol_schema.TextReplyToPost(
|
|
post_id=str(self.first_message.id),
|
|
user_post_id=str(user_msg.id),
|
|
user=self.to_api_user(user_msg.author),
|
|
text=user_msg.content,
|
|
)
|
|
)
|
|
|
|
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
|
try:
|
|
self.post_text_reply_to_post(user_msg)
|
|
await user_msg.add_reaction("✅")
|
|
except ChannelExpiredException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Error in handle_text_reply_to_post()")
|
|
await user_msg.add_reaction("❌")
|
|
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
|
|
|
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
|
|
return self.backend.post_interaction(
|
|
protocol_schema.PostRanking(
|
|
post_id=str(self.first_message.id),
|
|
user_post_id=str(user_msg.id),
|
|
user=self.to_api_user(user_msg.author),
|
|
ranking=ranking,
|
|
)
|
|
)
|
|
|
|
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
|
|
try:
|
|
ranking_str = user_msg.content
|
|
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
|
self.post_ranking(user_msg, ranking=ranking)
|
|
await user_msg.add_reaction("✅")
|
|
except ChannelExpiredException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Error in handle_ranking()")
|
|
await user_msg.add_reaction("❌")
|
|
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
|
|
|
|
|
class SummarizeStoryHandler(ChannelTaskBase):
|
|
task: protocol_schema.SummarizeStoryTask
|
|
thread_name: str = "Summaries"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_summarize_story.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_text_reply_to_post(msg)
|
|
|
|
|
|
class InitialPromptHandler(ChannelTaskBase):
|
|
task: protocol_schema.InitialPromptTask
|
|
thread_name: str = "Prompts"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_initial_prompt.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_text_reply_to_post(msg)
|
|
|
|
|
|
class UserReplyHandler(ChannelTaskBase):
|
|
task: protocol_schema.UserReplyTask
|
|
thread_name: str = "User replies"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_user_reply.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_text_reply_to_post(msg)
|
|
|
|
|
|
class AssistantReplyHandler(ChannelTaskBase):
|
|
task: protocol_schema.AssistantReplyTask
|
|
thread_name: str = "Assistant replies"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_assistant_reply.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_text_reply_to_post(msg)
|
|
|
|
|
|
class RankInitialPromptsHandler(ChannelTaskBase):
|
|
task: protocol_schema.RankInitialPromptsTask
|
|
thread_name: str = "User Responses"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_ranking(msg)
|
|
|
|
|
|
class RankConversationsHandler(ChannelTaskBase):
|
|
task: protocol_schema.RankConversationRepliesTask
|
|
thread_name: str = "Rankings"
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
await self.handle_ranking(msg)
|
|
|
|
|
|
class RatingButton(discord.ui.Button):
|
|
def __init__(self, label, value, response_handler):
|
|
super().__init__(label=label, style=discord.ButtonStyle.green)
|
|
self.value = value
|
|
self.response_handler = response_handler
|
|
|
|
async def callback(self, interaction):
|
|
await self.response_handler(self.value, interaction)
|
|
|
|
|
|
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
|
view = discord.ui.View()
|
|
for i in range(lo, hi + 1):
|
|
view.add_item(RatingButton(str(i), i, response_handler))
|
|
return view
|
|
|
|
|
|
class RateSummaryHandler(ChannelTaskBase):
|
|
task: protocol_schema.RateSummaryTask
|
|
thread_name: str = "Ratings"
|
|
|
|
async def _rating_response_handler(self, score, interaction: discord.Interaction):
|
|
logger.info("rating_response_handler", score)
|
|
if self.thread:
|
|
try:
|
|
self.backend.post_interaction(
|
|
protocol_schema.PostRating(
|
|
post_id=str(self.first_message.id),
|
|
user_post_id=str(interaction.id),
|
|
user=self.to_api_user(interaction.user),
|
|
rating=score,
|
|
)
|
|
)
|
|
await interaction.response.send_message(
|
|
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
|
|
)
|
|
except ChannelExpiredException:
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Error in _rating_response_handler()")
|
|
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
|
|
|
|
async def send_first_message(self) -> discord.message:
|
|
return await self.post_teaser_msg("teaser_rate_summary.msg")
|
|
|
|
async def on_thread_created(self, thread: discord.Thread) -> None:
|
|
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
|
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
|
|
|
|
async def handler_loop(self):
|
|
while True:
|
|
msg = await self.read()
|
|
logger.info(f"on_rate_summary_reply: {msg.content}")
|
|
await msg.add_reaction("❌")
|
|
await msg.reply("❌ Text intput not supported.")
|