Files
Open-Assistant/bot/bot.py
T
2022-12-22 21:13:05 +01:00

284 lines
11 KiB
Python

# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)