mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
6452bb860d
added playbook to deploy dev machine added playbook to deploy dev machine added next.js font module, updated fonts, updated login page replaced logos, changed logo on login and header added 404 and email confirmation pages, changed discord and github buttons to show conditionally also removed unused imports, fixed one spelling error, and minor styling changes Quick format to the authenticated user page, updated header with user profile, styling updates fixed html encoding added checkout for release re-vamped release config and ports
284 lines
11 KiB
Python
284 lines
11 KiB
Python
# -*- coding: utf-8 -*-
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from datetime import timedelta
|
|
from pathlib import Path
|
|
from typing import Optional, Union
|
|
|
|
import discord
|
|
import task_handlers
|
|
from api_client import ApiClient, TaskType
|
|
from bot_base import BotBase
|
|
from discord import app_commands
|
|
from loguru import logger
|
|
from message_templates import MessageTemplates
|
|
from oasst_shared.schemas import protocol as protocol_schema
|
|
from utils import get_git_head_hash, utcnow
|
|
|
|
__version__ = "0.0.3"
|
|
BOT_NAME = "Open-Assistant Junior"
|
|
|
|
|
|
class OpenAssistantBot(BotBase):
|
|
def __init__(
|
|
self,
|
|
bot_token: str,
|
|
bot_channel_name: str,
|
|
backend_url: str,
|
|
api_key: str,
|
|
owner_id: Optional[Union[int, str]] = None,
|
|
template_dir: str = "./templates",
|
|
debug: bool = False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.template_dir = Path(template_dir)
|
|
self.bot_channel_name = bot_channel_name
|
|
self.templates = MessageTemplates(template_dir)
|
|
self.debug = debug
|
|
|
|
intents = discord.Intents.default()
|
|
intents.message_content = True
|
|
|
|
if isinstance(owner_id, str):
|
|
owner_id = int(owner_id)
|
|
self.owner_id = owner_id
|
|
|
|
self.bot_token = bot_token
|
|
client = discord.Client(intents=intents)
|
|
self.client = client
|
|
self.loop = client.loop
|
|
|
|
self.bot_channel: discord.TextChannel = None
|
|
self.backend = ApiClient(backend_url, api_key)
|
|
|
|
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
|
|
|
|
@client.event
|
|
async def on_ready():
|
|
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
|
logger.info(f"{client.user} is now running!")
|
|
|
|
await self.delete_all_old_bot_messages()
|
|
# if self.debug:
|
|
# await self.post_boot_message()
|
|
await self.post_welcome_message()
|
|
|
|
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
|
|
|
@client.event
|
|
async def on_message(message: discord.Message):
|
|
# ignore own messages
|
|
if message.author != client.user:
|
|
await self.handle_message(message)
|
|
|
|
@self.tree.command()
|
|
async def tutorial(interaction: discord.Interaction):
|
|
"""Start the Open-Assistant tutorial via DMs."""
|
|
|
|
dm = await self.client.create_dm(discord.Object(interaction.user.id))
|
|
await dm.send("Tutorial coming soon... :-)")
|
|
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
|
|
|
|
@self.tree.command()
|
|
async def help(interaction: discord.Interaction):
|
|
"""Sends the user a list of all available commands"""
|
|
await self.post_help(interaction.user)
|
|
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
|
|
|
|
@self.tree.command()
|
|
async def work(interaction: discord.Interaction):
|
|
"""Request a new personalized task"""
|
|
|
|
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
|
# task = self.backend.fetch_random_task(user=None)
|
|
q = task_handlers.Questionnaire()
|
|
await interaction.response.send_modal(q)
|
|
|
|
async def post_help(self, user: discord.abc.User) -> discord.Message:
|
|
is_bot_owner = user.id == self.owner_id
|
|
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
|
|
|
|
async def post_boot_message(self) -> discord.Message:
|
|
return await self.post_template(
|
|
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
|
)
|
|
|
|
async def post_welcome_message(self) -> discord.Message:
|
|
return await self.post_template("welcome.msg")
|
|
|
|
async def delete_all_old_bot_messages(self) -> None:
|
|
logger.info("Deleting old threads...")
|
|
for thread in self.bot_channel.threads:
|
|
if thread.owner_id == self.client.user.id:
|
|
await thread.delete()
|
|
logger.info("Completed deleting old theards.")
|
|
|
|
logger.info("Deleting old messages...")
|
|
look_until = utcnow() - timedelta(days=365)
|
|
async for msg in self.bot_channel.history(limit=None):
|
|
msg: discord.Message
|
|
if msg.created_at < look_until:
|
|
break
|
|
if msg.author.id == self.client.user.id:
|
|
await msg.delete()
|
|
logger.info("Completed deleting old messages.")
|
|
|
|
async def next_task(self):
|
|
task_type = protocol_schema.TaskRequestType.random
|
|
task = self.backend.fetch_task(task_type, user=None)
|
|
|
|
handler: task_handlers.ChannelTaskBase = None
|
|
match task.type:
|
|
case TaskType.summarize_story:
|
|
handler = task_handlers.SummarizeStoryHandler()
|
|
case TaskType.rate_summary:
|
|
handler = task_handlers.RateSummaryHandler()
|
|
case TaskType.initial_prompt:
|
|
handler = task_handlers.InitialPromptHandler()
|
|
case TaskType.user_reply:
|
|
handler = task_handlers.UserReplyHandler()
|
|
case TaskType.assistant_reply:
|
|
handler = task_handlers.AssistantReplyHandler()
|
|
case TaskType.rank_initial_prompts:
|
|
handler = task_handlers.RankInitialPromptsHandler()
|
|
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
|
|
handler = task_handlers.RankConversationsHandler()
|
|
case _:
|
|
logger.warning(f"Unsupported task type received: {task.type}")
|
|
self.backend.nack_task(task.id, "not supported")
|
|
|
|
if handler:
|
|
try:
|
|
logger.info(f"strarting task {task.id}")
|
|
msg = await handler.start(self, task)
|
|
self.backend.ack_task(task.id, msg.id)
|
|
except Exception:
|
|
logger.exception("Starting task failed.")
|
|
self.backend.nack_task(task.id, "faled")
|
|
|
|
async def background_timer(self):
|
|
next_remove_completed = utcnow() + timedelta(seconds=10)
|
|
next_fetch_task = utcnow() + timedelta(seconds=1)
|
|
while True:
|
|
now = utcnow()
|
|
|
|
if self.bot_channel:
|
|
if now > next_fetch_task:
|
|
next_fetch_task = utcnow() + timedelta(seconds=60)
|
|
|
|
try:
|
|
await self.next_task()
|
|
except Exception:
|
|
logger.exception("fetching next task failed")
|
|
|
|
for x in self.reply_handlers.values():
|
|
x.handler.tick(now)
|
|
|
|
if now > next_remove_completed:
|
|
next_remove_completed = utcnow() + timedelta(seconds=10)
|
|
await self.remove_completed_handlers()
|
|
|
|
await asyncio.sleep(1)
|
|
|
|
async def _sync(self, command: str, message: discord.Message):
|
|
|
|
logger.info(f"sync tree command received: {command}")
|
|
|
|
if command == "sync.copy_global":
|
|
await self.tree.copy_global_to(guild=message.guild)
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
elif command == "sync.clear_guild":
|
|
self.tree.clear_commands(guild=message.guild)
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
elif command == "sync.guild":
|
|
synced = await self.tree.sync(guild=message.guild)
|
|
else:
|
|
synced = await self.tree.sync()
|
|
|
|
logger.info(f"Synced {len(synced)} commands")
|
|
await message.reply(f"Synced {len(synced)} commands")
|
|
|
|
async def handle_command(self, message: discord.Message, is_owner: bool):
|
|
command_text: str = message.content
|
|
command_text = command_text[1:]
|
|
match command_text:
|
|
case "help" | "?":
|
|
await self.post_help(user=message.author)
|
|
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
|
|
if is_owner:
|
|
await self._sync(command_text, message)
|
|
case _:
|
|
await message.reply(f"unknown command: {command_text}")
|
|
|
|
def recipient_filter(self, message: discord.Message) -> bool:
|
|
channel = message.channel
|
|
|
|
if (
|
|
message.channel.type == discord.ChannelType.private
|
|
or message.channel.type == discord.ChannelType.private_thread
|
|
):
|
|
return True
|
|
|
|
if (
|
|
message.channel.type == discord.ChannelType.text
|
|
or message.channel.type == discord.ChannelType.public_thread
|
|
):
|
|
while channel:
|
|
if self.bot_channel and channel.id == self.bot_channel.id:
|
|
return True
|
|
channel = channel.parent
|
|
|
|
return False
|
|
|
|
async def handle_message(self, message: discord.Message):
|
|
if not self.recipient_filter(message):
|
|
return
|
|
|
|
user_id = message.author.id
|
|
user_display_name = message.author.name
|
|
|
|
logger.debug(
|
|
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
|
|
)
|
|
|
|
command_prefix = "!"
|
|
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
|
|
is_owner = self.owner_id and user_id == self.owner_id
|
|
await self.handle_command(message, is_owner)
|
|
|
|
if isinstance(message.channel, discord.Thread):
|
|
handler = self.reply_handlers.get(message.channel.id)
|
|
if handler and not handler.handler.completed:
|
|
handler.handler.on_reply(message)
|
|
|
|
if message.reference:
|
|
handler = self.reply_handlers.get(message.reference.message_id)
|
|
if handler and not handler.handler.completed:
|
|
handler.handler.on_reply(message)
|
|
|
|
async def remove_completed_handlers(self):
|
|
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
|
|
if len(completed) == 0:
|
|
return
|
|
|
|
for c in completed:
|
|
handler = self.reply_handlers[c]
|
|
del self.reply_handlers[c]
|
|
try:
|
|
await handler.handler.finalize()
|
|
except Exception:
|
|
logger.exception("handler finalize failed")
|
|
|
|
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
|
|
|
|
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
|
for channel in self.client.get_all_channels():
|
|
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
|
return channel
|
|
|
|
def run(self):
|
|
"""Run bot loop blocking."""
|
|
self.client.run(self.bot_token)
|