mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
284 lines
11 KiB
Python
284 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import discord
|
|
import task_handlers
|
|
from api_client import ApiClient, TaskType
|
|
from bot_base import BotBase
|
|
from discord import app_commands
|
|
from loguru import logger
|
|
from message_templates import MessageTemplates
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from utils import get_git_head_hash, utcnow
|
|
|
|
__version__ = "0.0.3"
|
|
BOT_NAME = "Open-Assistant Junior"
|
|
|
|
|
|
class OpenAssistantBot(BotBase):
|
|
def __init__(
|
|
self,
|
|
bot_token: str,
|
|
bot_channel_name: str,
|
|
backend_url: str,
|
|
api_key: str,
|
|
owner_id: Optional[Union[int, str]] = None,
|
|
template_dir: str = "./templates",
|
|
debug: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.template_dir = Path(template_dir)
|
|
self.bot_channel_name = bot_channel_name
|
|
self.templates = MessageTemplates(template_dir)
|
|
self.debug = debug
|
|
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
|
|
if isinstance(owner_id, str):
|
|
owner_id = int(owner_id)
|
|
self.owner_id = owner_id
|
|
|
|
self.bot_token = bot_token
|
|
client = discord.Client(intents=intents)
|
|
self.client = client
|
|
self.loop = client.loop
|
|
|
|
self.bot_channel: discord.TextChannel = None
|
|
self.backend = ApiClient(backend_url, api_key)
|
|
|
|
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
|
|
|
|
@client.event
|
|
async def on_ready():
|
|
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
|
logger.info(f"{client.user} is now running!")
|
|
|
|
await self.delete_all_old_bot_messages()
|
|
# if self.debug:
|
|
# await self.post_boot_message()
|
|
await self.post_welcome_message()
|
|
|
|
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
|
|
|
@client.event
|
|
async def on_message(message: discord.Message):
|
|
# ignore own messages
|
|
if message.author != client.user:
|
|
await self.handle_message(message)
|
|
|
|
@self.tree.command()
|
|
async def tutorial(interaction: discord.Interaction):
|
|
"""Start the Open-Assistant tutorial via DMs."""
|
|
|
|
dm = await self.client.create_dm(discord.Object(interaction.user.id))
|
|
await dm.send("Tutorial coming soon... :-)")
|
|
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
|
|
|
|
@self.tree.command()
|
|
async def help(interaction: discord.Interaction):
|
|
"""Sends the user a list of all available commands"""
|
|
await self.post_help(interaction.user)
|
|
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
|
|
|
|
@self.tree.command()
|
|
async def work(interaction: discord.Interaction):
|
|
"""Request a new personalized task"""
|
|
|
|
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
|
# task = self.backend.fetch_random_task(user=None)
|
|
q = task_handlers.Questionnaire()
|
|
await interaction.response.send_modal(q)
|
|
|
|
async def post_help(self, user: discord.abc.User) -> discord.Message:
|
|
is_bot_owner = user.id == self.owner_id
|
|
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
|
|
|
|
async def post_boot_message(self) -> discord.Message:
|
|
return await self.post_template(
|
|
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
|
)
|
|
|
|
async def post_welcome_message(self) -> discord.Message:
|
|
return await self.post_template("welcome.msg")
|
|
|
|
async def delete_all_old_bot_messages(self) -> None:
|
|
logger.info("Deleting old threads...")
|
|
for thread in self.bot_channel.threads:
|
|
if thread.owner_id == self.client.user.id:
|
|
await thread.delete()
|
|
logger.info("Completed deleting old theards.")
|
|
|
|
logger.info("Deleting old messages...")
|
|
look_until = utcnow() - timedelta(days=365)
|
|
async for msg in self.bot_channel.history(limit=None):
|
|
msg: discord.Message
|
|
if msg.created_at < look_until:
|
|
break
|
|
if msg.author.id == self.client.user.id:
|
|
await msg.delete()
|
|
logger.info("Completed deleting old messages.")
|
|
|
|
async def next_task(self):
|
|
task_type = protocol_schema.TaskRequestType.random
|
|
task = self.backend.fetch_task(task_type, user=None)
|
|
|
|
handler: task_handlers.ChannelTaskBase = None
|
|
match task.type:
|
|
case TaskType.summarize_story:
|
|
handler = task_handlers.SummarizeStoryHandler()
|
|
case TaskType.rate_summary:
|
|
handler = task_handlers.RateSummaryHandler()
|
|
case TaskType.initial_prompt:
|
|
handler = task_handlers.InitialPromptHandler()
|
|
case TaskType.user_reply:
|
|
handler = task_handlers.UserReplyHandler()
|
|
case TaskType.assistant_reply:
|
|
handler = task_handlers.AssistantReplyHandler()
|
|
case TaskType.rank_initial_prompts:
|
|
handler = task_handlers.RankInitialPromptsHandler()
|
|
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
|
handler = task_handlers.RankConversationsHandler()
|
|
case _:
|
|
logger.warning(f"Unsupported task type received: {task.type}")
|
|
self.backend.nack_task(task.id, "not supported")
|
|
|
|
if handler:
|
|
try:
|
|
logger.info(f"strarting task {task.id}")
|
|
msg = await handler.start(self, task)
|
|
self.backend.ack_task(task.id, msg.id)
|
|
except Exception:
|
|
logger.exception("Starting task failed.")
|
|
self.backend.nack_task(task.id, "faled")
|
|
|
|
async def background_timer(self):
|
|
next_remove_completed = utcnow() + timedelta(seconds=10)
|
|
next_fetch_task = utcnow() + timedelta(seconds=1)
|
|
while True:
|
|
now = utcnow()
|
|
|
|
if self.bot_channel:
|
|
if now > next_fetch_task:
|
|
next_fetch_task = utcnow() + timedelta(seconds=60)
|
|
|
|
try:
|
|
await self.next_task()
|
|
except Exception:
|
|
logger.exception("fetching next task failed")
|
|
|
|
for x in self.reply_handlers.values():
|
|
x.handler.tick(now)
|
|
|
|
if now > next_remove_completed:
|
|
next_remove_completed = utcnow() + timedelta(seconds=10)
|
|
await self.remove_completed_handlers()
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
async def _sync(self, command: str, message: discord.Message):
|
|
|
|
logger.info(f"sync tree command received: {command}")
|
|
|
|
if command == "sync.copy_global":
|
|
await self.tree.copy_global_to(guild=message.guild)
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
elif command == "sync.clear_guild":
|
|
self.tree.clear_commands(guild=message.guild)
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
elif command == "sync.guild":
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
else:
|
|
synced = await self.tree.sync()
|
|
|
|
logger.info(f"Synced {len(synced)} commands")
|
|
await message.reply(f"Synced {len(synced)} commands")
|
|
|
|
async def handle_command(self, message: discord.Message, is_owner: bool):
|
|
command_text: str = message.content
|
|
command_text = command_text[1:]
|
|
match command_text:
|
|
case "help" | "?":
|
|
await self.post_help(user=message.author)
|
|
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
|
|
if is_owner:
|
|
await self._sync(command_text, message)
|
|
case _:
|
|
await message.reply(f"unknown command: {command_text}")
|
|
|
|
def recipient_filter(self, message: discord.Message) -> bool:
|
|
channel = message.channel
|
|
|
|
if (
|
|
message.channel.type == discord.ChannelType.private
|
|
or message.channel.type == discord.ChannelType.private_thread
|
|
):
|
|
return True
|
|
|
|
if (
|
|
message.channel.type == discord.ChannelType.text
|
|
or message.channel.type == discord.ChannelType.public_thread
|
|
):
|
|
while channel:
|
|
if self.bot_channel and channel.id == self.bot_channel.id:
|
|
return True
|
|
channel = channel.parent
|
|
|
|
return False
|
|
|
|
async def handle_message(self, message: discord.Message):
|
|
if not self.recipient_filter(message):
|
|
return
|
|
|
|
user_id = message.author.id
|
|
user_display_name = message.author.name
|
|
|
|
logger.debug(
|
|
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
|
|
)
|
|
|
|
command_prefix = "!"
|
|
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
|
|
is_owner = self.owner_id and user_id == self.owner_id
|
|
await self.handle_command(message, is_owner)
|
|
|
|
if isinstance(message.channel, discord.Thread):
|
|
handler = self.reply_handlers.get(message.channel.id)
|
|
if handler and not handler.handler.completed:
|
|
handler.handler.on_reply(message)
|
|
|
|
if message.reference:
|
|
handler = self.reply_handlers.get(message.reference.message_id)
|
|
if handler and not handler.handler.completed:
|
|
handler.handler.on_reply(message)
|
|
|
|
async def remove_completed_handlers(self):
|
|
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
|
|
if len(completed) == 0:
|
|
return
|
|
|
|
for c in completed:
|
|
handler = self.reply_handlers[c]
|
|
del self.reply_handlers[c]
|
|
try:
|
|
await handler.handler.finalize()
|
|
except Exception:
|
|
logger.exception("handler finalize failed")
|
|
|
|
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
|
|
|
|
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
|
for channel in self.client.get_all_channels():
|
|
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
|
return channel
|
|
|
|
def run(self):
|
|
"""Run bot loop blocking."""
|
|
self.client.run(self.bot_token)
|