Files
Open-Assistant/discord-bot/bot.py
T
Yannic Kilcher 6452bb860d moved bot to discord-bot
added playbook to deploy dev machine

added playbook to deploy dev machine

added next.js font module, updated fonts, updated login page

replaced logos, changed logo on login and header

added 404 and email confirmation pages, changed discord and github buttons to show conditionally

also removed unused imports, fixed one spelling error, and minor styling changes

Quick format to the authenticated user page, updated header with user profile, styling updates

fixed html encoding

added checkout for release

re-vamped release config and ports
2022-12-25 11:35:54 +01:00

284 lines
11 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)