diff --git a/discord-bot/.env.example b/discord-bot/.env.example new file mode 100644 index 00000000..89e50c05 --- /dev/null +++ b/discord-bot/.env.example @@ -0,0 +1,3 @@ +TOKEN= +DECLARE_GLOBAL_COMMANDS= +OWNER_IDS= \ No newline at end of file diff --git a/discord-bot/.gitignore b/discord-bot/.gitignore index a7982d60..2842b686 100644 --- a/discord-bot/.gitignore +++ b/discord-bot/.gitignore @@ -1,3 +1,7 @@ .env *.egg-info/ __pycache__/ + +.venv +.nox +.env \ No newline at end of file diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md new file mode 100644 index 00000000..089a0c33 --- /dev/null +++ b/discord-bot/CONTRIBUTING.md @@ -0,0 +1,43 @@ +# Contributing + +## Setup + +To run the bot + +``` +cp .env.example .env + +python -V # 3.10 + +pip install -r requirements.txt +python -m bot +``` + +To test the bot + +``` +python -m pip install -r dev-requirements.txt + +nox +``` + +To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. + +1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) +2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable. + +## Resources + +Main framework + +- [Hikari Repo](https://github.com/hikari-py/hikari) +- [Hikari Docs](https://docs.hikari-py.dev/en/latest/) + +Command handler + +- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb) +- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/) + +Component handler (buttons, modals, etc... ) + +- [Miru Repo](https://github.com/HyperGH/hikari-miru) diff --git a/discord-bot/README.md b/discord-bot/README.md index a585b37f..cde82025 100644 --- a/discord-bot/README.md +++ b/discord-bot/README.md @@ -6,15 +6,6 @@ This bot collects human feedback to create a dataset for RLHF-alignment of an as To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages. -## Bot token for development +## Contributing -To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. - -1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) -2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`. - -The simplest way to configure the token is via an `.env` file: - -``` -BOT_TOKEN=XYZABC123... -``` +To contribute to the bot, please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file. diff --git a/discord-bot/__main__.py b/discord-bot/__main__.py deleted file mode 100644 index 9e5e29c7..00000000 --- a/discord-bot/__main__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -from bot import OpenAssistantBot -from bot_settings import settings - -# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot - -if __name__ == "__main__": - bot = OpenAssistantBot( - settings.BOT_TOKEN, - bot_channel_name=settings.BOT_CHANNEL_NAME, - backend_url=settings.BACKEND_URL, - api_key=settings.API_KEY, - owner_id=settings.OWNER_ID, - template_dir=settings.TEMPLATE_DIR, - debug=settings.DEBUG, - ) - bot.run() diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py index 1de6bb17..0caa1595 100644 --- a/discord-bot/api_client.py +++ b/discord-bot/api_client.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import enum from typing import Optional, Type +import typing as t import requests from oasst_shared.schemas import protocol as protocol_schema @@ -41,7 +42,7 @@ class ApiClient: response.raise_for_status() return response.json() - def _parse_task(self, data: dict) -> protocol_schema.Task: + def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: if not isinstance(data, dict): raise ValueError("dict expected") diff --git a/discord-bot/bot.py b/discord-bot/bot.py deleted file mode 100644 index a19fdfe1..00000000 --- a/discord-bot/bot.py +++ /dev/null @@ -1,283 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -import asyncio -from datetime import timedelta -from pathlib import Path -from typing import Optional, Union - -import discord -import task_handlers -from api_client import ApiClient, TaskType -from bot_base import BotBase -from discord import app_commands -from loguru import logger -from message_templates import MessageTemplates -from oasst_shared.schemas import protocol as protocol_schema -from utils import get_git_head_hash, utcnow - -__version__ = "0.0.3" -BOT_NAME = "Open-Assistant Junior" - - -class OpenAssistantBot(BotBase): - def __init__( - self, - bot_token: str, - bot_channel_name: str, - backend_url: str, - api_key: str, - owner_id: Optional[Union[int, str]] = None, - template_dir: str = "./templates", - debug: bool = False, - ): - super().__init__() - - self.template_dir = Path(template_dir) - self.bot_channel_name = bot_channel_name - self.templates = MessageTemplates(template_dir) - self.debug = debug - - intents = discord.Intents.default() - intents.message_content = True - - if isinstance(owner_id, str): - owner_id = int(owner_id) - self.owner_id = owner_id - - self.bot_token = bot_token - client = discord.Client(intents=intents) - self.client = client - self.loop = client.loop - - self.bot_channel: discord.TextChannel = None - self.backend = ApiClient(backend_url, api_key) - - self.tree = app_commands.CommandTree(self.client, fallback_to_global=True) - - @client.event - async def on_ready(): - self.bot_channel = self.get_text_channel_by_name(bot_channel_name) - logger.info(f"{client.user} is now running!") - - await self.delete_all_old_bot_messages() - # if self.debug: - # await self.post_boot_message() - await self.post_welcome_message() - - client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") - - @client.event - async def on_message(message: discord.Message): - # ignore own messages - if message.author != client.user: - await self.handle_message(message) - - @self.tree.command() - async def tutorial(interaction: discord.Interaction): - """Start the Open-Assistant tutorial via DMs.""" - - dm = await self.client.create_dm(discord.Object(interaction.user.id)) - await dm.send("Tutorial coming soon... :-)") - await interaction.response.send_message(f"tutorial command by {interaction.user.name}") - - @self.tree.command() - async def help(interaction: discord.Interaction): - """Sends the user a list of all available commands""" - await self.post_help(interaction.user) - await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.") - - @self.tree.command() - async def work(interaction: discord.Interaction): - """Request a new personalized task""" - - # task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) - # task = self.backend.fetch_random_task(user=None) - q = task_handlers.Questionnaire() - await interaction.response.send_modal(q) - - async def post_help(self, user: discord.abc.User) -> discord.Message: - is_bot_owner = user.id == self.owner_id - return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner) - - async def post_boot_message(self) -> discord.Message: - return await self.post_template( - "boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug - ) - - async def post_welcome_message(self) -> discord.Message: - return await self.post_template("welcome.msg") - - async def delete_all_old_bot_messages(self) -> None: - logger.info("Deleting old threads...") - for thread in self.bot_channel.threads: - if thread.owner_id == self.client.user.id: - await thread.delete() - logger.info("Completed deleting old theards.") - - logger.info("Deleting old messages...") - look_until = utcnow() - timedelta(days=365) - async for msg in self.bot_channel.history(limit=None): - msg: discord.Message - if msg.created_at < look_until: - break - if msg.author.id == self.client.user.id: - await msg.delete() - logger.info("Completed deleting old messages.") - - async def next_task(self): - task_type = protocol_schema.TaskRequestType.random - task = self.backend.fetch_task(task_type, user=None) - - handler: task_handlers.ChannelTaskBase = None - match task.type: - case TaskType.summarize_story: - handler = task_handlers.SummarizeStoryHandler() - case TaskType.rate_summary: - handler = task_handlers.RateSummaryHandler() - case TaskType.initial_prompt: - handler = task_handlers.InitialPromptHandler() - case TaskType.user_reply: - handler = task_handlers.UserReplyHandler() - case TaskType.assistant_reply: - handler = task_handlers.AssistantReplyHandler() - case TaskType.rank_initial_prompts: - handler = task_handlers.RankInitialPromptsHandler() - case TaskType.rank_user_replies | TaskType.rank_assistant_replies: - handler = task_handlers.RankConversationsHandler() - case _: - logger.warning(f"Unsupported task type received: {task.type}") - self.backend.nack_task(task.id, "not supported") - - if handler: - try: - logger.info(f"strarting task {task.id}") - msg = await handler.start(self, task) - self.backend.ack_task(task.id, msg.id) - except Exception: - logger.exception("Starting task failed.") - self.backend.nack_task(task.id, "faled") - - async def background_timer(self): - next_remove_completed = utcnow() + timedelta(seconds=10) - next_fetch_task = utcnow() + timedelta(seconds=1) - while True: - now = utcnow() - - if self.bot_channel: - if now > next_fetch_task: - next_fetch_task = utcnow() + timedelta(seconds=60) - - try: - await self.next_task() - except Exception: - logger.exception("fetching next task failed") - - for x in self.reply_handlers.values(): - x.handler.tick(now) - - if now > next_remove_completed: - next_remove_completed = utcnow() + timedelta(seconds=10) - await self.remove_completed_handlers() - - await asyncio.sleep(1) - - async def _sync(self, command: str, message: discord.Message): - - logger.info(f"sync tree command received: {command}") - - if command == "sync.copy_global": - await self.tree.copy_global_to(guild=message.guild) - synced = await self.tree.sync(guild=message.guild) - elif command == "sync.clear_guild": - self.tree.clear_commands(guild=message.guild) - synced = await self.tree.sync(guild=message.guild) - elif command == "sync.guild": - synced = await self.tree.sync(guild=message.guild) - else: - synced = await self.tree.sync() - - logger.info(f"Synced {len(synced)} commands") - await message.reply(f"Synced {len(synced)} commands") - - async def handle_command(self, message: discord.Message, is_owner: bool): - command_text: str = message.content - command_text = command_text[1:] - match command_text: - case "help" | "?": - await self.post_help(user=message.author) - case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild": - if is_owner: - await self._sync(command_text, message) - case _: - await message.reply(f"unknown command: {command_text}") - - def recipient_filter(self, message: discord.Message) -> bool: - channel = message.channel - - if ( - message.channel.type == discord.ChannelType.private - or message.channel.type == discord.ChannelType.private_thread - ): - return True - - if ( - message.channel.type == discord.ChannelType.text - or message.channel.type == discord.ChannelType.public_thread - ): - while channel: - if self.bot_channel and channel.id == self.bot_channel.id: - return True - channel = channel.parent - - return False - - async def handle_message(self, message: discord.Message): - if not self.recipient_filter(message): - return - - user_id = message.author.id - user_display_name = message.author.name - - logger.debug( - f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})" - ) - - command_prefix = "!" - if message.type == discord.MessageType.default and message.content.startswith(command_prefix): - is_owner = self.owner_id and user_id == self.owner_id - await self.handle_command(message, is_owner) - - if isinstance(message.channel, discord.Thread): - handler = self.reply_handlers.get(message.channel.id) - if handler and not handler.handler.completed: - handler.handler.on_reply(message) - - if message.reference: - handler = self.reply_handlers.get(message.reference.message_id) - if handler and not handler.handler.completed: - handler.handler.on_reply(message) - - async def remove_completed_handlers(self): - completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed] - if len(completed) == 0: - return - - for c in completed: - handler = self.reply_handlers[c] - del self.reply_handlers[c] - try: - await handler.handler.finalize() - except Exception: - logger.exception("handler finalize failed") - - logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})") - - def get_text_channel_by_name(self, channel_name) -> discord.TextChannel: - for channel in self.client.get_all_channels(): - if channel.type == discord.ChannelType.text and channel.name == channel_name: - return channel - - def run(self): - """Run bot loop blocking.""" - self.client.run(self.bot_token) diff --git a/discord-bot/bot/__init__.py b/discord-bot/bot/__init__.py new file mode 100644 index 00000000..3d04718d --- /dev/null +++ b/discord-bot/bot/__init__.py @@ -0,0 +1,2 @@ +# -*- coding=utf-8 -*- +"""The official Open-Assistant Discord Bot.""" diff --git a/discord-bot/bot/__main__.py b/discord-bot/bot/__main__.py new file mode 100644 index 00000000..f258d148 --- /dev/null +++ b/discord-bot/bot/__main__.py @@ -0,0 +1,17 @@ +# -*- coding=utf-8 -*- +"""Entry point for the bot.""" +import logging +import os + +from bot.bot import bot + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + if os.name != "nt": + import uvloop + + uvloop.install() + + logger.info("Starting bot") + bot.run() diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py new file mode 100644 index 00000000..e529cf75 --- /dev/null +++ b/discord-bot/bot/bot.py @@ -0,0 +1,37 @@ +# -*- coding=utf-8 -*- +"""Bot logic.""" +import hikari + +import aiosqlite +import lightbulb +import miru +from bot.config import Config + +config = Config.from_env() + +bot = lightbulb.BotApp( + token=config.token, + logs="DEBUG", + prefix="./", + default_enabled_guilds=config.declare_global_commands, + owner_ids=config.owner_ids, + intents=hikari.Intents.ALL, +) + + +@bot.listen() +async def on_starting(event: hikari.StartingEvent): + """Setup.""" + + miru.install(bot) # component handler + bot.load_extensions_from("./bot/extensions") # load extensions + + bot.d.db = await aiosqlite.connect(":memory:") # TODO: Update + await bot.d.db.executescript(open("./bot/db/schema.sql").read()) + await bot.d.db.commit() + + +@bot.listen() +async def on_stopping(event: hikari.StoppingEvent): + """Cleanup.""" + await bot.d.db.close() diff --git a/discord-bot/bot/config.py b/discord-bot/bot/config.py new file mode 100644 index 00000000..5905301c --- /dev/null +++ b/discord-bot/bot/config.py @@ -0,0 +1,35 @@ +# -*- coding=utf-8 -*- +"""Configuration for the bot.""" + +import logging +from dataclasses import dataclass +from os import getenv + +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """Configuration for the bot.""" + + token: str + declare_global_commands: int + owner_ids: list[int] + + @classmethod + def from_env(cls): + token = getenv("TOKEN", None) + + if token is None: + logger.error("Invalid token, please set the TOKEN environment variable.") + exit(1) + + return cls( + token=token, + declare_global_commands=int(getenv("DECLARE_GLOBAL_COMMANDS", 0)), + owner_ids=[int(x) for x in getenv("OWNER_IDS", "").split(",")], + ) diff --git a/discord-bot/bot/db/database.db b/discord-bot/bot/db/database.db new file mode 100644 index 00000000..e69de29b diff --git a/discord-bot/bot/db/schema.sql b/discord-bot/bot/db/schema.sql new file mode 100644 index 00000000..9fedf1da --- /dev/null +++ b/discord-bot/bot/db/schema.sql @@ -0,0 +1,10 @@ +-- Sqlite3 schema for the bot +CREATE TABLE IF NOT EXISTS guild_settings ( + guild_id BIGINT NOT NULL PRIMARY KEY, + log_channel_id BIGINT +); + +CREATE TABLE IF NOT EXISTS example ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL +); diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py new file mode 100644 index 00000000..ffb7ea70 --- /dev/null +++ b/discord-bot/bot/extensions/hot_reload.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +"""Hot reload plugin.""" +from glob import glob + +import hikari +import lightbulb + +plugin = lightbulb.Plugin( + "HotReloadPlugin", +) +plugin.add_checks(lightbulb.owner_only) + +EXTENSIONS_FOLDER = "bot/extensions" + + +def _get_extensions() -> list[str]: + # Recursively get all the .py files in the extensions directory. + exts = glob("bot/extensions/**/*.py", recursive=True) + # Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension") + return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts] + + +async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]: + # Check that the option is a string. + if not isinstance(option.value, str): + raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`") + + exts = _get_extensions() + return [ext for ext in exts if option.value in ext] + + +@plugin.command +@lightbulb.option( + "plugin", + "The plugin to reload. Leave empty to reload all plugins.", + autocomplete=_plugin_autocomplete, + required=False, + default=None, +) +@lightbulb.command("reload", "Reload a plugin") +@lightbulb.implements(lightbulb.SlashCommand) +async def reload(ctx: lightbulb.SlashContext): + """Reload a plugin or all plugins.""" + # If the plugin option is None, reload all plugins. + if ctx.options.plugin is None: + ctx.bot.reload_extensions(*_get_extensions()) + await ctx.respond("Reloaded all plugins.") + # Otherwise, reload the specified plugin. + else: + ctx.bot.reload_extensions(ctx.options.plugin) + await ctx.respond(f"Reloaded `{ctx.options.plugin}`.") + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot_base.py b/discord-bot/bot_base.py deleted file mode 100644 index 76eca22d..00000000 --- a/discord-bot/bot_base.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -import asyncio -from abc import ABC -from dataclasses import dataclass -from typing import Any - -import discord -from api_client import ApiClient -from channel_handlers import ChannelHandlerBase -from loguru import logger -from message_templates import MessageTemplates - - -@dataclass -class ReplyHandlerInfo: - msg_id: int - handler_task: asyncio.Task - handler: ChannelHandlerBase - - -class BotBase(ABC): - bot_channel_name: str - debug: bool - backend: ApiClient - client: discord.Client - loop: asyncio.BaseEventLoop - owner_id: int - bot_channel: discord.TextChannel - templates: MessageTemplates - reply_handlers: dict[int, ReplyHandlerInfo] - - def __init__(self): - self.reply_handlers = {} # handlers by msg_id - - def ensure_bot_channel(self) -> None: - if self.bot_channel is None: - raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found") - - async def post( - self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None - ) -> discord.Message: - if channel is None: - self.ensure_bot_channel() - channel = self.bot_channel - return await channel.send(content=content, view=view) - - async def post_template( - self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any - ) -> discord.Message: - logger.debug(f"rendering {name}") - text = self.templates.render(name, **kwargs) - return await self.post(text, view=view, channel=channel) - - def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase): - if msg_id in self.reply_handlers: - raise RuntimeError(f"Handler already registered for msg_id: {msg_id}") - task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})") - task.add_done_callback(lambda t: handler.on_completed()) - self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler) diff --git a/discord-bot/bot_settings.py b/discord-bot/bot_settings.py deleted file mode 100644 index c976d7cd..00000000 --- a/discord-bot/bot_settings.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -from pydantic import AnyHttpUrl, BaseSettings - - -class BotSettings(BaseSettings): - BACKEND_URL: AnyHttpUrl = "http://localhost:8080" - API_KEY: str = "any_key" - BOT_TOKEN: str - BOT_CHANNEL_NAME: str = "bot" - OWNER_ID: int = None - TEMPLATE_DIR: str = "./templates" - DEBUG: bool = True - - -settings = BotSettings(_env_file=".env") diff --git a/discord-bot/channel_handlers.py b/discord-bot/channel_handlers.py deleted file mode 100644 index 75f03c0e..00000000 --- a/discord-bot/channel_handlers.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- -import asyncio -from abc import ABC, abstractmethod -from datetime import datetime - -import discord -from loguru import logger - - -class ChannelExpiredException(Exception): - pass - - -class ChannelHandlerBase(ABC): - queue: asyncio.Queue - completed: bool = False - expiry_date: datetime - expired: bool = False - - def __init__(self, *, expiry_date: datetime = None): - self.expiry_date = expiry_date - self.queue = asyncio.Queue() - - async def read(self) -> discord.Message: - """Call this method to read the next message from the user in the handler method.""" - if self.expired: - raise ChannelExpiredException() - - msg = await self.queue.get() - if msg is None: - if self.expired: - raise ChannelExpiredException() - else: - raise RuntimeError("Unexpected None message read") - return msg - - def on_reply(self, message: discord.Message) -> None: - self.queue.put_nowait(message) - - def on_expire(self) -> None: - logger.info("ChannelHandler: on_expire") - self.expired = True - self.queue.put_nowait(None) - - def on_completed(self) -> None: - logger.info("ChannelHandler: on_completed") - self.completed = True - - def tick(self, now: datetime): - if now > self.expiry_date and not self.expired: - self.on_expire() - - @abstractmethod - async def handler_loop(self): - ... - - async def finalize(self): - pass - - -class AutoDestructThreadHandler(ChannelHandlerBase): - first_message: discord.Message = None - thread: discord.Thread = None - - def __init__(self, *, expiry_date: datetime = None): - super().__init__(expiry_date=expiry_date) - - async def read(self) -> discord.Message: - try: - return await super().read() - except ChannelExpiredException: - await self.cleanup() - raise - - async def cleanup(self): - logger.debug("AutoDestructThreadHandler.cleanup") - if self.thread: - logger.debug(f"deleting thread: {self.thread.name}") - await self.thread.delete() - self.thread = None - if self.first_message: - logger.debug(f"deleting first_message: {self.first_message.content}") - await self.first_message.delete() - self.first_message = None - - async def finalize(self): - await self.cleanup() - return await super().finalize() diff --git a/discord-bot/dev-requirements.txt b/discord-bot/dev-requirements.txt new file mode 100644 index 00000000..44a8d2cc --- /dev/null +++ b/discord-bot/dev-requirements.txt @@ -0,0 +1,8 @@ +nox + +black +isort + +codespell +flake8 +pyright \ No newline at end of file diff --git a/discord-bot/flake8-requirements.txt b/discord-bot/flake8-requirements.txt new file mode 100644 index 00000000..3509207e --- /dev/null +++ b/discord-bot/flake8-requirements.txt @@ -0,0 +1,26 @@ +flake8==6.0.0 + +# Plugins + +Flake8-pyproject # use the pyproject.toml as the config file +flake8-bandit # runs bandit +flake8-black # runs black +# flake8-broken-line # forbey "\" linebreaks +flake8-builtins # builtin shadowing checks +flake8-coding # coding magic-comment detection +flake8-comprehensions # comprehension checks +flake8-deprecated # deprecated call checks +flake8-docstrings # pydocstyle support +flake8-executable # shebangs +flake8-fixme # "fix me" counter +flake8-functions # function linting +flake8-html # html output +flake8-if-statements # condition linting +flake8-isort # runs isort +flake8-mutable # mutable default argument detection +flake8-pep3101 # new-style format strings only +flake8-print # complain about print statements in code +flake8-printf-formatting # forbey printf-style python2 string formatting +flake8-pytest-style # pytest checks +flake8-raise # exception raising linting +flake8-use-fstring # format string checking diff --git a/discord-bot/noxfile.py b/discord-bot/noxfile.py new file mode 100644 index 00000000..37226787 --- /dev/null +++ b/discord-bot/noxfile.py @@ -0,0 +1,33 @@ +# -*- coding=utf-8 -*- +"""Automated linting, formatting, and typechecking.""" +import nox +from nox.sessions import Session + + +@nox.session(reuse_venv=True) +def format_code(session: Session): + """Format the codebase.""" + session.install("isort", "-U") + session.install("black", "-U") + + session.run("isort", "bot") + session.run("black", "bot") + + +@nox.session(reuse_venv=True) +def lint_code(session: Session): + """Lint the codebase.""" + session.install("codespell", "-U") + session.install("flake8", "-U") + session.install("-r", "flake8-requirements.txt", "-U") + + session.run("codespell", "bot") + session.run("flake8", "bot") + + +@nox.session(reuse_venv=True) +def typecheck_code(session: Session): + session.install("-r", "requirements.txt", "-U") + session.install("pyright", "-U") + + session.run("pyright", "bot") diff --git a/discord-bot/pyproject.toml b/discord-bot/pyproject.toml new file mode 100644 index 00000000..7a1e8d82 --- /dev/null +++ b/discord-bot/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "Open-Assistant Discord Bot" +version = "0.0.1" + +[tool.black] +line-length = 120 +target-version = ["py310"] + +[tool.pyright] +include = ["ottbot", "noxfile.py"] +pythonVersion="3.10" +reportMissingImports=false +# reportInvalidTypeVarUse=false +# reportMissingModuleSource=false +reportUnknownVariableType=false +pythonPlatform="Linux" + +[tool.isort] +profile="black" +sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] +skip_glob = "**/__init__.pyi" + +[tool.flake8] +max-function-length = 130 +max-line-length = 130 +# Technically this is 120, but black has a policy of "1 or 2 over is fine if it is tidier", so we have to raise this. +accept-encodings = "utf-8" +docstring-convention = "numpy" +ignore = [ + "A002", # Argument is shadowing a python builtin. + "A003", # Class attribute is shadowing a python builtin. + "CFQ002", # Function has too many arguments. + "CFQ004", # Function has too many returns. + "D001", # False positive for depreciated functions. + "D102", # Missing docstring in public method. + "D105", # Magic methods not having a docstring. + "D412", # No blank lines allowed between a section header and its content + "E203", # Whitespace after : (to match how black formats it) + "E402", # Module level import not at top of file (isn't compatible with our import style). + "T101", # TO-DO comment detection (T102 is FIX-ME and T103 is XXX). + "W503", # line break before binary operator. + "W504", # line break before binary operator (again, I guess). + "S101", # Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. + "S105", # Possible hardcoded password. + "EXE002", # Executable file with not shebang + "D401", # Imperative mood +] diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 927ebcf2..49c5e1ba 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -1,7 +1,10 @@ -discord.py==2.1.0 -Jinja2==3.1.2 -pydantic==1.9.1 -python-dotenv==0.21.0 -pytz==2022.7 -requests==2.28.1 -schedule==1.1.0 +hikari # discord framework +hikari[speedups] +uvloop; os_name != 'nt' +hikari-lightbulb # command handler +hikari-miru # modals and buttons + +python-dotenv # .env file support +aiosqlite # database +aiohttp # http client +aiohttp[speedups] # speedups for aiohttp \ No newline at end of file diff --git a/discord-bot/task_handlers.py b/discord-bot/task_handlers.py deleted file mode 100644 index 1434d17c..00000000 --- a/discord-bot/task_handlers.py +++ /dev/null @@ -1,267 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -from abc import abstractmethod -from datetime import timedelta - -import discord -from api_client import ApiClient -from bot_base import BotBase -from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException -from loguru import logger -from oasst_shared.schemas import protocol as protocol_schema -from utils import DiscordTimestampStyle, discord_timestamp, utcnow - - -class Questionnaire(discord.ui.Modal, title="Questionnaire Response"): - name = discord.ui.TextInput(label="Name") - answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph) - - async def on_submit(self, interaction: discord.Interaction): - await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True) - - -class ChannelTaskBase(AutoDestructThreadHandler): - thread_name: str = "Replies" - expires_after: timedelta = timedelta(minutes=5) - backend: ApiClient - - async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message: - try: - self.bot = bot - self.task = task - self.backend = bot.backend - self.expiry_date = utcnow() + self.expires_after if self.expires_after else None - msg = await self.send_first_message() - self.first_message = msg - self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name) - await self.on_thread_created(self.thread) - except Exception: - logger.exception("start task failed") - await self.cleanup() # try to cleanup messag or thread - raise - - bot.register_reply_handler(msg_id=msg.id, handler=self) - return msg - - async def on_thread_created(self, thread: discord.Thread) -> None: - pass - - @abstractmethod - async def send_first_message(self) -> discord.message: - ... - - def to_api_user(self, user: discord.User) -> protocol_schema.User: - return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name) - - async def post_teaser_msg(self, template_name: str): - expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time) - expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time) - return await self.bot.post_template( - template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative - ) - - async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - api_response = await self.backend.post_interaction(interaction) - if api_response.type != "task_done": - # multi-step tasks are not supported yet - logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})") - raise RuntimeError("Unexpected response from backend received") - return api_response - - def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: - return self.backend.post_interaction( - protocol_schema.TextReplyToPost( - post_id=str(self.first_message.id), - user_post_id=str(user_msg.id), - user=self.to_api_user(user_msg.author), - text=user_msg.content, - ) - ) - - async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: - try: - self.post_text_reply_to_post(user_msg) - await user_msg.add_reaction("✅") - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in handle_text_reply_to_post()") - await user_msg.add_reaction("❌") - await user_msg.reply(f"❌ Error communicating with backend: {e}") - - def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task: - return self.backend.post_interaction( - protocol_schema.PostRanking( - post_id=str(self.first_message.id), - user_post_id=str(user_msg.id), - user=self.to_api_user(user_msg.author), - ranking=ranking, - ) - ) - - async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task: - try: - ranking_str = user_msg.content - ranking = [int(x) - 1 for x in ranking_str.split(",")] - self.post_ranking(user_msg, ranking=ranking) - await user_msg.add_reaction("✅") - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in handle_ranking()") - await user_msg.add_reaction("❌") - await user_msg.reply(f"❌ Error communicating with backend: {e}") - - -class SummarizeStoryHandler(ChannelTaskBase): - task: protocol_schema.SummarizeStoryTask - thread_name: str = "Summaries" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_summarize_story.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class InitialPromptHandler(ChannelTaskBase): - task: protocol_schema.InitialPromptTask - thread_name: str = "Prompts" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_initial_prompt.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class UserReplyHandler(ChannelTaskBase): - task: protocol_schema.UserReplyTask - thread_name: str = "User replies" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_user_reply.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class AssistantReplyHandler(ChannelTaskBase): - task: protocol_schema.AssistantReplyTask - thread_name: str = "Assistant replies" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_assistant_reply.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class RankInitialPromptsHandler(ChannelTaskBase): - task: protocol_schema.RankInitialPromptsTask - thread_name: str = "User Responses" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rank_initial_prompts.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_ranking(msg) - - -class RankConversationsHandler(ChannelTaskBase): - task: protocol_schema.RankConversationRepliesTask - thread_name: str = "Rankings" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rank_conversation_replies.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_ranking(msg) - - -class RatingButton(discord.ui.Button): - def __init__(self, label, value, response_handler): - super().__init__(label=label, style=discord.ButtonStyle.green) - self.value = value - self.response_handler = response_handler - - async def callback(self, interaction): - await self.response_handler(self.value, interaction) - - -def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: - view = discord.ui.View() - for i in range(lo, hi + 1): - view.add_item(RatingButton(str(i), i, response_handler)) - return view - - -class RateSummaryHandler(ChannelTaskBase): - task: protocol_schema.RateSummaryTask - thread_name: str = "Ratings" - - async def _rating_response_handler(self, score, interaction: discord.Interaction): - logger.info("rating_response_handler", score) - if self.thread: - try: - self.backend.post_interaction( - protocol_schema.PostRating( - post_id=str(self.first_message.id), - user_post_id=str(interaction.id), - user=self.to_api_user(interaction.user), - rating=score, - ) - ) - await interaction.response.send_message( - f"Thanks {interaction.user.display_name}, got your feedback: {score}!" - ) - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in _rating_response_handler()") - interaction.response.send_message(f"❌ Error communicating with backend: {e}") - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rate_summary.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler) - return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - logger.info(f"on_rate_summary_reply: {msg.content}") - await msg.add_reaction("❌") - await msg.reply("❌ Text intput not supported.") diff --git a/discord-bot/utils.py b/discord-bot/utils.py deleted file mode 100644 index 968e4498..00000000 --- a/discord-bot/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -import subprocess -from datetime import datetime - -import pytz - - -def get_git_head_hash(): - # get current git hash - x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True) - if x.returncode == 0: - return x.stdout.replace("\n", "") - return None - - -def utcnow() -> datetime: - return datetime.now(pytz.UTC) - - -class DiscordTimestampStyle(str, enum.Enum): - """ - Timestamp Styles - - t 16:20 Short Time - T 16:20:30 Long Time - d 20/04/2021 Short Date - D 20 April 2021 Long Date - f * 20 April 2021 16:20 Short Date/Time - F Tuesday, 20 April 2021 16:20 Long Date/Time - R 2 months ago Relative Time - - See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles - """ - - default = "" - short_time = "t" - long_time = "T" - short_date = "d" - long_date = "D" - short_date_time = "f" - long_date_time = "F" - relative_time = "R" - - -def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default): - parts = ["") - return "".join(parts)