From 3ce6ab80d67d6653a892995eea5c0d2cfb65eec6 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Wed, 28 Dec 2022 16:43:14 -0800 Subject: [PATCH 01/27] initial bot structure --- discord-bot/.env.example | 3 + discord-bot/.gitignore | 4 + discord-bot/CONTRIBUTING.md | 43 ++++ discord-bot/README.md | 13 +- discord-bot/__main__.py | 17 -- discord-bot/api_client.py | 3 +- discord-bot/bot.py | 283 ----------------------- discord-bot/bot/__init__.py | 2 + discord-bot/bot/__main__.py | 17 ++ discord-bot/bot/bot.py | 37 +++ discord-bot/bot/config.py | 35 +++ discord-bot/bot/db/database.db | 0 discord-bot/bot/db/schema.sql | 10 + discord-bot/bot/extensions/hot_reload.py | 61 +++++ discord-bot/bot_base.py | 61 ----- discord-bot/bot_settings.py | 15 -- discord-bot/channel_handlers.py | 88 ------- discord-bot/dev-requirements.txt | 8 + discord-bot/flake8-requirements.txt | 26 +++ discord-bot/noxfile.py | 33 +++ discord-bot/pyproject.toml | 47 ++++ discord-bot/requirements.txt | 17 +- discord-bot/task_handlers.py | 267 --------------------- discord-bot/utils.py | 52 ----- 24 files changed, 340 insertions(+), 802 deletions(-) create mode 100644 discord-bot/.env.example create mode 100644 discord-bot/CONTRIBUTING.md delete mode 100644 discord-bot/__main__.py delete mode 100644 discord-bot/bot.py create mode 100644 discord-bot/bot/__init__.py create mode 100644 discord-bot/bot/__main__.py create mode 100644 discord-bot/bot/bot.py create mode 100644 discord-bot/bot/config.py create mode 100644 discord-bot/bot/db/database.db create mode 100644 discord-bot/bot/db/schema.sql create mode 100644 discord-bot/bot/extensions/hot_reload.py delete mode 100644 discord-bot/bot_base.py delete mode 100644 discord-bot/bot_settings.py delete mode 100644 discord-bot/channel_handlers.py create mode 100644 discord-bot/dev-requirements.txt create mode 100644 discord-bot/flake8-requirements.txt create mode 100644 discord-bot/noxfile.py create mode 100644 discord-bot/pyproject.toml delete mode 100644 discord-bot/task_handlers.py delete mode 100644 discord-bot/utils.py diff --git a/discord-bot/.env.example b/discord-bot/.env.example new file mode 100644 index 00000000..89e50c05 --- /dev/null +++ b/discord-bot/.env.example @@ -0,0 +1,3 @@ +TOKEN= +DECLARE_GLOBAL_COMMANDS= +OWNER_IDS= \ No newline at end of file diff --git a/discord-bot/.gitignore b/discord-bot/.gitignore index a7982d60..2842b686 100644 --- a/discord-bot/.gitignore +++ b/discord-bot/.gitignore @@ -1,3 +1,7 @@ .env *.egg-info/ __pycache__/ + +.venv +.nox +.env \ No newline at end of file diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md new file mode 100644 index 00000000..089a0c33 --- /dev/null +++ b/discord-bot/CONTRIBUTING.md @@ -0,0 +1,43 @@ +# Contributing + +## Setup + +To run the bot + +``` +cp .env.example .env + +python -V # 3.10 + +pip install -r requirements.txt +python -m bot +``` + +To test the bot + +``` +python -m pip install -r dev-requirements.txt + +nox +``` + +To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. + +1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) +2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable. + +## Resources + +Main framework + +- [Hikari Repo](https://github.com/hikari-py/hikari) +- [Hikari Docs](https://docs.hikari-py.dev/en/latest/) + +Command handler + +- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb) +- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/) + +Component handler (buttons, modals, etc... ) + +- [Miru Repo](https://github.com/HyperGH/hikari-miru) diff --git a/discord-bot/README.md b/discord-bot/README.md index a585b37f..cde82025 100644 --- a/discord-bot/README.md +++ b/discord-bot/README.md @@ -6,15 +6,6 @@ This bot collects human feedback to create a dataset for RLHF-alignment of an as To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages. -## Bot token for development +## Contributing -To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. - -1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) -2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`. - -The simplest way to configure the token is via an `.env` file: - -``` -BOT_TOKEN=XYZABC123... -``` +To contribute to the bot, please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file. diff --git a/discord-bot/__main__.py b/discord-bot/__main__.py deleted file mode 100644 index 9e5e29c7..00000000 --- a/discord-bot/__main__.py +++ /dev/null @@ -1,17 +0,0 @@ -# -*- coding: utf-8 -*- -from bot import OpenAssistantBot -from bot_settings import settings - -# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot - -if __name__ == "__main__": - bot = OpenAssistantBot( - settings.BOT_TOKEN, - bot_channel_name=settings.BOT_CHANNEL_NAME, - backend_url=settings.BACKEND_URL, - api_key=settings.API_KEY, - owner_id=settings.OWNER_ID, - template_dir=settings.TEMPLATE_DIR, - debug=settings.DEBUG, - ) - bot.run() diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py index 1de6bb17..0caa1595 100644 --- a/discord-bot/api_client.py +++ b/discord-bot/api_client.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import enum from typing import Optional, Type +import typing as t import requests from oasst_shared.schemas import protocol as protocol_schema @@ -41,7 +42,7 @@ class ApiClient: response.raise_for_status() return response.json() - def _parse_task(self, data: dict) -> protocol_schema.Task: + def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: if not isinstance(data, dict): raise ValueError("dict expected") diff --git a/discord-bot/bot.py b/discord-bot/bot.py deleted file mode 100644 index a19fdfe1..00000000 --- a/discord-bot/bot.py +++ /dev/null @@ -1,283 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -import asyncio -from datetime import timedelta -from pathlib import Path -from typing import Optional, Union - -import discord -import task_handlers -from api_client import ApiClient, TaskType -from bot_base import BotBase -from discord import app_commands -from loguru import logger -from message_templates import MessageTemplates -from oasst_shared.schemas import protocol as protocol_schema -from utils import get_git_head_hash, utcnow - -__version__ = "0.0.3" -BOT_NAME = "Open-Assistant Junior" - - -class OpenAssistantBot(BotBase): - def __init__( - self, - bot_token: str, - bot_channel_name: str, - backend_url: str, - api_key: str, - owner_id: Optional[Union[int, str]] = None, - template_dir: str = "./templates", - debug: bool = False, - ): - super().__init__() - - self.template_dir = Path(template_dir) - self.bot_channel_name = bot_channel_name - self.templates = MessageTemplates(template_dir) - self.debug = debug - - intents = discord.Intents.default() - intents.message_content = True - - if isinstance(owner_id, str): - owner_id = int(owner_id) - self.owner_id = owner_id - - self.bot_token = bot_token - client = discord.Client(intents=intents) - self.client = client - self.loop = client.loop - - self.bot_channel: discord.TextChannel = None - self.backend = ApiClient(backend_url, api_key) - - self.tree = app_commands.CommandTree(self.client, fallback_to_global=True) - - @client.event - async def on_ready(): - self.bot_channel = self.get_text_channel_by_name(bot_channel_name) - logger.info(f"{client.user} is now running!") - - await self.delete_all_old_bot_messages() - # if self.debug: - # await self.post_boot_message() - await self.post_welcome_message() - - client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") - - @client.event - async def on_message(message: discord.Message): - # ignore own messages - if message.author != client.user: - await self.handle_message(message) - - @self.tree.command() - async def tutorial(interaction: discord.Interaction): - """Start the Open-Assistant tutorial via DMs.""" - - dm = await self.client.create_dm(discord.Object(interaction.user.id)) - await dm.send("Tutorial coming soon... :-)") - await interaction.response.send_message(f"tutorial command by {interaction.user.name}") - - @self.tree.command() - async def help(interaction: discord.Interaction): - """Sends the user a list of all available commands""" - await self.post_help(interaction.user) - await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.") - - @self.tree.command() - async def work(interaction: discord.Interaction): - """Request a new personalized task""" - - # task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) - # task = self.backend.fetch_random_task(user=None) - q = task_handlers.Questionnaire() - await interaction.response.send_modal(q) - - async def post_help(self, user: discord.abc.User) -> discord.Message: - is_bot_owner = user.id == self.owner_id - return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner) - - async def post_boot_message(self) -> discord.Message: - return await self.post_template( - "boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug - ) - - async def post_welcome_message(self) -> discord.Message: - return await self.post_template("welcome.msg") - - async def delete_all_old_bot_messages(self) -> None: - logger.info("Deleting old threads...") - for thread in self.bot_channel.threads: - if thread.owner_id == self.client.user.id: - await thread.delete() - logger.info("Completed deleting old theards.") - - logger.info("Deleting old messages...") - look_until = utcnow() - timedelta(days=365) - async for msg in self.bot_channel.history(limit=None): - msg: discord.Message - if msg.created_at < look_until: - break - if msg.author.id == self.client.user.id: - await msg.delete() - logger.info("Completed deleting old messages.") - - async def next_task(self): - task_type = protocol_schema.TaskRequestType.random - task = self.backend.fetch_task(task_type, user=None) - - handler: task_handlers.ChannelTaskBase = None - match task.type: - case TaskType.summarize_story: - handler = task_handlers.SummarizeStoryHandler() - case TaskType.rate_summary: - handler = task_handlers.RateSummaryHandler() - case TaskType.initial_prompt: - handler = task_handlers.InitialPromptHandler() - case TaskType.user_reply: - handler = task_handlers.UserReplyHandler() - case TaskType.assistant_reply: - handler = task_handlers.AssistantReplyHandler() - case TaskType.rank_initial_prompts: - handler = task_handlers.RankInitialPromptsHandler() - case TaskType.rank_user_replies | TaskType.rank_assistant_replies: - handler = task_handlers.RankConversationsHandler() - case _: - logger.warning(f"Unsupported task type received: {task.type}") - self.backend.nack_task(task.id, "not supported") - - if handler: - try: - logger.info(f"strarting task {task.id}") - msg = await handler.start(self, task) - self.backend.ack_task(task.id, msg.id) - except Exception: - logger.exception("Starting task failed.") - self.backend.nack_task(task.id, "faled") - - async def background_timer(self): - next_remove_completed = utcnow() + timedelta(seconds=10) - next_fetch_task = utcnow() + timedelta(seconds=1) - while True: - now = utcnow() - - if self.bot_channel: - if now > next_fetch_task: - next_fetch_task = utcnow() + timedelta(seconds=60) - - try: - await self.next_task() - except Exception: - logger.exception("fetching next task failed") - - for x in self.reply_handlers.values(): - x.handler.tick(now) - - if now > next_remove_completed: - next_remove_completed = utcnow() + timedelta(seconds=10) - await self.remove_completed_handlers() - - await asyncio.sleep(1) - - async def _sync(self, command: str, message: discord.Message): - - logger.info(f"sync tree command received: {command}") - - if command == "sync.copy_global": - await self.tree.copy_global_to(guild=message.guild) - synced = await self.tree.sync(guild=message.guild) - elif command == "sync.clear_guild": - self.tree.clear_commands(guild=message.guild) - synced = await self.tree.sync(guild=message.guild) - elif command == "sync.guild": - synced = await self.tree.sync(guild=message.guild) - else: - synced = await self.tree.sync() - - logger.info(f"Synced {len(synced)} commands") - await message.reply(f"Synced {len(synced)} commands") - - async def handle_command(self, message: discord.Message, is_owner: bool): - command_text: str = message.content - command_text = command_text[1:] - match command_text: - case "help" | "?": - await self.post_help(user=message.author) - case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild": - if is_owner: - await self._sync(command_text, message) - case _: - await message.reply(f"unknown command: {command_text}") - - def recipient_filter(self, message: discord.Message) -> bool: - channel = message.channel - - if ( - message.channel.type == discord.ChannelType.private - or message.channel.type == discord.ChannelType.private_thread - ): - return True - - if ( - message.channel.type == discord.ChannelType.text - or message.channel.type == discord.ChannelType.public_thread - ): - while channel: - if self.bot_channel and channel.id == self.bot_channel.id: - return True - channel = channel.parent - - return False - - async def handle_message(self, message: discord.Message): - if not self.recipient_filter(message): - return - - user_id = message.author.id - user_display_name = message.author.name - - logger.debug( - f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})" - ) - - command_prefix = "!" - if message.type == discord.MessageType.default and message.content.startswith(command_prefix): - is_owner = self.owner_id and user_id == self.owner_id - await self.handle_command(message, is_owner) - - if isinstance(message.channel, discord.Thread): - handler = self.reply_handlers.get(message.channel.id) - if handler and not handler.handler.completed: - handler.handler.on_reply(message) - - if message.reference: - handler = self.reply_handlers.get(message.reference.message_id) - if handler and not handler.handler.completed: - handler.handler.on_reply(message) - - async def remove_completed_handlers(self): - completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed] - if len(completed) == 0: - return - - for c in completed: - handler = self.reply_handlers[c] - del self.reply_handlers[c] - try: - await handler.handler.finalize() - except Exception: - logger.exception("handler finalize failed") - - logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})") - - def get_text_channel_by_name(self, channel_name) -> discord.TextChannel: - for channel in self.client.get_all_channels(): - if channel.type == discord.ChannelType.text and channel.name == channel_name: - return channel - - def run(self): - """Run bot loop blocking.""" - self.client.run(self.bot_token) diff --git a/discord-bot/bot/__init__.py b/discord-bot/bot/__init__.py new file mode 100644 index 00000000..3d04718d --- /dev/null +++ b/discord-bot/bot/__init__.py @@ -0,0 +1,2 @@ +# -*- coding=utf-8 -*- +"""The official Open-Assistant Discord Bot.""" diff --git a/discord-bot/bot/__main__.py b/discord-bot/bot/__main__.py new file mode 100644 index 00000000..f258d148 --- /dev/null +++ b/discord-bot/bot/__main__.py @@ -0,0 +1,17 @@ +# -*- coding=utf-8 -*- +"""Entry point for the bot.""" +import logging +import os + +from bot.bot import bot + +logger = logging.getLogger(__name__) + +if __name__ == "__main__": + if os.name != "nt": + import uvloop + + uvloop.install() + + logger.info("Starting bot") + bot.run() diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py new file mode 100644 index 00000000..e529cf75 --- /dev/null +++ b/discord-bot/bot/bot.py @@ -0,0 +1,37 @@ +# -*- coding=utf-8 -*- +"""Bot logic.""" +import hikari + +import aiosqlite +import lightbulb +import miru +from bot.config import Config + +config = Config.from_env() + +bot = lightbulb.BotApp( + token=config.token, + logs="DEBUG", + prefix="./", + default_enabled_guilds=config.declare_global_commands, + owner_ids=config.owner_ids, + intents=hikari.Intents.ALL, +) + + +@bot.listen() +async def on_starting(event: hikari.StartingEvent): + """Setup.""" + + miru.install(bot) # component handler + bot.load_extensions_from("./bot/extensions") # load extensions + + bot.d.db = await aiosqlite.connect(":memory:") # TODO: Update + await bot.d.db.executescript(open("./bot/db/schema.sql").read()) + await bot.d.db.commit() + + +@bot.listen() +async def on_stopping(event: hikari.StoppingEvent): + """Cleanup.""" + await bot.d.db.close() diff --git a/discord-bot/bot/config.py b/discord-bot/bot/config.py new file mode 100644 index 00000000..5905301c --- /dev/null +++ b/discord-bot/bot/config.py @@ -0,0 +1,35 @@ +# -*- coding=utf-8 -*- +"""Configuration for the bot.""" + +import logging +from dataclasses import dataclass +from os import getenv + +from dotenv import load_dotenv + +load_dotenv() + +logger = logging.getLogger(__name__) + + +@dataclass +class Config: + """Configuration for the bot.""" + + token: str + declare_global_commands: int + owner_ids: list[int] + + @classmethod + def from_env(cls): + token = getenv("TOKEN", None) + + if token is None: + logger.error("Invalid token, please set the TOKEN environment variable.") + exit(1) + + return cls( + token=token, + declare_global_commands=int(getenv("DECLARE_GLOBAL_COMMANDS", 0)), + owner_ids=[int(x) for x in getenv("OWNER_IDS", "").split(",")], + ) diff --git a/discord-bot/bot/db/database.db b/discord-bot/bot/db/database.db new file mode 100644 index 00000000..e69de29b diff --git a/discord-bot/bot/db/schema.sql b/discord-bot/bot/db/schema.sql new file mode 100644 index 00000000..9fedf1da --- /dev/null +++ b/discord-bot/bot/db/schema.sql @@ -0,0 +1,10 @@ +-- Sqlite3 schema for the bot +CREATE TABLE IF NOT EXISTS guild_settings ( + guild_id BIGINT NOT NULL PRIMARY KEY, + log_channel_id BIGINT +); + +CREATE TABLE IF NOT EXISTS example ( + id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, + name VARCHAR(255) NOT NULL +); diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py new file mode 100644 index 00000000..ffb7ea70 --- /dev/null +++ b/discord-bot/bot/extensions/hot_reload.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +"""Hot reload plugin.""" +from glob import glob + +import hikari +import lightbulb + +plugin = lightbulb.Plugin( + "HotReloadPlugin", +) +plugin.add_checks(lightbulb.owner_only) + +EXTENSIONS_FOLDER = "bot/extensions" + + +def _get_extensions() -> list[str]: + # Recursively get all the .py files in the extensions directory. + exts = glob("bot/extensions/**/*.py", recursive=True) + # Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension") + return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts] + + +async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]: + # Check that the option is a string. + if not isinstance(option.value, str): + raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`") + + exts = _get_extensions() + return [ext for ext in exts if option.value in ext] + + +@plugin.command +@lightbulb.option( + "plugin", + "The plugin to reload. Leave empty to reload all plugins.", + autocomplete=_plugin_autocomplete, + required=False, + default=None, +) +@lightbulb.command("reload", "Reload a plugin") +@lightbulb.implements(lightbulb.SlashCommand) +async def reload(ctx: lightbulb.SlashContext): + """Reload a plugin or all plugins.""" + # If the plugin option is None, reload all plugins. + if ctx.options.plugin is None: + ctx.bot.reload_extensions(*_get_extensions()) + await ctx.respond("Reloaded all plugins.") + # Otherwise, reload the specified plugin. + else: + ctx.bot.reload_extensions(ctx.options.plugin) + await ctx.respond(f"Reloaded `{ctx.options.plugin}`.") + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot_base.py b/discord-bot/bot_base.py deleted file mode 100644 index 76eca22d..00000000 --- a/discord-bot/bot_base.py +++ /dev/null @@ -1,61 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -import asyncio -from abc import ABC -from dataclasses import dataclass -from typing import Any - -import discord -from api_client import ApiClient -from channel_handlers import ChannelHandlerBase -from loguru import logger -from message_templates import MessageTemplates - - -@dataclass -class ReplyHandlerInfo: - msg_id: int - handler_task: asyncio.Task - handler: ChannelHandlerBase - - -class BotBase(ABC): - bot_channel_name: str - debug: bool - backend: ApiClient - client: discord.Client - loop: asyncio.BaseEventLoop - owner_id: int - bot_channel: discord.TextChannel - templates: MessageTemplates - reply_handlers: dict[int, ReplyHandlerInfo] - - def __init__(self): - self.reply_handlers = {} # handlers by msg_id - - def ensure_bot_channel(self) -> None: - if self.bot_channel is None: - raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found") - - async def post( - self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None - ) -> discord.Message: - if channel is None: - self.ensure_bot_channel() - channel = self.bot_channel - return await channel.send(content=content, view=view) - - async def post_template( - self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any - ) -> discord.Message: - logger.debug(f"rendering {name}") - text = self.templates.render(name, **kwargs) - return await self.post(text, view=view, channel=channel) - - def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase): - if msg_id in self.reply_handlers: - raise RuntimeError(f"Handler already registered for msg_id: {msg_id}") - task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})") - task.add_done_callback(lambda t: handler.on_completed()) - self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler) diff --git a/discord-bot/bot_settings.py b/discord-bot/bot_settings.py deleted file mode 100644 index c976d7cd..00000000 --- a/discord-bot/bot_settings.py +++ /dev/null @@ -1,15 +0,0 @@ -# -*- coding: utf-8 -*- -from pydantic import AnyHttpUrl, BaseSettings - - -class BotSettings(BaseSettings): - BACKEND_URL: AnyHttpUrl = "http://localhost:8080" - API_KEY: str = "any_key" - BOT_TOKEN: str - BOT_CHANNEL_NAME: str = "bot" - OWNER_ID: int = None - TEMPLATE_DIR: str = "./templates" - DEBUG: bool = True - - -settings = BotSettings(_env_file=".env") diff --git a/discord-bot/channel_handlers.py b/discord-bot/channel_handlers.py deleted file mode 100644 index 75f03c0e..00000000 --- a/discord-bot/channel_handlers.py +++ /dev/null @@ -1,88 +0,0 @@ -# -*- coding: utf-8 -*- -import asyncio -from abc import ABC, abstractmethod -from datetime import datetime - -import discord -from loguru import logger - - -class ChannelExpiredException(Exception): - pass - - -class ChannelHandlerBase(ABC): - queue: asyncio.Queue - completed: bool = False - expiry_date: datetime - expired: bool = False - - def __init__(self, *, expiry_date: datetime = None): - self.expiry_date = expiry_date - self.queue = asyncio.Queue() - - async def read(self) -> discord.Message: - """Call this method to read the next message from the user in the handler method.""" - if self.expired: - raise ChannelExpiredException() - - msg = await self.queue.get() - if msg is None: - if self.expired: - raise ChannelExpiredException() - else: - raise RuntimeError("Unexpected None message read") - return msg - - def on_reply(self, message: discord.Message) -> None: - self.queue.put_nowait(message) - - def on_expire(self) -> None: - logger.info("ChannelHandler: on_expire") - self.expired = True - self.queue.put_nowait(None) - - def on_completed(self) -> None: - logger.info("ChannelHandler: on_completed") - self.completed = True - - def tick(self, now: datetime): - if now > self.expiry_date and not self.expired: - self.on_expire() - - @abstractmethod - async def handler_loop(self): - ... - - async def finalize(self): - pass - - -class AutoDestructThreadHandler(ChannelHandlerBase): - first_message: discord.Message = None - thread: discord.Thread = None - - def __init__(self, *, expiry_date: datetime = None): - super().__init__(expiry_date=expiry_date) - - async def read(self) -> discord.Message: - try: - return await super().read() - except ChannelExpiredException: - await self.cleanup() - raise - - async def cleanup(self): - logger.debug("AutoDestructThreadHandler.cleanup") - if self.thread: - logger.debug(f"deleting thread: {self.thread.name}") - await self.thread.delete() - self.thread = None - if self.first_message: - logger.debug(f"deleting first_message: {self.first_message.content}") - await self.first_message.delete() - self.first_message = None - - async def finalize(self): - await self.cleanup() - return await super().finalize() diff --git a/discord-bot/dev-requirements.txt b/discord-bot/dev-requirements.txt new file mode 100644 index 00000000..44a8d2cc --- /dev/null +++ b/discord-bot/dev-requirements.txt @@ -0,0 +1,8 @@ +nox + +black +isort + +codespell +flake8 +pyright \ No newline at end of file diff --git a/discord-bot/flake8-requirements.txt b/discord-bot/flake8-requirements.txt new file mode 100644 index 00000000..3509207e --- /dev/null +++ b/discord-bot/flake8-requirements.txt @@ -0,0 +1,26 @@ +flake8==6.0.0 + +# Plugins + +Flake8-pyproject # use the pyproject.toml as the config file +flake8-bandit # runs bandit +flake8-black # runs black +# flake8-broken-line # forbey "\" linebreaks +flake8-builtins # builtin shadowing checks +flake8-coding # coding magic-comment detection +flake8-comprehensions # comprehension checks +flake8-deprecated # deprecated call checks +flake8-docstrings # pydocstyle support +flake8-executable # shebangs +flake8-fixme # "fix me" counter +flake8-functions # function linting +flake8-html # html output +flake8-if-statements # condition linting +flake8-isort # runs isort +flake8-mutable # mutable default argument detection +flake8-pep3101 # new-style format strings only +flake8-print # complain about print statements in code +flake8-printf-formatting # forbey printf-style python2 string formatting +flake8-pytest-style # pytest checks +flake8-raise # exception raising linting +flake8-use-fstring # format string checking diff --git a/discord-bot/noxfile.py b/discord-bot/noxfile.py new file mode 100644 index 00000000..37226787 --- /dev/null +++ b/discord-bot/noxfile.py @@ -0,0 +1,33 @@ +# -*- coding=utf-8 -*- +"""Automated linting, formatting, and typechecking.""" +import nox +from nox.sessions import Session + + +@nox.session(reuse_venv=True) +def format_code(session: Session): + """Format the codebase.""" + session.install("isort", "-U") + session.install("black", "-U") + + session.run("isort", "bot") + session.run("black", "bot") + + +@nox.session(reuse_venv=True) +def lint_code(session: Session): + """Lint the codebase.""" + session.install("codespell", "-U") + session.install("flake8", "-U") + session.install("-r", "flake8-requirements.txt", "-U") + + session.run("codespell", "bot") + session.run("flake8", "bot") + + +@nox.session(reuse_venv=True) +def typecheck_code(session: Session): + session.install("-r", "requirements.txt", "-U") + session.install("pyright", "-U") + + session.run("pyright", "bot") diff --git a/discord-bot/pyproject.toml b/discord-bot/pyproject.toml new file mode 100644 index 00000000..7a1e8d82 --- /dev/null +++ b/discord-bot/pyproject.toml @@ -0,0 +1,47 @@ +[project] +name = "Open-Assistant Discord Bot" +version = "0.0.1" + +[tool.black] +line-length = 120 +target-version = ["py310"] + +[tool.pyright] +include = ["ottbot", "noxfile.py"] +pythonVersion="3.10" +reportMissingImports=false +# reportInvalidTypeVarUse=false +# reportMissingModuleSource=false +reportUnknownVariableType=false +pythonPlatform="Linux" + +[tool.isort] +profile="black" +sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] +skip_glob = "**/__init__.pyi" + +[tool.flake8] +max-function-length = 130 +max-line-length = 130 +# Technically this is 120, but black has a policy of "1 or 2 over is fine if it is tidier", so we have to raise this. +accept-encodings = "utf-8" +docstring-convention = "numpy" +ignore = [ + "A002", # Argument is shadowing a python builtin. + "A003", # Class attribute is shadowing a python builtin. + "CFQ002", # Function has too many arguments. + "CFQ004", # Function has too many returns. + "D001", # False positive for depreciated functions. + "D102", # Missing docstring in public method. + "D105", # Magic methods not having a docstring. + "D412", # No blank lines allowed between a section header and its content + "E203", # Whitespace after : (to match how black formats it) + "E402", # Module level import not at top of file (isn't compatible with our import style). + "T101", # TO-DO comment detection (T102 is FIX-ME and T103 is XXX). + "W503", # line break before binary operator. + "W504", # line break before binary operator (again, I guess). + "S101", # Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. + "S105", # Possible hardcoded password. + "EXE002", # Executable file with not shebang + "D401", # Imperative mood +] diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 927ebcf2..49c5e1ba 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -1,7 +1,10 @@ -discord.py==2.1.0 -Jinja2==3.1.2 -pydantic==1.9.1 -python-dotenv==0.21.0 -pytz==2022.7 -requests==2.28.1 -schedule==1.1.0 +hikari # discord framework +hikari[speedups] +uvloop; os_name != 'nt' +hikari-lightbulb # command handler +hikari-miru # modals and buttons + +python-dotenv # .env file support +aiosqlite # database +aiohttp # http client +aiohttp[speedups] # speedups for aiohttp \ No newline at end of file diff --git a/discord-bot/task_handlers.py b/discord-bot/task_handlers.py deleted file mode 100644 index 1434d17c..00000000 --- a/discord-bot/task_handlers.py +++ /dev/null @@ -1,267 +0,0 @@ -# -*- coding: utf-8 -*- -from __future__ import annotations - -from abc import abstractmethod -from datetime import timedelta - -import discord -from api_client import ApiClient -from bot_base import BotBase -from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException -from loguru import logger -from oasst_shared.schemas import protocol as protocol_schema -from utils import DiscordTimestampStyle, discord_timestamp, utcnow - - -class Questionnaire(discord.ui.Modal, title="Questionnaire Response"): - name = discord.ui.TextInput(label="Name") - answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph) - - async def on_submit(self, interaction: discord.Interaction): - await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True) - - -class ChannelTaskBase(AutoDestructThreadHandler): - thread_name: str = "Replies" - expires_after: timedelta = timedelta(minutes=5) - backend: ApiClient - - async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message: - try: - self.bot = bot - self.task = task - self.backend = bot.backend - self.expiry_date = utcnow() + self.expires_after if self.expires_after else None - msg = await self.send_first_message() - self.first_message = msg - self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name) - await self.on_thread_created(self.thread) - except Exception: - logger.exception("start task failed") - await self.cleanup() # try to cleanup messag or thread - raise - - bot.register_reply_handler(msg_id=msg.id, handler=self) - return msg - - async def on_thread_created(self, thread: discord.Thread) -> None: - pass - - @abstractmethod - async def send_first_message(self) -> discord.message: - ... - - def to_api_user(self, user: discord.User) -> protocol_schema.User: - return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name) - - async def post_teaser_msg(self, template_name: str): - expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time) - expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time) - return await self.bot.post_template( - template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative - ) - - async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - api_response = await self.backend.post_interaction(interaction) - if api_response.type != "task_done": - # multi-step tasks are not supported yet - logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})") - raise RuntimeError("Unexpected response from backend received") - return api_response - - def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: - return self.backend.post_interaction( - protocol_schema.TextReplyToPost( - post_id=str(self.first_message.id), - user_post_id=str(user_msg.id), - user=self.to_api_user(user_msg.author), - text=user_msg.content, - ) - ) - - async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: - try: - self.post_text_reply_to_post(user_msg) - await user_msg.add_reaction("✅") - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in handle_text_reply_to_post()") - await user_msg.add_reaction("❌") - await user_msg.reply(f"❌ Error communicating with backend: {e}") - - def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task: - return self.backend.post_interaction( - protocol_schema.PostRanking( - post_id=str(self.first_message.id), - user_post_id=str(user_msg.id), - user=self.to_api_user(user_msg.author), - ranking=ranking, - ) - ) - - async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task: - try: - ranking_str = user_msg.content - ranking = [int(x) - 1 for x in ranking_str.split(",")] - self.post_ranking(user_msg, ranking=ranking) - await user_msg.add_reaction("✅") - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in handle_ranking()") - await user_msg.add_reaction("❌") - await user_msg.reply(f"❌ Error communicating with backend: {e}") - - -class SummarizeStoryHandler(ChannelTaskBase): - task: protocol_schema.SummarizeStoryTask - thread_name: str = "Summaries" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_summarize_story.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class InitialPromptHandler(ChannelTaskBase): - task: protocol_schema.InitialPromptTask - thread_name: str = "Prompts" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_initial_prompt.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class UserReplyHandler(ChannelTaskBase): - task: protocol_schema.UserReplyTask - thread_name: str = "User replies" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_user_reply.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class AssistantReplyHandler(ChannelTaskBase): - task: protocol_schema.AssistantReplyTask - thread_name: str = "Assistant replies" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_assistant_reply.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_text_reply_to_post(msg) - - -class RankInitialPromptsHandler(ChannelTaskBase): - task: protocol_schema.RankInitialPromptsTask - thread_name: str = "User Responses" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rank_initial_prompts.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_ranking(msg) - - -class RankConversationsHandler(ChannelTaskBase): - task: protocol_schema.RankConversationRepliesTask - thread_name: str = "Rankings" - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rank_conversation_replies.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - await self.handle_ranking(msg) - - -class RatingButton(discord.ui.Button): - def __init__(self, label, value, response_handler): - super().__init__(label=label, style=discord.ButtonStyle.green) - self.value = value - self.response_handler = response_handler - - async def callback(self, interaction): - await self.response_handler(self.value, interaction) - - -def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: - view = discord.ui.View() - for i in range(lo, hi + 1): - view.add_item(RatingButton(str(i), i, response_handler)) - return view - - -class RateSummaryHandler(ChannelTaskBase): - task: protocol_schema.RateSummaryTask - thread_name: str = "Ratings" - - async def _rating_response_handler(self, score, interaction: discord.Interaction): - logger.info("rating_response_handler", score) - if self.thread: - try: - self.backend.post_interaction( - protocol_schema.PostRating( - post_id=str(self.first_message.id), - user_post_id=str(interaction.id), - user=self.to_api_user(interaction.user), - rating=score, - ) - ) - await interaction.response.send_message( - f"Thanks {interaction.user.display_name}, got your feedback: {score}!" - ) - except ChannelExpiredException: - raise - except Exception as e: - logger.exception("Error in _rating_response_handler()") - interaction.response.send_message(f"❌ Error communicating with backend: {e}") - - async def send_first_message(self) -> discord.message: - return await self.post_teaser_msg("teaser_rate_summary.msg") - - async def on_thread_created(self, thread: discord.Thread) -> None: - view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler) - return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task) - - async def handler_loop(self): - while True: - msg = await self.read() - logger.info(f"on_rate_summary_reply: {msg.content}") - await msg.add_reaction("❌") - await msg.reply("❌ Text intput not supported.") diff --git a/discord-bot/utils.py b/discord-bot/utils.py deleted file mode 100644 index 968e4498..00000000 --- a/discord-bot/utils.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -import subprocess -from datetime import datetime - -import pytz - - -def get_git_head_hash(): - # get current git hash - x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True) - if x.returncode == 0: - return x.stdout.replace("\n", "") - return None - - -def utcnow() -> datetime: - return datetime.now(pytz.UTC) - - -class DiscordTimestampStyle(str, enum.Enum): - """ - Timestamp Styles - - t 16:20 Short Time - T 16:20:30 Long Time - d 20/04/2021 Short Date - D 20 April 2021 Long Date - f * 20 April 2021 16:20 Short Date/Time - F Tuesday, 20 April 2021 16:20 Long Date/Time - R 2 months ago Relative Time - - See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles - """ - - default = "" - short_time = "t" - long_time = "T" - short_date = "d" - long_date = "D" - short_date_time = "f" - long_date_time = "F" - relative_time = "R" - - -def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default): - parts = ["") - return "".join(parts) From c8834aa9e336fb57bef0b22593f082da6d1575c8 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Wed, 28 Dec 2022 21:24:53 -0800 Subject: [PATCH 02/27] add a lot of examples --- discord-bot/.env.example | 3 +- discord-bot/CONTRIBUTING.md | 61 ++++ discord-bot/bot/bot.py | 7 +- discord-bot/bot/config.py | 2 + discord-bot/bot/extensions/__init__.py | 5 + discord-bot/bot/extensions/example.py | 406 +++++++++++++++++++++++ discord-bot/bot/extensions/hot_reload.py | 4 +- discord-bot/bot/utils.py | 23 ++ 8 files changed, 504 insertions(+), 7 deletions(-) create mode 100644 discord-bot/bot/extensions/__init__.py create mode 100644 discord-bot/bot/extensions/example.py create mode 100644 discord-bot/bot/utils.py diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 89e50c05..c518010d 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,3 +1,4 @@ TOKEN= DECLARE_GLOBAL_COMMANDS= -OWNER_IDS= \ No newline at end of file +OWNER_IDS= +PREFIX="./" \ No newline at end of file diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md index 089a0c33..d4d8ad3b 100644 --- a/discord-bot/CONTRIBUTING.md +++ b/discord-bot/CONTRIBUTING.md @@ -28,6 +28,67 @@ To test the bot on your own discord server you need to register a discord applic ## Resources +### Structure + +```graphql +.env # Environment variables +.env.example # Example environment variables +CONTRIBUTING.md # This file +dev-requirements.txt # Development requirements +flake8-requirements.txt # Flake8 extensions (for linting) +noxfile.py # Nox session definitions (for formatting, typechecking, linting) +pyproject.toml # Project configuration +README.md # Project readme +requirements.txt # Requirements +templates/ # Message templates + +bot/ +├─ __init__.py +├─ __main__.py # Entrypoint +├─ bot.py # Main bot class +├─ config.py # Configuration and secrets +├─ utils.py # Utility Functions +│ +├─ db/ # Database related code +│ ├─ database.db # SQLite database +│ └─ schema.sql # Database schema +│ +└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html + └─ hot_reload.py # Utility for hot reload extension +``` + +### Adding a new command/listener + +1. Create a new file in the `extensions` folder +2. Copy the template below + +```py +# -*- coding: utf-8 -*- +"""My plugin.""" +import lightbulb + +plugin = lightbulb.Plugin("MyPlugin") + +# Add your commands here + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) +``` + +For example commands and listeners, see [here](/discord-bot/bot/extensions/_example.py) + +### Docs + +Discord + +- [Discord API Reference](https://discord.com/developers/docs/intro) + Main framework - [Hikari Repo](https://github.com/hikari-py/hikari) diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index e529cf75..af163545 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -1,10 +1,10 @@ # -*- coding=utf-8 -*- """Bot logic.""" -import hikari - import aiosqlite +import hikari import lightbulb import miru + from bot.config import Config config = Config.from_env() @@ -12,7 +12,7 @@ config = Config.from_env() bot = lightbulb.BotApp( token=config.token, logs="DEBUG", - prefix="./", + prefix=config.prefix, default_enabled_guilds=config.declare_global_commands, owner_ids=config.owner_ids, intents=hikari.Intents.ALL, @@ -22,7 +22,6 @@ bot = lightbulb.BotApp( @bot.listen() async def on_starting(event: hikari.StartingEvent): """Setup.""" - miru.install(bot) # component handler bot.load_extensions_from("./bot/extensions") # load extensions diff --git a/discord-bot/bot/config.py b/discord-bot/bot/config.py index 5905301c..e3addac9 100644 --- a/discord-bot/bot/config.py +++ b/discord-bot/bot/config.py @@ -19,6 +19,7 @@ class Config: token: str declare_global_commands: int owner_ids: list[int] + prefix: str @classmethod def from_env(cls): @@ -32,4 +33,5 @@ class Config: token=token, declare_global_commands=int(getenv("DECLARE_GLOBAL_COMMANDS", 0)), owner_ids=[int(x) for x in getenv("OWNER_IDS", "").split(",")], + prefix=getenv("PREFIX", "./"), ) diff --git a/discord-bot/bot/extensions/__init__.py b/discord-bot/bot/extensions/__init__.py new file mode 100644 index 00000000..87295d9a --- /dev/null +++ b/discord-bot/bot/extensions/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +"""Extensions for the bot. + +See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html +""" diff --git a/discord-bot/bot/extensions/example.py b/discord-bot/bot/extensions/example.py new file mode 100644 index 00000000..8ac7fe21 --- /dev/null +++ b/discord-bot/bot/extensions/example.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- +"""Example plugins for reference. + +Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`. +""" +import asyncio + +import hikari +import lightbulb +import lightbulb.decorators +import miru +from miru.ext import nav + +plugin = lightbulb.Plugin("ExamplePlugin") + +# To add checks to a plugin, you can use the `@plugin.check` decorator +# or the `plugin.add_check` method. Lightbulb has some built-in checks. +# The check will be called before any command in the plugin is called. +plugin.add_checks(lightbulb.guild_only) + + +# To create a slash command, use the template below +@plugin.command +@lightbulb.command("example", "Example command.") +@lightbulb.implements(lightbulb.SlashCommand) +async def example(ctx: lightbulb.SlashContext): + """Example command.""" + # To send a message, use the `respond` method on `ctx`. + # !!! Be sure to use `await` when calling `respond` !!! + await ctx.respond("Hello, world!") + + +# To add arguments, use the `@lightbulb.option` decorator. +@plugin.command +@lightbulb.option( + "name", # The name of the option. This is what you will use to access the value in `ctx.options.name` + "Your name.", # The description of the option. This will be shown in the slash command menu. + # Whether or not the option is required. + # If `required` is `True`, the user will not be able to use the command without providing a value for this option. + required=False, + default=None, # The default value for the option. If `required` is `True`, this will be ignored. + type=str | None, # The type of the option. This is used to convert the value to the correct type. + # https://hikari-lightbulb.readthedocs.io/en/latest/guides/commands.html#converters-and-slash-command-option-types +) +@lightbulb.option( + "age", + "Your age.", + type=int, + # These are enforced on the client side, so the user won't be able to enter a value outside of the range. + min_value=0, + max_value=100, +) +@lightbulb.option( + "gender", + "Your gender.", + # You can also use `choices` to limit the user to a specific set of values. + # This can be a list of `str`, `int, or `float` + # choices=["Male", "Female", "Other"], + # or a list of `hikari.CommandChoice` objects to have separate option names and values + choices=[ + hikari.CommandChoice(name="male", value="M"), + hikari.CommandChoice(name="female", value="F"), + hikari.CommandChoice(name="other", value="Other"), + ], + type=str, +) +@lightbulb.command("args_example", "Example command with arguments.") +@lightbulb.implements(lightbulb.SlashCommand) +async def args_example(ctx: lightbulb.SlashContext): + """Example command with arguments.""" + name: str | None = ctx.options.name + if name is None: + name = ctx.author.username + age: int = ctx.options.age + gender: str = ctx.options.gender + + await ctx.respond( + f"Hello {ctx.author.mention}! Your name is {name}, you are {age} years old, and your gender is {gender}.", + # in order to actually mention the user, you must pass `user_mentions=True` + # otherwise, the user won't get a notification + user_mentions=True, + ) + + +# To have autocomplete options, add the +# pass `autocomplete=function` to `@lightbulb.option` +# or `autocomplete=True` and mark the function with `@command.autocomplete("option_name")`. +# @autocomplete_example.autocomplete("language") +async def _programming_language_autocomplete( + option: hikari.CommandInteractionOption, interaction: hikari.AutocompleteInteraction +) -> list[str]: + # The `option` argument is the current text that the user typed in. + if not isinstance(option.value, str): + # This will raise a TypeError if `option.value` cannot be converted + option.value = str(option.value) + + # You can query a database, fetch an api, or return any list of strings + # !!! You can return a max of 25 options !!! + langs = [ + "C", + "C++", + "C#", + "CSS", + "Go", + "HTML", + "Java", + "Javascript", + "Kotlin", + "Matlab", + "NoSQL", + "PHP", + "Perl", + "Python", + "R", + "Ruby", + "Rust", + "SQL", + "Scala", + "Swift", + "TypeScript", + "Zig", + ] + return [lang for lang in langs if option.value.lower() in lang.lower()] + + +@plugin.command +@lightbulb.option( + "language", + "Your favorite programming language.", + autocomplete=_programming_language_autocomplete, +) +@lightbulb.command("autocomplete_example", "Autocomplete example.") +@lightbulb.implements(lightbulb.SlashCommand) +async def autocomplete_example(ctx: lightbulb.SlashContext): + """Autocomplete example.""" + await ctx.respond("Your favorite programming language is " + ctx.options.language) + + +# Command groups are like trees +# You can have subcommands, subcommand groups, and subcommand groups with subcommands +# Here is an example diagram: +# /group_example (group) +# subcommand (executable) +# subcommand_group (group) +# subsubcommand (executable) + +# Because those are slash commands, only the leaves (/subcommand and /subsubcommand) are callable. + +# To create a group, use the template below +# 1. Create the command group +@plugin.command +@lightbulb.command("group_example", "Example command group.") +@lightbulb.implements(lightbulb.SlashCommandGroup) +async def group_example(_: lightbulb.SlashContext) -> None: + """Group example.""" + # This will never execute because it is a group + pass + + +# 2. Add a child command +@group_example.child +@lightbulb.command("subcommand", "Example subcommand.") +@lightbulb.implements(lightbulb.SlashSubCommand) +async def subcommand(ctx: lightbulb.SlashContext) -> None: + """An example subcommand.""" + await ctx.respond("invoked `/group_example subcommand`") + + +# 3. Add a sub-group +@group_example.child +@lightbulb.command("subcommand_group", "Example subcommand group.") +@lightbulb.implements(lightbulb.SlashSubGroup) +async def subcommand_group(_: lightbulb.SlashContext) -> None: + """Subcommand group example.""" + # This will never execute because it is a sub-group + pass + + +# 4. Add a child to the sub-group +@subcommand_group.child +@lightbulb.command("subsubcommand", "Example subsubcommand.") +@lightbulb.implements(lightbulb.SlashSubCommand) +async def subsubcommand(ctx: lightbulb.SlashContext) -> None: + """An example subsubcommand.""" + await ctx.respond("invoked `/group_example subcommand_group subsubcommand`") + + +# Event listeners are a way to listen to events from the gateway. +# You can have stand alone event listeners or use `wait_for` to wait for a specific event inside a command / listener. +@plugin.listener(hikari.MemberCreateEvent) +async def on_member_join(event: hikari.MemberCreateEvent) -> None: + """Event listener to welcome new members.""" + guild = event.get_guild() + await event.member.send(f"Welcome to {guild.name if guild else 'the server'}!") + + +# You can also use `wait_for` to wait for a specific event +@plugin.command +@lightbulb.command("wait_for_example", "Example command with `wait_for` and `stream`.") +@lightbulb.implements(lightbulb.SlashCommand) +async def wait_for_example(ctx: lightbulb.SlashContext) -> None: + """Wait for example.""" + await ctx.respond("Send a message!") + + # We can add a predicate to `wait_for` to filter out events + def author_check(e: hikari.MessageCreateEvent) -> bool: + return e.author_id == ctx.author.id + + # You need to wrap wait_for in a try/catch block because it can raise `asyncio.TimeoutError` + try: + event = await ctx.bot.wait_for(hikari.MessageCreateEvent, timeout=10, predicate=author_check) + await ctx.respond(f"You sent: {event.message.content}") + except asyncio.TimeoutError: + await ctx.respond("Too slow!") + # remember to use try/except/finally if you need to clean up any resources + + # You can also use `stream` to listen for events + await ctx.respond("Waiting for guild events...") + with ctx.bot.stream(hikari.Event, timeout=5).filter( + # Only listen for events that have a guild_id and are not bots + lambda e: getattr(e, "guild_id", None) == ctx.guild_id + and getattr(e, "is_human", False) + ) as stream: + async for event in stream: + await ctx.respond(f"New `{event.__class__.__name__}`") + + await ctx.respond("Done!") + + +# You can interact with discord's API using the `rest` attribute on the bot +# This allows you to +# - fetch information about users, channels, guilds, etc. +# - create, edit, and delete messages, channels, threads, roles, categories, etc. +# - add, remove, and edit reactions +@plugin.command +@lightbulb.command("rest_example", "Example command using the `rest` attribute.") +@lightbulb.implements(lightbulb.SlashCommand) +async def rest_example(ctx: lightbulb.SlashContext) -> None: + """Example command using the `rest` attribute.""" + rest = ctx.bot.rest + your_messages = await rest.fetch_messages(ctx.channel_id).filter(lambda m: m.author.id == ctx.author.id).count() + await ctx.respond(f"{your_messages} out of the last 10 messages in this channel were sent by you.") + + +# Context Menus are a way to attach a command to a user or a message. +# By right clicking a user or a User, you can select to execute a command under the "Apps" menu item. +@plugin.command +@lightbulb.command("user_context_menu_example", "Example context menu on a user.") +@lightbulb.implements(lightbulb.UserCommand) +async def user_context_menu_example(ctx: lightbulb.UserContext) -> None: + """User context menu example.""" + user: hikari.Member = ctx.options.target + await ctx.respond(f"Hello {user.mention}!", user_mentions=True) + + +# Same with messages +@plugin.command +@lightbulb.command("message_context_menu_example", "Example context menu on a message.") +@lightbulb.implements(lightbulb.MessageCommand) +async def message_context_menu_example(ctx: lightbulb.MessageContext) -> None: + """Message context menu example.""" + message: hikari.Message = ctx.options.target + await ctx.respond(f"The message length is: {len(message.content or '')}", flags=hikari.MessageFlag.EPHEMERAL) + + +# Components are a way to add interactive buttons to your slash commands. +# We use `miru` to manage components and their callbacks. + +# To create a component, use the template below +# 1. Create the view +class MyView(miru.View): + """An example view with buttons.""" + + @miru.button(label="Rock", emoji="\N{ROCK}", style=hikari.ButtonStyle.PRIMARY) + async def rock_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + await ctx.respond("Paper!") + + @miru.button(label="Paper", emoji="\N{SCROLL}", style=hikari.ButtonStyle.PRIMARY) + async def paper_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + await ctx.respond("Scissors!") + + @miru.button(label="Scissors", emoji="\N{BLACK SCISSORS}", style=hikari.ButtonStyle.PRIMARY) + async def scissors_button(self, button: miru.Button, ctx: miru.ViewContext): + await ctx.respond("Rock!") + + @miru.button(emoji="\N{BLACK SQUARE FOR STOP}", style=hikari.ButtonStyle.DANGER, row=2) + async def stop_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.stop() # Stop listening for interactions + + @miru.select( + options=[ + hikari.SelectMenuOption( + label="Thing 1", + value="1", + description="This is a thing", + emoji=hikari.UnicodeEmoji("🗿"), + is_default=True, + ), + hikari.SelectMenuOption( + label="Thing 2", + value="2", + description="This is another thing", + emoji=hikari.UnicodeEmoji("🗿"), + is_default=False, + ), + hikari.SelectMenuOption( + label="Thing 3", + value="3", + description="This is a different thing", + emoji=hikari.UnicodeEmoji("🗿"), + is_default=False, + ), + ], + placeholder="Select some stuff!", + min_values=0, + max_values=2, + row=3, + ) + async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None: + await ctx.respond(f"You selected {select.values}") + + +# 2. Create a command to use the view +@plugin.command +@lightbulb.command("button_example", "Example command with buttons.") +@lightbulb.implements(lightbulb.SlashCommand) +async def button_example(ctx: lightbulb.SlashContext) -> None: + """Wait for example.""" + # 3. Create an instance of the view and start it + view = MyView(timeout=60) + resp = await ctx.respond("Rock Paper Scissors!", components=view) + msg = await resp.message() + await view.start(msg) + await view.wait() + + await ctx.respond("Thank you for playing!") + + +# You can use buttons to create a navigation menu +@plugin.command +@lightbulb.command("nav_example", "Example command with button navigation.", auto_defer=True) +@lightbulb.implements(lightbulb.SlashCommand) +async def navigation_example(ctx: lightbulb.SlashContext) -> None: + """Navigation example.""" + # await ctx.respond(response_type=hikari.ResponseType.DEFERRED_MESSAGE_UPDATE) + embed = hikari.Embed(title="I'm the second page!", description="Also an embed!") + pages = ["I'm the first page!", embed, "I'm the last page!"] + + navigator = nav.NavigatorView(pages=pages, timeout=10) + # You may also pass an interaction object to this function + await navigator.send(ctx.channel_id) + + await navigator.wait() # This is not necessary, but we want to wait anyway + await ctx.respond("Done!") + + +# Miru also has modal support +class MyModal(miru.Modal): + """An example modal.""" + + # Define our modal items + # You can also use Modal.add_item() to add items to the modal after instantiation, just like with views. + name = miru.TextInput(label="Name", placeholder="Enter your name!", required=True) + bio = miru.TextInput(label="Biography", value="Pre-filled content!", style=hikari.TextInputStyle.PARAGRAPH) + + # You can currently only use TextInputs + # https://discord.com/developers/docs/interactions/receiving-and-responding#interaction-response-object-modal + + # The callback function is called after the user hits 'Submit' + async def callback(self, context: miru.ModalContext) -> None: + # You can also access the values using ctx.values, Modal.values, or use ctx.get_value_by_id() + await context.respond(f"Your name: `{self.name.value}`\nYour bio: ```{self.bio.value}```") + + +class ModalView(miru.View): + """An example view that opens a modal.""" + + # Create a new button that will invoke our modal + @miru.button(label="Click me!", style=hikari.ButtonStyle.PRIMARY) + async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + modal = MyModal(title="Example Title") + # You may also use Modal.send(interaction) if not working with a miru context object. (e.g. slash commands) + # Keep in mind that modals can only be sent in response to interactions. + await ctx.respond_with_modal(modal) + # OR + # await modal.send(ctx.interaction) + + +@plugin.command +@lightbulb.command("modal_example", "Example command with a modal.") +@lightbulb.implements(lightbulb.SlashCommand) +async def modal_example(ctx: lightbulb.SlashContext) -> None: + """Navigation example.""" + view = ModalView() + resp = await ctx.respond("This button triggers a modal!", components=view) + await view.start(await resp.message()) + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py index ffb7ea70..b70a22fd 100644 --- a/discord-bot/bot/extensions/hot_reload.py +++ b/discord-bot/bot/extensions/hot_reload.py @@ -14,8 +14,8 @@ EXTENSIONS_FOLDER = "bot/extensions" def _get_extensions() -> list[str]: - # Recursively get all the .py files in the extensions directory. - exts = glob("bot/extensions/**/*.py", recursive=True) + # Recursively get all the .py files in the extensions directory not starting with an `_`. + exts = glob("bot/extensions/**/*[!_].py", recursive=True) # Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension") return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts] diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py new file mode 100644 index 00000000..beb81c36 --- /dev/null +++ b/discord-bot/bot/utils.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +"""Utility functions.""" +import typing as t +from datetime import datetime + + +def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str: + """Format a datetime object into the discord time format. + + ``` + | t | HH:MM | 16:20 + | T | HH:MM:SS | 16:20:11 + | D | D Mo Yr | 20 April 2022 + | f | D Mo Yr HH:MM | 20 April 2022 16:20 + | F | W, D Mo Yr HH:MM | Wednesday, 20 April 2022 16:20 + | R | relative | in an hour + ``` + """ + match fmt: + case "t" | "T" | "D" | "f" | "F" | "R": + return f"" + case _: + raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}") From 99303ed26585282cb79a4aa94dc4eb5d4c4f3a21 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Wed, 28 Dec 2022 21:29:19 -0800 Subject: [PATCH 03/27] move example.py to _example.py so it doesn't load on startup --- discord-bot/bot/extensions/{example.py => _example.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord-bot/bot/extensions/{example.py => _example.py} (100%) diff --git a/discord-bot/bot/extensions/example.py b/discord-bot/bot/extensions/_example.py similarity index 100% rename from discord-bot/bot/extensions/example.py rename to discord-bot/bot/extensions/_example.py From 9fd2e769175979c3034e133162262e2129bb1bc3 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Thu, 29 Dec 2022 14:20:56 -0800 Subject: [PATCH 04/27] add initial task loop for initial_prompt and rank_initial_prompts --- discord-bot/api_client.py | 75 ------ discord-bot/bot/api_client.py | 130 ++++++++++ discord-bot/bot/bot.py | 4 + discord-bot/bot/extensions/_example.py | 7 +- discord-bot/bot/extensions/hot_reload.py | 2 +- discord-bot/bot/extensions/tasks.py | 302 +++++++++++++++++++++++ discord-bot/bot/extensions/work.py | 281 +++++++++++++++++++++ discord-bot/bot/utils.py | 7 + discord-bot/requirements.txt | 3 +- 9 files changed, 733 insertions(+), 78 deletions(-) delete mode 100644 discord-bot/api_client.py create mode 100644 discord-bot/bot/api_client.py create mode 100644 discord-bot/bot/extensions/tasks.py create mode 100644 discord-bot/bot/extensions/work.py diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py deleted file mode 100644 index 0caa1595..00000000 --- a/discord-bot/api_client.py +++ /dev/null @@ -1,75 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -from typing import Optional, Type -import typing as t - -import requests -from oasst_shared.schemas import protocol as protocol_schema - - -class TaskType(str, enum.Enum): - summarize_story = "summarize_story" - rate_summary = "rate_summary" - initial_prompt = "initial_prompt" - user_reply = "user_reply" - assistant_reply = "assistant_reply" - rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" - rank_assistant_replies = "rank_assistant_replies" - done = "task_done" - - -class ApiClient: - def __init__(self, backend_url: str, api_key: str): - self.backend_url = backend_url - self.api_key = api_key - - task_models_map: dict[str, Type[protocol_schema.Task]] = { - TaskType.summarize_story: protocol_schema.SummarizeStoryTask, - TaskType.rate_summary: protocol_schema.RateSummaryTask, - TaskType.initial_prompt: protocol_schema.InitialPromptTask, - TaskType.user_reply: protocol_schema.UserReplyTask, - TaskType.assistant_reply: protocol_schema.AssistantReplyTask, - TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, - TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, - TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, - TaskType.done: protocol_schema.TaskDone, - } - self.task_models_map = task_models_map - - def post(self, path: str, json: dict) -> dict: - response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) - response.raise_for_status() - return response.json() - - def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: - if not isinstance(data, dict): - raise ValueError("dict expected") - - task_type = data.get("type") - if task_type not in self.task_models_map: - raise RuntimeError(f"Unsupported task type: {task_type}") - - return self.task_models_map[task_type].parse_obj(data) - - def fetch_task( - self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None - ) -> protocol_schema.Task: - req = protocol_schema.TaskRequest(type=task_type, user=user) - data = self.post("/api/v1/tasks/", req.dict()) - return self._parse_task(data) - - def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: - return self.fetch_task(protocol_schema.TaskRequestType.random, user) - - def ack_task(self, task_id: str, post_id: str) -> None: - req = protocol_schema.TaskAck(post_id=post_id) - return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) - - def nack_task(self, task_id: str, reason: str) -> None: - req = protocol_schema.TaskNAck(reason=reason) - return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - data = self.post("/api/v1/tasks/interaction", interaction.dict()) - return self._parse_task(data) diff --git a/discord-bot/bot/api_client.py b/discord-bot/bot/api_client.py new file mode 100644 index 00000000..cec1900f --- /dev/null +++ b/discord-bot/bot/api_client.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +import asyncio +import enum +import typing as t +from typing import Optional, Type +from uuid import UUID + +import aiohttp +from loguru import logger + +from oasst_shared.schemas import protocol as protocol_schema + + +class TaskType(str, enum.Enum): + summarize_story = "summarize_story" + rate_summary = "rate_summary" + initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" + rank_initial_prompts = "rank_initial_prompts" + rank_user_replies = "rank_user_replies" + rank_assistant_replies = "rank_assistant_replies" + done = "task_done" + + +class OasstApiClient: + """API Client for interacting with the OASST backend.""" + + def __init__(self, backend_url: str, api_key: str): + """Create a new OasstApiClient. + + Args: + backend_url (str): The base backend URL. + api_key (str): The API key to use for authentication. + """ + logger.debug("Opening OasstApiClient session") + self.session = aiohttp.ClientSession() + self.backend_url = backend_url + self.api_key = api_key + + self.task_models_map: dict[str, Type[protocol_schema.Task]] = { + TaskType.summarize_story: protocol_schema.SummarizeStoryTask, + TaskType.rate_summary: protocol_schema.RateSummaryTask, + TaskType.initial_prompt: protocol_schema.InitialPromptTask, + TaskType.user_reply: protocol_schema.UserReplyTask, + TaskType.assistant_reply: protocol_schema.AssistantReplyTask, + TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, + TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, + TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, + TaskType.done: protocol_schema.TaskDone, + } + + async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]: + """Make a POST request to the backend.""" + logger.debug(f"POST {self.backend_url}{path} DATA: {data}") + response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key}) + response.raise_for_status() + return await response.json() + + def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: + task_type = data.get("type") + + if not isinstance(task_type, str): + logger.error(f"task type must be a `str`: {task_type}") + raise ValueError(f"task type must be a `str`: {task_type}") + + model = self.task_models_map.get(task_type) + if not model: + logger.error(f"Unsupported task type: {task_type}") + raise ValueError(f"Unsupported task type: {task_type}") + return self.task_models_map[task_type].parse_obj(data) + + async def fetch_task( + self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + ) -> protocol_schema.Task: + """Fetch a task from the backend.""" + logger.debug(f"Fetching task {task_type} for user {user}") + req = protocol_schema.TaskRequest(type=task_type.value, user=user) + resp = await self.post(f"/api/v1/tasks/", data=req.dict()) + print("resp", resp) + return self._parse_task(resp) + + async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + """Fetch a random task from the backend.""" + logger.debug(f"Fetching random for user {user}") + return await self.fetch_task(protocol_schema.TaskRequestType.random, user) + + async def ack_task(self, task_id: str | UUID, post_id: str): + """Send an ACK for a task to the backend.""" + logger.debug(f"ACK task {task_id} with post {post_id}") + req = protocol_schema.TaskAck(post_id=post_id) + return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict()) + + async def nack_task(self, task_id: str | UUID, reason: str): + """Send a NACK for a task to the backend.""" + logger.debug(f"NACK task {task_id} with reason {reason}") + req = protocol_schema.TaskNAck(reason=reason) + return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict()) + + async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: + """Send a completed task to the backend.""" + logger.debug(f"Interaction: {interaction}") + resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict()) + + return self._parse_task(resp) + + async def close(self): + logger.debug("Closing OasstApiClient session") + await self.session.close() + + +async def main(): + api = OasstApiClient("http://localhost:8080", "test") + try: + task = await api.fetch_task(protocol_schema.TaskRequestType.initial_prompt, None) + print(task) + finally: + + await api.close() + # session = aiohttp.ClientSession() + # try: + # resp = await session.post("http://localhost:8080/api/v1/tasks/", json={"type": "initial_prompt", "user": None}) + # resp.raise_for_status() + # print(await resp.text()) + # finally: + # await session.close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index af163545..de8ceacf 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -6,6 +6,7 @@ import lightbulb import miru from bot.config import Config +from bot.api_client import OasstApiClient config = Config.from_env() @@ -29,8 +30,11 @@ async def on_starting(event: hikari.StartingEvent): await bot.d.db.executescript(open("./bot/db/schema.sql").read()) await bot.d.db.commit() + bot.d.oasst_api = OasstApiClient("http://localhost:8080", "any_key") + @bot.listen() async def on_stopping(event: hikari.StoppingEvent): """Cleanup.""" await bot.d.db.close() + await bot.d.oasst_api.close() diff --git a/discord-bot/bot/extensions/_example.py b/discord-bot/bot/extensions/_example.py index 8ac7fe21..330f5909 100644 --- a/discord-bot/bot/extensions/_example.py +++ b/discord-bot/bot/extensions/_example.py @@ -1,5 +1,6 @@ +# TODO: Convert file to markdown # -*- coding: utf-8 -*- -"""Example plugins for reference. +"""Example plugin for reference. Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`. """ @@ -396,6 +397,10 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None: await view.start(await resp.message()) +# TODO: Database example +# TODO: Rest client example + + def load(bot: lightbulb.BotApp): """Add the plugin to the bot.""" bot.add_plugin(plugin) diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py index b70a22fd..28bcede3 100644 --- a/discord-bot/bot/extensions/hot_reload.py +++ b/discord-bot/bot/extensions/hot_reload.py @@ -15,7 +15,7 @@ EXTENSIONS_FOLDER = "bot/extensions" def _get_extensions() -> list[str]: # Recursively get all the .py files in the extensions directory not starting with an `_`. - exts = glob("bot/extensions/**/*[!_].py", recursive=True) + exts = glob("bot/extensions/**/[!_]*.py", recursive=True) # Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension") return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts] diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py new file mode 100644 index 00000000..dfe51160 --- /dev/null +++ b/discord-bot/bot/extensions/tasks.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +"""Task plugin for testing different data collection methods.""" +import asyncio +import logging +import typing as t +from datetime import datetime, timedelta + +import hikari + +import lightbulb +import lightbulb.decorators +import miru +from bot.utils import format_time +from oasst_shared.schemas.protocol import TaskRequestType + +plugin = lightbulb.Plugin("TaskPlugin") + +MAX_TASK_TIME = 60 * 60 +MAX_TASK_ACCEPT_TIME = 60 +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True) +@lightbulb.implements(lightbulb.SlashCommand) +async def task_thread(ctx: lightbulb.SlashContext): + """Request a task from the backend.""" + typ: str = ctx.options.type + + # Create a thread for the task + thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}") + + await ctx.respond(f"Please complete the task in the thread: {thread.mention}") + + # Send the task in the thread + # TODO: Request task from the backend + await thread.send( + f"Please complete the task.\nSample Task\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + ) + + # Wait for the user to respond + try: + event = await ctx.bot.wait_for( + hikari.GuildMessageCreateEvent, + timeout=MAX_TASK_TIME, + predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id, + ) + await ctx.respond(f"Received message: {event.message.content}") + # TODO: Send the message to the backend + except asyncio.TimeoutError: + await ctx.respond("You took too long to respond.") + finally: + await thread.delete() + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True) +@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) +async def task_dm(ctx: lightbulb.Context): + """Request a task from the backend.""" + typ: str = ctx.options.type + + # Create a thread for the task + + await ctx.respond(f"Please complete the task in your DMs") + + # Send the task in the thread + # TODO: Request task from the backend + await ctx.author.send( + f"Please complete the task.\nSample Task ({typ})\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + ) + + # Wait for the user to respond + try: + event = await ctx.bot.wait_for( + hikari.DMMessageCreateEvent, + timeout=MAX_TASK_TIME, + predicate=lambda e: e.author.id == ctx.author.id, + ) + await ctx.respond(f"Received message: {event.message.content}") + # TODO: Send the message to the backend + except asyncio.TimeoutError: + await ctx.respond("You took too long to respond.") + + +class TaskModal(miru.Modal): + """Modal for submitting a task.""" + + response = miru.TextInput( + label="Response", + placeholder="Enter your response!", + required=True, + style=hikari.TextInputStyle.PARAGRAPH, + row=2, + ) + + async def callback(self, context: miru.ModalContext) -> None: + await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL) + # TODO: Send the message to the backend + + +class ModalView(miru.View): + """View for opening a modal.""" + + def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + self.modal_title = modal_title + self.task = task + + @miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY) + async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + modal = TaskModal(title=self.modal_title) + modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1)) + await ctx.respond_with_modal(modal) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType], + required=False, + default=TaskRequestType.summarize_story, + type=str, +) +@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True) +@lightbulb.implements(lightbulb.SlashCommand) +async def task_modal(ctx: lightbulb.SlashContext): + """Request a task from the backend.""" + # typ: str = ctx.options.type + view = ModalView( + modal_title=f"Assistant Response", + task="Please explain the moon landing to a six year old.", + timeout=MAX_TASK_TIME, + ) + resp = await ctx.respond( + "Task - Respond to the prompt as if you were the Assistant:", + flags=hikari.MessageFlag.EPHEMERAL, + components=view, + ) + await view.start(await resp.message()) + + +class RatingView(miru.View): + """View for rating a task.""" + + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: + super().__init__(*args, **kwargs) + self.presses: list[str] = [] + + def _close_if_all_pressed(self) -> None: + if len(self.presses) == 5: + self.stop() + + @miru.button(label="1", style=hikari.ButtonStyle.PRIMARY) + async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("1") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="2", style=hikari.ButtonStyle.PRIMARY) + async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("2") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="3", style=hikari.ButtonStyle.PRIMARY) + async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("3") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="4", style=hikari.ButtonStyle.PRIMARY) + async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("4") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="5", style=hikari.ButtonStyle.PRIMARY) + async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None: + if button.label not in self.presses: + self.presses.append("5") + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + self._close_if_all_pressed() + + @miru.button(label="Reset", style=hikari.ButtonStyle.DANGER) + async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.presses = [] + await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL) + + +class SelectRating(miru.View): + @miru.select( + options=[ + hikari.SelectMenuOption( + label="1", + value="1", + description=None, + emoji=None, + is_default=False, + ), + hikari.SelectMenuOption( + label="2", + value="2", + description=None, + emoji=None, + is_default=False, + ), + hikari.SelectMenuOption( + label="3", + value="3", + description=None, + emoji=None, + is_default=False, + ), + ], + placeholder="Select the good responses", + min_values=0, + max_values=3, + row=3, + ) + async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None: + await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL) + + +@plugin.command +@lightbulb.command("rating_task", "Rate stuff.") +@lightbulb.implements(lightbulb.SlashCommand) +async def rating_task(ctx: lightbulb.SlashContext): + """Rate stuff.""" + + # Message Based rating + await ctx.respond( + "List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL + ) + try: + event = await ctx.bot.wait_for( + hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id + ) + + except asyncio.TimeoutError: + await ctx.respond("Timed out waiting for response") + return + + if event.content is None: + await ctx.respond("No content in message") + return + ratings = event.content.replace(" ", "").split(",") + + # Check if the ratings are valid + if len(ratings) != 5: + await ctx.respond("Invalid number of ratings") + if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]): + await ctx.respond("Invalid rating") + + await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL) + # Button Based rating + view = RatingView(timeout=MAX_TASK_TIME) + + resp = await ctx.respond("Click the buttons in order of best to worst response", components=view) + await view.start(await resp.message()) + await view.wait() + await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL) + await resp.delete() + + # Select Based rating + select_view = SelectRating(timeout=MAX_TASK_TIME) + resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL) + await select_view.start(await resp_2.message()) + await select_view.wait() + await resp_2.delete() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py new file mode 100644 index 00000000..e6ea3d7c --- /dev/null +++ b/discord-bot/bot/extensions/work.py @@ -0,0 +1,281 @@ +# -*- coding: utf-8 -*- +"""Work plugin for collecting user data.""" +import asyncio +import logging +import typing as t +from datetime import datetime + +import hikari + +import lightbulb +import lightbulb.decorators +import miru +from bot.api_client import OasstApiClient, TaskType +from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TaskRequestType +from bot.utils import ZWJ + +plugin = lightbulb.Plugin("WorkPlugin") + +MAX_TASK_TIME = 60 * 60 # 1 hour +MAX_TASK_ACCEPT_TIME = 60 # 1 minute + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + + +@plugin.command +@lightbulb.option( + "type", + "The type of task to request.", + choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], + required=False, + default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random + type=str, +) +@lightbulb.command("work", "Complete a task.") +@lightbulb.implements(lightbulb.SlashCommand) +async def work(ctx: lightbulb.SlashContext): + """Create and handle a task.""" + task_type: TaskRequestType = TaskRequestType(ctx.options.type) + + await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) + logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}") + + await _handle_task(ctx, task_type) + + +async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None: + """Handle creating and collecting user input for a task. + + Continually present tasks to the user until they select one, cancel, or time out. + If they select one, present the task steps until a `task_done` task is received. + Finally, ask the user if they want to perform another task (of the same type). + """ + + oasst_api: OasstApiClient = ctx.bot.d.oasst_api + + # Continue to complete tasks until the user doesn't want to do another + done = False + while not done: + + # Loop until the user accepts a task + task, msg_id = await _select_task(ctx, task_type) + + if task is None: + return + + # Task action loop + completed = False + while not completed: + await ctx.author.send("Please type your response here:") + try: + event = await ctx.bot.wait_for( + hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id + ) + except asyncio.TimeoutError: + await ctx.author.send("Task timed out. Exiting") + # TODO: NACK task maybe? + return + + # Invalid response + if event.content is None: + await ctx.author.send("No content in message") + continue + + logger.info(f"User input received: {event.content}") + + # Send the response to the backend + reply = protocol_schema.TextReplyToPost( + post_id=str(msg_id), + user_post_id=str(event.message_id), + user=protocol_schema.User( + auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username + ), + text=event.content, + ) + logger.debug(f"Sending reply to backend: {reply!r}") + + # Get next task + new_task = await oasst_api.post_interaction(reply) + logger.info(f"New task {new_task}") + + if new_task.type == TaskType.done: + await ctx.author.send("Task completed") + completed = True + continue + else: + logger.fatal(f"Unexpected task type received: {new_task.type}") + + # ask the user if they want to do another task + choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) + msg = await ctx.author.send("Would you like another task?", components=choice_view) + await choice_view.start(msg) + await choice_view.wait() + + match choice_view.choice: + case False | None: + done = True + await ctx.author.send("Exiting, goodbye!") + case True: + pass + + +async def _select_task( + ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None +) -> tuple[protocol_schema.Task | None, str]: + """Present tasks to the user until they accept one, cancel, or time out.""" + oasst_api: OasstApiClient = ctx.bot.d.oasst_api + logger.debug(f"Starting task selection for {task_type}") + + # Loop until the user accepts a task, cancels, or times out + while True: + logger.debug(f"Requesting task of type {task_type}") + task = await oasst_api.fetch_task(task_type, user) + resp, msg_id = await _send_task(ctx, task) + + logger.debug(f"user choice: {resp}") + match resp: + case "accept": + logger.info(f"Task {task.id} accepted, sending ACK") + await oasst_api.ack_task(task.id, msg_id) + return task, msg_id + + case "next": + logger.info(f"Task {task.id} rejected, sending NACK") + await oasst_api.nack_task(task.id, "rejected") + await ctx.author.send("Sending next task...") + continue + + case "cancel": + logger.info(f"Task {task.id} canceled, sending NACK") + await oasst_api.nack_task(task.id, "canceled") + await ctx.author.send("Task canceled. Exiting") + return None, msg_id + + case None: + logger.info(f"Task {task.id} timed out, sending NACK") + await oasst_api.nack_task(task.id, "timed out") + await ctx.author.send("Task timed out. Exiting") + return None, msg_id + + +async def _send_task( + ctx: lightbulb.SlashContext, task: protocol_schema.Task +) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]: + """Send a task to the user. + + Returns the user's choice and the message ID of the task message.""" + + # The clean way to do this would be to attach a `to_embed` method to the task classes + # but the tasks aren't discord specific so that doesn't really make sense. + + view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) + embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED + + # Create an embed based on the task's type + if task.type == TaskRequestType.initial_prompt: + assert isinstance(task, protocol_schema.InitialPromptTask) + logger.info("sending initial prompt task") + embed = _initial_prompt_embed(task) + + elif task.type == TaskRequestType.rank_initial_prompts: + assert isinstance(task, protocol_schema.RankInitialPromptsTask) + logger.info("sending rank initial prompt task") + embed = _rank_initial_prompt_embed(task) + + else: + logger.error(f"unknown task type {task.type}") + + msg = await ctx.author.send( + ZWJ, + embed=embed, + components=view, + ) + + assert msg is not None + + await view.start(msg) + await view.wait() + + return view.choice, str(msg.id) + + +def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: + return ( + hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) + .set_image( + "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", + ) + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + +def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="Rank Initial Prompt", + description=f"Rank the following tasks from best to worst (1,2,3,4,5)", + timestamp=datetime.now().astimezone(), + ) + .set_image( + "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", + ) + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for i, prompt in enumerate(task.prompts): + embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False) + + return embed + + +class TaskAcceptView(miru.View): + """View with three buttons: accept, next, and cancel. + + The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute. + """ + + choice: t.Literal["accept", "next", "cancel"] | None = None + + @miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS) + async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Accept button pressed") + self.choice = "accept" + self.stop() + + @miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY) + async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Next button pressed") + self.choice = "next" + self.stop() + + @miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER) + async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + logger.info("Cancel button pressed") + self.choice = "cancel" + self.stop() + + +class ChoiceView(miru.View): + choice: bool | None = None + + @miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS) + async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.choice = True + self.stop() + + @miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER) + async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: + self.choice = False + self.stop() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py index beb81c36..1ff6ef1f 100644 --- a/discord-bot/bot/utils.py +++ b/discord-bot/bot/utils.py @@ -21,3 +21,10 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s return f"" case _: raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}") + + +ZWJ = "\u200d" +"""Zero-width joiner. + +This appears as an empty message in Discord. +""" diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 49c5e1ba..17348c12 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -7,4 +7,5 @@ hikari-miru # modals and buttons python-dotenv # .env file support aiosqlite # database aiohttp # http client -aiohttp[speedups] # speedups for aiohttp \ No newline at end of file +aiohttp[speedups] # speedups for aiohttp +loguru \ No newline at end of file From 221d3396f7e97a6978ecdc3b4852973f3779cd52 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Thu, 29 Dec 2022 14:39:02 -0800 Subject: [PATCH 05/27] clean up code --- discord-bot/bot/api_client.py | 41 ++++++-------------------- discord-bot/bot/bot.py | 2 +- discord-bot/bot/extensions/_example.py | 3 +- discord-bot/bot/extensions/tasks.py | 36 +++++++++++----------- discord-bot/bot/extensions/work.py | 24 +++++++-------- discord-bot/message_templates.py | 10 +++++-- discord-bot/noxfile.py | 1 + 7 files changed, 51 insertions(+), 66 deletions(-) diff --git a/discord-bot/bot/api_client.py b/discord-bot/bot/api_client.py index cec1900f..9f319869 100644 --- a/discord-bot/bot/api_client.py +++ b/discord-bot/bot/api_client.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -import asyncio +"""API Client for interacting with the OASST backend.""" import enum import typing as t from typing import Optional, Type @@ -7,11 +7,12 @@ from uuid import UUID import aiohttp from loguru import logger - from oasst_shared.schemas import protocol as protocol_schema class TaskType(str, enum.Enum): + """Task types.""" + summarize_story = "summarize_story" rate_summary = "rate_summary" initial_prompt = "initial_prompt" @@ -30,6 +31,7 @@ class OasstApiClient: """Create a new OasstApiClient. Args: + ---- backend_url (str): The base backend URL. api_key (str): The API key to use for authentication. """ @@ -38,7 +40,7 @@ class OasstApiClient: self.backend_url = backend_url self.api_key = api_key - self.task_models_map: dict[str, Type[protocol_schema.Task]] = { + self.task_models_map: dict[TaskType, Type[protocol_schema.Task]] = { TaskType.summarize_story: protocol_schema.SummarizeStoryTask, TaskType.rate_summary: protocol_schema.RateSummaryTask, TaskType.initial_prompt: protocol_schema.InitialPromptTask, @@ -58,17 +60,13 @@ class OasstApiClient: return await response.json() def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task: - task_type = data.get("type") - - if not isinstance(task_type, str): - logger.error(f"task type must be a `str`: {task_type}") - raise ValueError(f"task type must be a `str`: {task_type}") + task_type = TaskType(data.get("type")) model = self.task_models_map.get(task_type) if not model: logger.error(f"Unsupported task type: {task_type}") raise ValueError(f"Unsupported task type: {task_type}") - return self.task_models_map[task_type].parse_obj(data) + return self.task_models_map[task_type].parse_obj(data) # type: ignore async def fetch_task( self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None @@ -76,8 +74,8 @@ class OasstApiClient: """Fetch a task from the backend.""" logger.debug(f"Fetching task {task_type} for user {user}") req = protocol_schema.TaskRequest(type=task_type.value, user=user) - resp = await self.post(f"/api/v1/tasks/", data=req.dict()) - print("resp", resp) + resp = await self.post("/api/v1/tasks/", data=req.dict()) + logger.debug(f"Fetch task response: {resp}") return self._parse_task(resp) async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: @@ -107,24 +105,3 @@ class OasstApiClient: async def close(self): logger.debug("Closing OasstApiClient session") await self.session.close() - - -async def main(): - api = OasstApiClient("http://localhost:8080", "test") - try: - task = await api.fetch_task(protocol_schema.TaskRequestType.initial_prompt, None) - print(task) - finally: - - await api.close() - # session = aiohttp.ClientSession() - # try: - # resp = await session.post("http://localhost:8080/api/v1/tasks/", json={"type": "initial_prompt", "user": None}) - # resp.raise_for_status() - # print(await resp.text()) - # finally: - # await session.close() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index de8ceacf..e189b765 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -5,8 +5,8 @@ import hikari import lightbulb import miru -from bot.config import Config from bot.api_client import OasstApiClient +from bot.config import Config config = Config.from_env() diff --git a/discord-bot/bot/extensions/_example.py b/discord-bot/bot/extensions/_example.py index 330f5909..37783e43 100644 --- a/discord-bot/bot/extensions/_example.py +++ b/discord-bot/bot/extensions/_example.py @@ -2,7 +2,8 @@ # -*- coding: utf-8 -*- """Example plugin for reference. -Because this file starts with an `_`, it cannot be loaded by the bot. To see the example plugin in action, rename this file to `example.py`. +Because this file starts with an `_`, it cannot be loaded by the bot. +To see the example plugin in action, rename this file to `example.py`. """ import asyncio diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py index dfe51160..71f47f52 100644 --- a/discord-bot/bot/extensions/tasks.py +++ b/discord-bot/bot/extensions/tasks.py @@ -6,13 +6,13 @@ import typing as t from datetime import datetime, timedelta import hikari - import lightbulb import lightbulb.decorators import miru -from bot.utils import format_time from oasst_shared.schemas.protocol import TaskRequestType +from bot.utils import format_time + plugin = lightbulb.Plugin("TaskPlugin") MAX_TASK_TIME = 60 * 60 @@ -42,9 +42,13 @@ async def task_thread(ctx: lightbulb.SlashContext): await ctx.respond(f"Please complete the task in the thread: {thread.mention}") # Send the task in the thread - # TODO: Request task from the backend await thread.send( - f"Please complete the task.\nSample Task\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + f"""\ +Please complete the task. +Sample Task + +Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')} +""" ) # Wait for the user to respond @@ -55,7 +59,6 @@ async def task_thread(ctx: lightbulb.SlashContext): predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id, ) await ctx.respond(f"Received message: {event.message.content}") - # TODO: Send the message to the backend except asyncio.TimeoutError: await ctx.respond("You took too long to respond.") finally: @@ -75,16 +78,16 @@ async def task_thread(ctx: lightbulb.SlashContext): @lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand) async def task_dm(ctx: lightbulb.Context): """Request a task from the backend.""" - typ: str = ctx.options.type + await ctx.respond("Please complete the task in your DMs") - # Create a thread for the task - - await ctx.respond(f"Please complete the task in your DMs") - - # Send the task in the thread - # TODO: Request task from the backend + # Send the task in the dm await ctx.author.send( - f"Please complete the task.\nSample Task ({typ})\n\nSelf destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}" + f"""\ +Please complete the task. +Sample Task + +Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')} +""" ) # Wait for the user to respond @@ -95,7 +98,6 @@ async def task_dm(ctx: lightbulb.Context): predicate=lambda e: e.author.id == ctx.author.id, ) await ctx.respond(f"Received message: {event.message.content}") - # TODO: Send the message to the backend except asyncio.TimeoutError: await ctx.respond("You took too long to respond.") @@ -113,7 +115,6 @@ class TaskModal(miru.Modal): async def callback(self, context: miru.ModalContext) -> None: await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL) - # TODO: Send the message to the backend class ModalView(miru.View): @@ -146,7 +147,7 @@ async def task_modal(ctx: lightbulb.SlashContext): """Request a task from the backend.""" # typ: str = ctx.options.type view = ModalView( - modal_title=f"Assistant Response", + modal_title="Assistant Response", task="Please explain the moon landing to a six year old.", timeout=MAX_TASK_TIME, ) @@ -211,6 +212,8 @@ class RatingView(miru.View): class SelectRating(miru.View): + """View for rating a task with a select menu.""" + @miru.select( options=[ hikari.SelectMenuOption( @@ -249,7 +252,6 @@ class SelectRating(miru.View): @lightbulb.implements(lightbulb.SlashCommand) async def rating_task(ctx: lightbulb.SlashContext): """Rate stuff.""" - # Message Based rating await ctx.respond( "List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index e6ea3d7c..5c191481 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -6,13 +6,13 @@ import typing as t from datetime import datetime import hikari - import lightbulb import lightbulb.decorators import miru -from bot.api_client import OasstApiClient, TaskType from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TaskRequestType + +from bot.api_client import OasstApiClient, TaskType from bot.utils import ZWJ plugin = lightbulb.Plugin("WorkPlugin") @@ -52,7 +52,6 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) If they select one, present the task steps until a `task_done` task is received. Finally, ask the user if they want to perform another task (of the same type). """ - oasst_api: OasstApiClient = ctx.bot.d.oasst_api # Continue to complete tasks until the user doesn't want to do another @@ -165,8 +164,8 @@ async def _send_task( ) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]: """Send a task to the user. - Returns the user's choice and the message ID of the task message.""" - + Returns the user's choice and the message ID of the task message. + """ # The clean way to do this would be to attach a `to_embed` method to the task classes # but the tasks aren't discord specific so that doesn't really make sense. @@ -204,9 +203,7 @@ async def _send_task( def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: return ( hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) - .set_image( - "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", - ) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") .set_footer(text=f"OASST Assistant | {task.id}") ) @@ -215,12 +212,10 @@ def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> embed = ( hikari.Embed( title="Rank Initial Prompt", - description=f"Rank the following tasks from best to worst (1,2,3,4,5)", + description="Rank the following tasks from best to worst (1,2,3,4,5)", timestamp=datetime.now().astimezone(), ) - .set_image( - "https://images.unsplash.com/photo-1455390582262-044cdead277a?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1073&q=80", - ) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") .set_footer(text=f"OASST Assistant | {task.id}") ) @@ -258,6 +253,11 @@ class TaskAcceptView(miru.View): class ChoiceView(miru.View): + """View with two buttons: yes and no. + + The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute. + """ + choice: bool | None = None @miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS) diff --git a/discord-bot/message_templates.py b/discord-bot/message_templates.py index df3ef1ac..dcb84c94 100644 --- a/discord-bot/message_templates.py +++ b/discord-bot/message_templates.py @@ -1,16 +1,20 @@ # -*- coding: utf-8 -*- +"""Message templates for the discord bot.""" import jinja2 +import typing from loguru import logger class MessageTemplates: - def __init__(self, template_dir="./templates"): - self.env = jinja2.Environment( + """Create message templates for the discord bot.""" + + def __init__(self, template_dir: str = "./templates"): + self.env = jinja2.Environment( # noqa: S701 loader=jinja2.FileSystemLoader(template_dir), autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False), ) - def render(self, template_name, **kwargs): + def render(self, template_name: str, **kwargs: typing.Any): template = self.env.get_template(template_name) txt = template.render(kwargs) logger.debug(txt) diff --git a/discord-bot/noxfile.py b/discord-bot/noxfile.py index 37226787..f85fc60c 100644 --- a/discord-bot/noxfile.py +++ b/discord-bot/noxfile.py @@ -27,6 +27,7 @@ def lint_code(session: Session): @nox.session(reuse_venv=True) def typecheck_code(session: Session): + """Typecheck the codebase.""" session.install("-r", "requirements.txt", "-U") session.install("pyright", "-U") From 84146f23960dcc7d00ca7516537667004f4003a2 Mon Sep 17 00:00:00 2001 From: AlexanderHOtt Date: Thu, 29 Dec 2022 14:42:55 -0800 Subject: [PATCH 06/27] remove database file (luckly it was empty) --- discord-bot/.gitignore | 5 ++++- discord-bot/bot/db/database.db | 0 2 files changed, 4 insertions(+), 1 deletion(-) delete mode 100644 discord-bot/bot/db/database.db diff --git a/discord-bot/.gitignore b/discord-bot/.gitignore index 2842b686..499012d2 100644 --- a/discord-bot/.gitignore +++ b/discord-bot/.gitignore @@ -4,4 +4,7 @@ __pycache__/ .venv .nox -.env \ No newline at end of file +.env + +# Database files +*.db \ No newline at end of file diff --git a/discord-bot/bot/db/database.db b/discord-bot/bot/db/database.db deleted file mode 100644 index e69de29b..00000000 From bb3b0e739781c8d26fb6d2f1d474fda69f9d7260 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 29 Dec 2022 15:00:20 -0800 Subject: [PATCH 07/27] update api client to upstream version --- discord-bot/api_client.py | 79 ----------------------------------- discord-bot/bot/api_client.py | 13 ++++-- 2 files changed, 9 insertions(+), 83 deletions(-) delete mode 100644 discord-bot/api_client.py diff --git a/discord-bot/api_client.py b/discord-bot/api_client.py deleted file mode 100644 index 0c88258e..00000000 --- a/discord-bot/api_client.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -import enum -from typing import Optional, Type - -import requests -from oasst_shared.schemas import protocol as protocol_schema - - -class TaskType(str, enum.Enum): - summarize_story = "summarize_story" - rate_summary = "rate_summary" - initial_prompt = "initial_prompt" - user_reply = "user_reply" - assistant_reply = "assistant_reply" - rank_initial_prompts = "rank_initial_prompts" - rank_user_replies = "rank_user_replies" - rank_assistant_replies = "rank_assistant_replies" - done = "task_done" - - -class ApiClient: - def __init__(self, backend_url: str, api_key: str): - self.backend_url = backend_url - self.api_key = api_key - - task_models_map: dict[str, Type[protocol_schema.Task]] = { - TaskType.summarize_story: protocol_schema.SummarizeStoryTask, - TaskType.rate_summary: protocol_schema.RateSummaryTask, - TaskType.initial_prompt: protocol_schema.InitialPromptTask, - TaskType.user_reply: protocol_schema.UserReplyTask, - TaskType.assistant_reply: protocol_schema.AssistantReplyTask, - TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, - TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, - TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, - TaskType.done: protocol_schema.TaskDone, - } - self.task_models_map = task_models_map - - def post(self, path: str, json: dict) -> dict: - response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) - response.raise_for_status() - return response.json() - - def _parse_task(self, data: dict) -> protocol_schema.Task: - if not isinstance(data, dict): - raise ValueError("dict expected") - - task_type = data.get("type") - if task_type not in self.task_models_map: - raise RuntimeError(f"Unsupported task type: {task_type}") - - return self.task_models_map[task_type].parse_obj(data) - - def fetch_task( - self, - task_type: protocol_schema.TaskRequestType, - user: Optional[protocol_schema.User] = None, - collective: bool = False, - ) -> protocol_schema.Task: - req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective) - data = self.post("/api/v1/tasks/", req.dict()) - return self._parse_task(data) - - def fetch_random_task( - self, user: Optional[protocol_schema.User] = None, collective: bool = False - ) -> protocol_schema.Task: - return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective) - - def ack_task(self, task_id: str, post_id: str) -> None: - req = protocol_schema.TaskAck(post_id=post_id) - return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) - - def nack_task(self, task_id: str, reason: str) -> None: - req = protocol_schema.TaskNAck(reason=reason) - return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: - data = self.post("/api/v1/tasks/interaction", interaction.dict()) - return self._parse_task(data) diff --git a/discord-bot/bot/api_client.py b/discord-bot/bot/api_client.py index 9f319869..b5c96505 100644 --- a/discord-bot/bot/api_client.py +++ b/discord-bot/bot/api_client.py @@ -69,19 +69,24 @@ class OasstApiClient: return self.task_models_map[task_type].parse_obj(data) # type: ignore async def fetch_task( - self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + self, + task_type: protocol_schema.TaskRequestType, + user: Optional[protocol_schema.User] = None, + collective: bool = False, ) -> protocol_schema.Task: """Fetch a task from the backend.""" logger.debug(f"Fetching task {task_type} for user {user}") - req = protocol_schema.TaskRequest(type=task_type.value, user=user) + req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective) resp = await self.post("/api/v1/tasks/", data=req.dict()) logger.debug(f"Fetch task response: {resp}") return self._parse_task(resp) - async def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + async def fetch_random_task( + self, user: Optional[protocol_schema.User] = None, collective: bool = False + ) -> protocol_schema.Task: """Fetch a random task from the backend.""" logger.debug(f"Fetching random for user {user}") - return await self.fetch_task(protocol_schema.TaskRequestType.random, user) + return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective) async def ack_task(self, task_id: str | UUID, post_id: str): """Send an ACK for a task to the backend.""" From a2a0e1608d42231ff6403bdd3bf918a0998fad62 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 29 Dec 2022 15:45:01 -0800 Subject: [PATCH 08/27] fix: parse TaskRequestType enum correctly --- discord-bot/bot/extensions/work.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 5c191481..1175e4a5 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -37,7 +37,7 @@ logger.setLevel(logging.DEBUG) @lightbulb.implements(lightbulb.SlashCommand) async def work(ctx: lightbulb.SlashContext): """Create and handle a task.""" - task_type: TaskRequestType = TaskRequestType(ctx.options.type) + task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}") From 26c1b4eaab1280c1af0ca84747214ef06492b315 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 29 Dec 2022 16:24:55 -0800 Subject: [PATCH 09/27] merge upstream main --- discord-bot/{bot => }/api_client.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord-bot/{bot => }/api_client.py (100%) diff --git a/discord-bot/bot/api_client.py b/discord-bot/api_client.py similarity index 100% rename from discord-bot/bot/api_client.py rename to discord-bot/api_client.py From 9c15258fd1008eadf801dffab2ab0d611cf9ab2d Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Thu, 29 Dec 2022 16:25:37 -0800 Subject: [PATCH 10/27] move api_client.py back to the correct position --- discord-bot/{ => bot}/api_client.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord-bot/{ => bot}/api_client.py (100%) diff --git a/discord-bot/api_client.py b/discord-bot/bot/api_client.py similarity index 100% rename from discord-bot/api_client.py rename to discord-bot/bot/api_client.py From 63e2120825b9b3565f32b3b329d386d1baea472e Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 02:02:46 -0800 Subject: [PATCH 11/27] add completion message to task and add message labeling --- discord-bot/bot/extensions/text_labels.py | 166 ++++++++++++++++++++ discord-bot/bot/extensions/work.py | 178 ++++++++++++++++++---- discord-bot/bot/utils.py | 2 +- 3 files changed, 317 insertions(+), 29 deletions(-) create mode 100644 discord-bot/bot/extensions/text_labels.py diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py new file mode 100644 index 00000000..ab7d109d --- /dev/null +++ b/discord-bot/bot/extensions/text_labels.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +"""Hot reload plugin.""" +import hikari +import lightbulb +import miru +from datetime import datetime + +import typing as t + +plugin = lightbulb.Plugin( + "HotReloadPlugin", +) + +from bot.utils import EMPTY + +DISCORD_GRAY = 0x2F3136 + + +def clamp(num: float) -> float: + """Clamp a number between 0 and 1.""" + return min(max(0.0, num), 1.0) + + +class LabelModal(miru.Modal): + """Modal for submitting text labels.""" + + def __init__(self, label: str, content: str, *args: t.Any, **kwargs: t.Any): + super().__init__(*args, **kwargs) + self.label = label + self.original_content = content + + # Add the text of the message to the modal + self.content = miru.TextInput( + label="Text", style=hikari.TextInputStyle.PARAGRAPH, value=content, required=True, row=1 + ) + self.add_item(self.content) + + value = miru.TextInput(label="Value", placeholder="Enter a value between 0 and 1", required=True, row=2) + + async def callback(self, context: miru.ModalContext) -> None: + val = float(self.value.value) if self.value.value else 0.0 + val = clamp(val) + + edited = self.content.value != self.original_content + await context.respond( + f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.", + flags=hikari.MessageFlag.EPHEMERAL, + ) + + # Send a notification to the log channel + embed = ( + hikari.Embed( + title="Message Label", + description=f"{context.author.mention} labeled a message as `{self.label}`.", + timestamp=datetime.now().astimezone(), + color=0x00FF00, + ) + .set_author(name=context.author.username, icon=context.author.avatar_url) + .add_field("Total Labeled Message", "0", inline=True) + .add_field("Server Ranking", "0/0", inline=True) + .add_field("Global Ranking", "0/0", inline=True) + .set_footer(f"Message ID: TODO") + ) + channel = await context.bot.rest.fetch_channel(1058299131115872297) + assert isinstance(channel, hikari.TextableChannel) + await channel.send(EMPTY, embed=embed) + + +class LabelSelect(miru.View): + """Select menu for selecting a label. + + The current labels are: + - contains toxic language + - encourages illegal activity + - good quality + - bad quality + - is spam + """ + + def __init__(self, content: str, *args: t.Any, **kwargs: t.Any): + super().__init__(*args, **kwargs) + self.content = content + + @miru.select( + options=[ + hikari.SelectMenuOption( + label="Toxic Language", + value="toxic_language", + description="The message contains toxic language.", + is_default=False, + emoji=None, + ), + hikari.SelectMenuOption( + label="Illegal Activity", + value="illegal_activity", + description="The message encourages illegal activity.", + is_default=False, + emoji=None, + ), + hikari.SelectMenuOption( + label="Good Quality", + value="good_quality", + description="The message is good quality.", + is_default=False, + emoji=None, + ), + hikari.SelectMenuOption( + label="Bad Quality", + value="bad_quality", + description="The message is bad quality.", + is_default=False, + emoji=None, + ), + hikari.SelectMenuOption( + label="Spam", + value="spam", + description="The message is spam.", + is_default=False, + emoji=None, + ), + ], + min_values=1, + max_values=1, + ) + async def label_select(self, select: miru.Select, ctx: miru.ViewContext) -> None: + """Handle the select menu.""" + label = select.values[0] + modal = LabelModal(label, self.content, title=f"Text Label: {label}", timeout=60) + await modal.send(ctx.interaction) + await modal.wait() + + self.stop() + + +@plugin.command +@lightbulb.command("Label Message", "Label a message") +@lightbulb.implements(lightbulb.MessageCommand) +async def label_message_text(ctx: lightbulb.MessageContext): + """Label a message.""" + msg: hikari.Message = ctx.options.target + # Exit if the message is empty + if not msg.content: + await ctx.respond("Cannot label an empty message.", flags=hikari.MessageFlag.EPHEMERAL) + return + + # Send the select menu + # The modal will be opened from the select menu interaction + embed = hikari.Embed(title="Label Message", description="Select a label for the message.", color=DISCORD_GRAY) + label_select_view = LabelSelect( + msg.content, + timeout=60, + ) + resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL) + + await label_select_view.start(await resp.message()) + await label_select_view.wait() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 1175e4a5..86462f87 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -13,7 +13,7 @@ from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TaskRequestType from bot.api_client import OasstApiClient, TaskType -from bot.utils import ZWJ +from bot.utils import EMPTY plugin = lightbulb.Plugin("WorkPlugin") @@ -106,6 +106,24 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) else: logger.fatal(f"Unexpected task type received: {new_task.type}") + # Send a message in the log channel that the task is complete + # TODO: Maybe do something with the msg ID + done_embed = ( + hikari.Embed( + title="Task Completion", + description=f"`{task.type}` completed by {ctx.author.mention}", + color=hikari.Color(0x00FF00), + timestamp=datetime.now().astimezone(), + ) + .add_field("Total Tasks", "0", inline=True) + .add_field("Server Ranking", "0/0", inline=True) + .add_field("Global Ranking", "0/0", inline=True) + .set_footer(f"Task ID: {task.id}") + ) + channel = await ctx.bot.rest.fetch_channel(1058299131115872297) + assert isinstance(channel, hikari.TextableChannel) + await channel.send(EMPTY, embed=done_embed) + # ask the user if they want to do another task choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) msg = await ctx.author.send("Would you like another task?", components=choice_view) @@ -169,7 +187,6 @@ async def _send_task( # The clean way to do this would be to attach a `to_embed` method to the task classes # but the tasks aren't discord specific so that doesn't really make sense. - view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED # Create an embed based on the task's type @@ -183,11 +200,38 @@ async def _send_task( logger.info("sending rank initial prompt task") embed = _rank_initial_prompt_embed(task) + elif task.type == TaskRequestType.rank_user_replies: + assert isinstance(task, protocol_schema.RankUserRepliesTask) + logger.info("sending rank user reply task") + embed = _rank_user_reply_embed(task) + + elif task.type == TaskRequestType.rank_assistant_replies: + assert isinstance(task, protocol_schema.RankAssistantRepliesTask) + logger.info("sending rank assistant reply task") + embed = _rank_assistant_reply_embed(task) + + elif task.type == TaskRequestType.user_reply: + assert isinstance(task, protocol_schema.UserReplyTask) + logger.info("sending user reply task") + embed = _user_reply_embed(task) + + elif task.type == TaskRequestType.assistant_reply: + assert isinstance(task, protocol_schema.AssistantReplyTask) + logger.info("sending assistant reply task") + embed = _assistant_reply_embed(task) + + elif task.type == TaskRequestType.summarize_story: + raise NotImplementedError + elif task.type == TaskRequestType.rate_summary: + raise NotImplementedError + else: logger.error(f"unknown task type {task.type}") + raise ValueError(f"unknown task type {task.type}") + view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) msg = await ctx.author.send( - ZWJ, + EMPTY, embed=embed, components=view, ) @@ -200,31 +244,6 @@ async def _send_task( return view.choice, str(msg.id) -def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: - return ( - hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - -def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed: - embed = ( - hikari.Embed( - title="Rank Initial Prompt", - description="Rank the following tasks from best to worst (1,2,3,4,5)", - timestamp=datetime.now().astimezone(), - ) - .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") - .set_footer(text=f"OASST Assistant | {task.id}") - ) - - for i, prompt in enumerate(task.prompts): - embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False) - - return embed - - class TaskAcceptView(miru.View): """View with three buttons: accept, next, and cancel. @@ -271,6 +290,109 @@ class ChoiceView(miru.View): self.stop() +################################################################ +# Template Embeds # +################################################################ + +# TODO: Maybe implement a better way of creating embeds, like `from_json` or something + + +def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed: + return ( + hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone()) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + +def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="Rank Initial Prompt", + description="Rank the following tasks from best to worst (1,2,3,4,5)", + timestamp=datetime.now().astimezone(), + ) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for i, prompt in enumerate(task.prompts): + embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False) + + return embed + + +def _rank_user_reply_embed(task: protocol_schema.RankUserRepliesTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="Rank User Reply", + description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4", + timestamp=datetime.now().astimezone(), + ) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for i, reply in enumerate(task.replies): + embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False) + + return embed + + +def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="Rank Assistant Reply", + description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4", + timestamp=datetime.now().astimezone(), + ) + .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for i, reply in enumerate(task.replies): + embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False) + + return embed + + +def _user_reply_embed(task: protocol_schema.UserReplyTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="User Reply", + description=f"""\ + Send the next message in the conversation as if you were the user. + {'Hint: ' if task.hint else ''} + """, + timestamp=datetime.now().astimezone(), + ) + # .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for message in task.conversation.messages: + embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False) + + return embed + + +def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed: + embed = ( + hikari.Embed( + title="User Reply", + description="Send the next message in the conversation as if you were the user.", + timestamp=datetime.now().astimezone(), + ) + # .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image + .set_footer(text=f"OASST Assistant | {task.id}") + ) + + for message in task.conversation.messages: + embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False) + + return embed + + def load(bot: lightbulb.BotApp): """Add the plugin to the bot.""" bot.add_plugin(plugin) diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py index 1ff6ef1f..1ce99560 100644 --- a/discord-bot/bot/utils.py +++ b/discord-bot/bot/utils.py @@ -23,7 +23,7 @@ def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> s raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}") -ZWJ = "\u200d" +EMPTY = "\u200d" """Zero-width joiner. This appears as an empty message in Discord. From e4b097edff5caa986760ee91284902139835ee07 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 02:59:52 -0800 Subject: [PATCH 12/27] add database schema and guild setting --- discord-bot/bot/bot.py | 2 +- discord-bot/bot/db/schemas.py | 16 ++++ discord-bot/bot/extensions/guild_settings.py | 83 ++++++++++++++++++++ discord-bot/bot/extensions/text_labels.py | 8 +- discord-bot/bot/utils.py | 18 +++++ 5 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 discord-bot/bot/db/schemas.py create mode 100644 discord-bot/bot/extensions/guild_settings.py diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index e189b765..9b443986 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -26,7 +26,7 @@ async def on_starting(event: hikari.StartingEvent): miru.install(bot) # component handler bot.load_extensions_from("./bot/extensions") # load extensions - bot.d.db = await aiosqlite.connect(":memory:") # TODO: Update + bot.d.db = await aiosqlite.connect("./bot/db/database.db") # TODO: Update await bot.d.db.executescript(open("./bot/db/schema.sql").read()) await bot.d.db.commit() diff --git a/discord-bot/bot/db/schemas.py b/discord-bot/bot/db/schemas.py new file mode 100644 index 00000000..e3ce9032 --- /dev/null +++ b/discord-bot/bot/db/schemas.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +"""Database schemas.""" +from aiosqlite import Row +from pydantic import BaseModel + + +class GuildSettings(BaseModel): + """Guild settings.""" + + guild_id: int + log_channel_id: int | None + + @classmethod + def parse_obj(cls, obj: Row) -> "GuildSettings": + """Deserialize a Row object from aiosqlite into a GuildSettings object.""" + return cls(guild_id=obj[0], log_channel_id=obj[1]) diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py new file mode 100644 index 00000000..8c9cded4 --- /dev/null +++ b/discord-bot/bot/extensions/guild_settings.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +"""Guild settings.""" +import hikari +import lightbulb +from aiosqlite import Connection + +from bot.db.schemas import GuildSettings +from bot.utils import mention + +plugin = lightbulb.Plugin("GuildSettings") +plugin.add_checks(lightbulb.guild_only) +plugin.add_checks(lightbulb.has_guild_permissions(hikari.Permissions.MANAGE_GUILD)) + + +@plugin.command +@lightbulb.command("settings", "Bot settings for the server.") +@lightbulb.implements(lightbulb.SlashCommandGroup) +async def settings(_: lightbulb.SlashContext) -> None: + """Bot settings for the server.""" + # This will never execute because it is a group + pass + + +@settings.child +@lightbulb.command("get", "Get all the guild settings.") +@lightbulb.implements(lightbulb.SlashSubCommand) +async def get(ctx: lightbulb.SlashContext) -> None: + """Get one of or all the guild settings.""" + conn: Connection = ctx.bot.d.db + assert ctx.guild_id is not None # `guild_only` check + + async with conn.cursor() as cursor: + # Get all settings + await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (ctx.guild_id,)) + row = await cursor.fetchone() + + if row is None: + await ctx.respond("No settings found for this guild.") + return + + guild_settings = GuildSettings.parse_obj(row) + + # Respond with all + # TODO: Embed + await ctx.respond( + f"""\ +**Guild Settings** +`log_channel`: { +mention(guild_settings.log_channel_id, "channel") +if guild_settings.log_channel_id else 'not set'} +""" + ) + + +@settings.child +@lightbulb.option("channel", "The channel to use.", hikari.TextableGuildChannel) +@lightbulb.command("log_channel", "Set the channel that the bot logs task and label completions in.") +@lightbulb.implements(lightbulb.SlashSubCommand) +async def log_channel(ctx: lightbulb.SlashContext) -> None: + """Set the channel that the bot logs task and label completions in.""" + channel: hikari.TextableGuildChannel = ctx.options.channel + conn: Connection = ctx.bot.d.db + assert ctx.guild_id is not None # `guild_only` check + + await ctx.respond(f"Setting `log_channel` to {channel.mention}.") + + async with conn.cursor() as cursor: + await cursor.execute( + "INSERT OR REPLACE INTO guild_settings (guild_id, log_channel_id) VALUES (?, ?)", + (ctx.guild_id, channel.id), + ) + + await conn.commit() + + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index ab7d109d..e9d08c86 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -1,11 +1,11 @@ # -*- coding: utf-8 -*- """Hot reload plugin.""" +import typing as t +from datetime import datetime + import hikari import lightbulb import miru -from datetime import datetime - -import typing as t plugin = lightbulb.Plugin( "HotReloadPlugin", @@ -59,7 +59,7 @@ class LabelModal(miru.Modal): .add_field("Total Labeled Message", "0", inline=True) .add_field("Server Ranking", "0/0", inline=True) .add_field("Global Ranking", "0/0", inline=True) - .set_footer(f"Message ID: TODO") + .set_footer("Message ID: TODO") ) channel = await context.bot.rest.fetch_channel(1058299131115872297) assert isinstance(channel, hikari.TextableChannel) diff --git a/discord-bot/bot/utils.py b/discord-bot/bot/utils.py index 1ce99560..03dfea3d 100644 --- a/discord-bot/bot/utils.py +++ b/discord-bot/bot/utils.py @@ -3,6 +3,8 @@ import typing as t from datetime import datetime +import hikari + def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str: """Format a datetime object into the discord time format. @@ -28,3 +30,19 @@ EMPTY = "\u200d" This appears as an empty message in Discord. """ + + +def mention( + id: hikari.Snowflakeish, + type: t.Literal["channel", "role", "user"], +) -> str: + """Mention an object.""" + match type: + case "channel": + return f"<#{id}>" + + case "user": + return f"<@{id}>" + + case "role": + return f"<@&{id}>" From d71ded13644047d46c7c76b8a14f8ad6e55594a2 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 03:15:22 -0800 Subject: [PATCH 13/27] send the task completion message in the guild's configured channel --- discord-bot/bot/db/schemas.py | 13 +++++++- discord-bot/bot/extensions/text_labels.py | 15 +++++++-- discord-bot/bot/extensions/work.py | 37 ++++++++++++++--------- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/discord-bot/bot/db/schemas.py b/discord-bot/bot/db/schemas.py index e3ce9032..efb903b8 100644 --- a/discord-bot/bot/db/schemas.py +++ b/discord-bot/bot/db/schemas.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """Database schemas.""" -from aiosqlite import Row +from aiosqlite import Row, Connection from pydantic import BaseModel +import typing as t class GuildSettings(BaseModel): @@ -14,3 +15,13 @@ class GuildSettings(BaseModel): def parse_obj(cls, obj: Row) -> "GuildSettings": """Deserialize a Row object from aiosqlite into a GuildSettings object.""" return cls(guild_id=obj[0], log_channel_id=obj[1]) + + @classmethod + async def from_db(cls, conn: Connection, guild_id: int) -> t.Optional["GuildSettings"]: + async with conn.cursor() as cursor: + await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,)) + row = await cursor.fetchone() + if row is None: + raise ValueError("No settings found for this guild.") + + return cls.parse_obj(row) diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index e9d08c86..212aa04b 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -6,12 +6,15 @@ from datetime import datetime import hikari import lightbulb import miru +from aiosqlite import Connection plugin = lightbulb.Plugin( - "HotReloadPlugin", + "TextLabels", ) +plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds from bot.utils import EMPTY +from bot.db.schemas import GuildSettings DISCORD_GRAY = 0x2F3136 @@ -48,6 +51,14 @@ class LabelModal(miru.Modal): ) # Send a notification to the log channel + assert context.guild_id is not None # `guild_only` check + conn: Connection = context.bot.d.db # type: ignore + guild_settings = await GuildSettings.from_db(conn, context.guild_id) + + + if guild_settings is None or guild_settings.log_channel_id is None: + return + embed = ( hikari.Embed( title="Message Label", @@ -61,7 +72,7 @@ class LabelModal(miru.Modal): .add_field("Global Ranking", "0/0", inline=True) .set_footer("Message ID: TODO") ) - channel = await context.bot.rest.fetch_channel(1058299131115872297) + channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id) assert isinstance(channel, hikari.TextableChannel) await channel.send(EMPTY, embed=embed) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 86462f87..1c1f38de 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -4,6 +4,7 @@ import asyncio import logging import typing as t from datetime import datetime +from aiosqlite import Connection import hikari import lightbulb @@ -14,6 +15,7 @@ from oasst_shared.schemas.protocol import TaskRequestType from bot.api_client import OasstApiClient, TaskType from bot.utils import EMPTY +from bot.db.schemas import GuildSettings plugin = lightbulb.Plugin("WorkPlugin") @@ -108,21 +110,28 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) # Send a message in the log channel that the task is complete # TODO: Maybe do something with the msg ID - done_embed = ( - hikari.Embed( - title="Task Completion", - description=f"`{task.type}` completed by {ctx.author.mention}", - color=hikari.Color(0x00FF00), - timestamp=datetime.now().astimezone(), + assert ctx.guild_id is not None + conn: Connection = ctx.bot.d.db + guild_settings = await GuildSettings.from_db(conn, ctx.guild_id) + + if guild_settings is not None and guild_settings.log_channel_id is not None: + + channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id) + assert isinstance(channel, hikari.TextableChannel) # option converter + + done_embed = ( + hikari.Embed( + title="Task Completion", + description=f"`{task.type}` completed by {ctx.author.mention}", + color=hikari.Color(0x00FF00), + timestamp=datetime.now().astimezone(), + ) + .add_field("Total Tasks", "0", inline=True) + .add_field("Server Ranking", "0/0", inline=True) + .add_field("Global Ranking", "0/0", inline=True) + .set_footer(f"Task ID: {task.id}") ) - .add_field("Total Tasks", "0", inline=True) - .add_field("Server Ranking", "0/0", inline=True) - .add_field("Global Ranking", "0/0", inline=True) - .set_footer(f"Task ID: {task.id}") - ) - channel = await ctx.bot.rest.fetch_channel(1058299131115872297) - assert isinstance(channel, hikari.TextableChannel) - await channel.send(EMPTY, embed=done_embed) + await channel.send(EMPTY, embed=done_embed) # ask the user if they want to do another task choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME) From 98955441d134ae8a436938e5b890228c884a189c Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 03:21:21 -0800 Subject: [PATCH 14/27] explain text_label logic --- discord-bot/bot/extensions/text_labels.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 212aa04b..325929d2 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -55,7 +55,6 @@ class LabelModal(miru.Modal): conn: Connection = context.bot.d.db # type: ignore guild_settings = await GuildSettings.from_db(conn, context.guild_id) - if guild_settings is None or guild_settings.log_channel_id is None: return @@ -148,6 +147,8 @@ class LabelSelect(miru.View): @lightbulb.implements(lightbulb.MessageCommand) async def label_message_text(ctx: lightbulb.MessageContext): """Label a message.""" + # We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction + # so the select menu will open the modal msg: hikari.Message = ctx.options.target # Exit if the message is empty if not msg.content: From 65c078fb9f8f4ec7f3b43d177ff9b801db6a4526 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 04:12:54 -0800 Subject: [PATCH 15/27] pre-commit changes --- backend/oasst_backend/prompt_repository.py | 9 ++++++++- discord-bot/.env.example | 2 +- discord-bot/.gitignore | 2 +- discord-bot/bot/__init__.py | 2 +- discord-bot/bot/__main__.py | 2 +- discord-bot/bot/bot.py | 4 ++-- discord-bot/bot/config.py | 4 ++-- discord-bot/bot/db/schemas.py | 5 +++-- discord-bot/bot/extensions/_example.py | 1 + discord-bot/bot/extensions/text_labels.py | 5 +++-- discord-bot/bot/extensions/work.py | 6 +++--- discord-bot/dev-requirements.txt | 6 +++--- discord-bot/flake8-requirements.txt | 8 ++++---- discord-bot/message_templates.py | 3 ++- discord-bot/noxfile.py | 2 +- discord-bot/requirements.txt | 18 +++++++++--------- 16 files changed, 45 insertions(+), 34 deletions(-) diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 13c6cd23..0a6c193c 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -7,7 +7,14 @@ import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage +from oasst_backend.models import ( + ApiClient, + Person, + Post, + PostReaction, + TextLabels, + WorkPackage, +) from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session, func diff --git a/discord-bot/.env.example b/discord-bot/.env.example index c518010d..7c414a53 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,4 +1,4 @@ TOKEN= DECLARE_GLOBAL_COMMANDS= OWNER_IDS= -PREFIX="./" \ No newline at end of file +PREFIX="./" diff --git a/discord-bot/.gitignore b/discord-bot/.gitignore index 499012d2..ee1e23f2 100644 --- a/discord-bot/.gitignore +++ b/discord-bot/.gitignore @@ -7,4 +7,4 @@ __pycache__/ .env # Database files -*.db \ No newline at end of file +*.db diff --git a/discord-bot/bot/__init__.py b/discord-bot/bot/__init__.py index 3d04718d..66779a9c 100644 --- a/discord-bot/bot/__init__.py +++ b/discord-bot/bot/__init__.py @@ -1,2 +1,2 @@ -# -*- coding=utf-8 -*- +# -*- coding: utf-8 -*- """The official Open-Assistant Discord Bot.""" diff --git a/discord-bot/bot/__main__.py b/discord-bot/bot/__main__.py index f258d148..87032e40 100644 --- a/discord-bot/bot/__main__.py +++ b/discord-bot/bot/__main__.py @@ -1,4 +1,4 @@ -# -*- coding=utf-8 -*- +# -*- coding: utf-8 -*- """Entry point for the bot.""" import logging import os diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index 9b443986..1f801413 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -1,4 +1,4 @@ -# -*- coding=utf-8 -*- +# -*- coding: utf-8 -*- """Bot logic.""" import aiosqlite import hikari @@ -26,7 +26,7 @@ async def on_starting(event: hikari.StartingEvent): miru.install(bot) # component handler bot.load_extensions_from("./bot/extensions") # load extensions - bot.d.db = await aiosqlite.connect("./bot/db/database.db") # TODO: Update + bot.d.db = await aiosqlite.connect("./bot/db/database.db") await bot.d.db.executescript(open("./bot/db/schema.sql").read()) await bot.d.db.commit() diff --git a/discord-bot/bot/config.py b/discord-bot/bot/config.py index e3addac9..fafcb308 100644 --- a/discord-bot/bot/config.py +++ b/discord-bot/bot/config.py @@ -1,11 +1,11 @@ -# -*- coding=utf-8 -*- +# -*- coding: utf-8 -*- """Configuration for the bot.""" import logging from dataclasses import dataclass from os import getenv -from dotenv import load_dotenv +from dotenv import load_dotenv # type: ignore load_dotenv() diff --git a/discord-bot/bot/db/schemas.py b/discord-bot/bot/db/schemas.py index efb903b8..9c548e9f 100644 --- a/discord-bot/bot/db/schemas.py +++ b/discord-bot/bot/db/schemas.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- """Database schemas.""" -from aiosqlite import Row, Connection -from pydantic import BaseModel import typing as t +from aiosqlite import Connection, Row +from pydantic import BaseModel + class GuildSettings(BaseModel): """Guild settings.""" diff --git a/discord-bot/bot/extensions/_example.py b/discord-bot/bot/extensions/_example.py index 37783e43..76398881 100644 --- a/discord-bot/bot/extensions/_example.py +++ b/discord-bot/bot/extensions/_example.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # TODO: Convert file to markdown # -*- coding: utf-8 -*- """Example plugin for reference. diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 325929d2..48e32763 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -8,13 +8,14 @@ import lightbulb import miru from aiosqlite import Connection +from bot.db.schemas import GuildSettings +from bot.utils import EMPTY + plugin = lightbulb.Plugin( "TextLabels", ) plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds -from bot.utils import EMPTY -from bot.db.schemas import GuildSettings DISCORD_GRAY = 0x2F3136 diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 1c1f38de..ecbe1710 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -4,18 +4,18 @@ import asyncio import logging import typing as t from datetime import datetime -from aiosqlite import Connection import hikari import lightbulb import lightbulb.decorators import miru +from aiosqlite import Connection from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TaskRequestType from bot.api_client import OasstApiClient, TaskType -from bot.utils import EMPTY from bot.db.schemas import GuildSettings +from bot.utils import EMPTY plugin = lightbulb.Plugin("WorkPlugin") @@ -76,7 +76,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) ) except asyncio.TimeoutError: await ctx.author.send("Task timed out. Exiting") - # TODO: NACK task maybe? + await oasst_api.nack_task(task.id, reason="timed out") return # Invalid response diff --git a/discord-bot/dev-requirements.txt b/discord-bot/dev-requirements.txt index 44a8d2cc..56393fdd 100644 --- a/discord-bot/dev-requirements.txt +++ b/discord-bot/dev-requirements.txt @@ -1,8 +1,8 @@ -nox black -isort codespell flake8 -pyright \ No newline at end of file +isort +nox +pyright diff --git a/discord-bot/flake8-requirements.txt b/discord-bot/flake8-requirements.txt index 3509207e..a022d8c5 100644 --- a/discord-bot/flake8-requirements.txt +++ b/discord-bot/flake8-requirements.txt @@ -1,8 +1,4 @@ flake8==6.0.0 - -# Plugins - -Flake8-pyproject # use the pyproject.toml as the config file flake8-bandit # runs bandit flake8-black # runs black # flake8-broken-line # forbey "\" linebreaks @@ -21,6 +17,10 @@ flake8-mutable # mutable default argument detection flake8-pep3101 # new-style format strings only flake8-print # complain about print statements in code flake8-printf-formatting # forbey printf-style python2 string formatting + +# Plugins + +Flake8-pyproject # use the pyproject.toml as the config file flake8-pytest-style # pytest checks flake8-raise # exception raising linting flake8-use-fstring # format string checking diff --git a/discord-bot/message_templates.py b/discord-bot/message_templates.py index dcb84c94..256f93d3 100644 --- a/discord-bot/message_templates.py +++ b/discord-bot/message_templates.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- """Message templates for the discord bot.""" -import jinja2 import typing + +import jinja2 from loguru import logger diff --git a/discord-bot/noxfile.py b/discord-bot/noxfile.py index f85fc60c..891e87fb 100644 --- a/discord-bot/noxfile.py +++ b/discord-bot/noxfile.py @@ -1,4 +1,4 @@ -# -*- coding=utf-8 -*- +# -*- coding: utf-8 -*- """Automated linting, formatting, and typechecking.""" import nox from nox.sessions import Session diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 17348c12..62b9b931 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -1,11 +1,11 @@ -hikari # discord framework -hikari[speedups] -uvloop; os_name != 'nt' -hikari-lightbulb # command handler -hikari-miru # modals and buttons - -python-dotenv # .env file support -aiosqlite # database aiohttp # http client aiohttp[speedups] # speedups for aiohttp -loguru \ No newline at end of file +aiosqlite # database +hikari # discord framework +hikari-lightbulb # command handler +hikari-miru # modals and buttons +hikari[speedups] +loguru + +python-dotenv # .env file support +uvloop; os_name != 'nt' From b81eeebe9e578f03a2b1b2975409a79a479979ee Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 04:35:45 -0800 Subject: [PATCH 16/27] switch to using pre-commit --- backend/oasst_backend/prompt_repository.py | 9 +--- discord-bot/CONTRIBUTING.md | 10 ++-- discord-bot/bot/bot.py | 1 - .../extensions/{_example.py => EXAMPLES.md} | 12 ++--- discord-bot/bot/extensions/guild_settings.py | 1 - discord-bot/bot/extensions/tasks.py | 3 +- discord-bot/bot/extensions/text_labels.py | 1 - discord-bot/bot/extensions/work.py | 5 +- discord-bot/dev-requirements.txt | 8 ---- discord-bot/flake8-requirements.txt | 26 ---------- discord-bot/noxfile.py | 34 -------------- discord-bot/pyproject.toml | 47 ------------------- 12 files changed, 14 insertions(+), 143 deletions(-) rename discord-bot/bot/extensions/{_example.py => EXAMPLES.md} (98%) delete mode 100644 discord-bot/dev-requirements.txt delete mode 100644 discord-bot/flake8-requirements.txt delete mode 100644 discord-bot/noxfile.py delete mode 100644 discord-bot/pyproject.toml diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 0a6c193c..13c6cd23 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -7,14 +7,7 @@ import oasst_backend.models.db_payload as db_payload from loguru import logger from oasst_backend.exceptions import OasstError, OasstErrorCode from oasst_backend.journal_writer import JournalWriter -from oasst_backend.models import ( - ApiClient, - Person, - Post, - PostReaction, - TextLabels, - WorkPackage, -) +from oasst_backend.models import ApiClient, Person, Post, PostReaction, TextLabels, WorkPackage from oasst_backend.models.payload_column_type import PayloadContainer from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session, func diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md index d4d8ad3b..44484354 100644 --- a/discord-bot/CONTRIBUTING.md +++ b/discord-bot/CONTRIBUTING.md @@ -13,12 +13,12 @@ pip install -r requirements.txt python -m bot ``` -To test the bot +Before you push, make sure the `pre-commit` hooks are installed and run successfully. ``` -python -m pip install -r dev-requirements.txt - -nox +pip install pre-commit +pre-commit install +pre-commit run --all-files ``` To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. @@ -81,7 +81,7 @@ def unload(bot: lightbulb.BotApp): bot.remove_plugin(plugin) ``` -For example commands and listeners, see [here](/discord-bot/bot/extensions/_example.py) +For example commands and listeners, see [EXAMPLES.md](/discord-bot/EXAMPLES.md) ### Docs diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index 1f801413..a328300a 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -4,7 +4,6 @@ import aiosqlite import hikari import lightbulb import miru - from bot.api_client import OasstApiClient from bot.config import Config diff --git a/discord-bot/bot/extensions/_example.py b/discord-bot/bot/extensions/EXAMPLES.md similarity index 98% rename from discord-bot/bot/extensions/_example.py rename to discord-bot/bot/extensions/EXAMPLES.md index 76398881..f031cd72 100644 --- a/discord-bot/bot/extensions/_example.py +++ b/discord-bot/bot/extensions/EXAMPLES.md @@ -1,11 +1,8 @@ -# -*- coding: utf-8 -*- -# TODO: Convert file to markdown -# -*- coding: utf-8 -*- -"""Example plugin for reference. +# `hikari`, `lightbulb`, and `muri` examples -Because this file starts with an `_`, it cannot be loaded by the bot. -To see the example plugin in action, rename this file to `example.py`. -""" +Example plugin for reference. + +````py import asyncio import hikari @@ -411,3 +408,4 @@ def load(bot: lightbulb.BotApp): def unload(bot: lightbulb.BotApp): """Remove the plugin to the bot.""" bot.remove_plugin(plugin) +```` diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py index 8c9cded4..5623cd5a 100644 --- a/discord-bot/bot/extensions/guild_settings.py +++ b/discord-bot/bot/extensions/guild_settings.py @@ -3,7 +3,6 @@ import hikari import lightbulb from aiosqlite import Connection - from bot.db.schemas import GuildSettings from bot.utils import mention diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py index 71f47f52..70fa5257 100644 --- a/discord-bot/bot/extensions/tasks.py +++ b/discord-bot/bot/extensions/tasks.py @@ -9,9 +9,8 @@ import hikari import lightbulb import lightbulb.decorators import miru -from oasst_shared.schemas.protocol import TaskRequestType - from bot.utils import format_time +from oasst_shared.schemas.protocol import TaskRequestType plugin = lightbulb.Plugin("TaskPlugin") diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 48e32763..1f278ca4 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -7,7 +7,6 @@ import hikari import lightbulb import miru from aiosqlite import Connection - from bot.db.schemas import GuildSettings from bot.utils import EMPTY diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index ecbe1710..ba71f41b 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -10,12 +10,11 @@ import lightbulb import lightbulb.decorators import miru from aiosqlite import Connection -from oasst_shared.schemas import protocol as protocol_schema -from oasst_shared.schemas.protocol import TaskRequestType - from bot.api_client import OasstApiClient, TaskType from bot.db.schemas import GuildSettings from bot.utils import EMPTY +from oasst_shared.schemas import protocol as protocol_schema +from oasst_shared.schemas.protocol import TaskRequestType plugin = lightbulb.Plugin("WorkPlugin") diff --git a/discord-bot/dev-requirements.txt b/discord-bot/dev-requirements.txt deleted file mode 100644 index 56393fdd..00000000 --- a/discord-bot/dev-requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ - -black - -codespell -flake8 -isort -nox -pyright diff --git a/discord-bot/flake8-requirements.txt b/discord-bot/flake8-requirements.txt deleted file mode 100644 index a022d8c5..00000000 --- a/discord-bot/flake8-requirements.txt +++ /dev/null @@ -1,26 +0,0 @@ -flake8==6.0.0 -flake8-bandit # runs bandit -flake8-black # runs black -# flake8-broken-line # forbey "\" linebreaks -flake8-builtins # builtin shadowing checks -flake8-coding # coding magic-comment detection -flake8-comprehensions # comprehension checks -flake8-deprecated # deprecated call checks -flake8-docstrings # pydocstyle support -flake8-executable # shebangs -flake8-fixme # "fix me" counter -flake8-functions # function linting -flake8-html # html output -flake8-if-statements # condition linting -flake8-isort # runs isort -flake8-mutable # mutable default argument detection -flake8-pep3101 # new-style format strings only -flake8-print # complain about print statements in code -flake8-printf-formatting # forbey printf-style python2 string formatting - -# Plugins - -Flake8-pyproject # use the pyproject.toml as the config file -flake8-pytest-style # pytest checks -flake8-raise # exception raising linting -flake8-use-fstring # format string checking diff --git a/discord-bot/noxfile.py b/discord-bot/noxfile.py deleted file mode 100644 index 891e87fb..00000000 --- a/discord-bot/noxfile.py +++ /dev/null @@ -1,34 +0,0 @@ -# -*- coding: utf-8 -*- -"""Automated linting, formatting, and typechecking.""" -import nox -from nox.sessions import Session - - -@nox.session(reuse_venv=True) -def format_code(session: Session): - """Format the codebase.""" - session.install("isort", "-U") - session.install("black", "-U") - - session.run("isort", "bot") - session.run("black", "bot") - - -@nox.session(reuse_venv=True) -def lint_code(session: Session): - """Lint the codebase.""" - session.install("codespell", "-U") - session.install("flake8", "-U") - session.install("-r", "flake8-requirements.txt", "-U") - - session.run("codespell", "bot") - session.run("flake8", "bot") - - -@nox.session(reuse_venv=True) -def typecheck_code(session: Session): - """Typecheck the codebase.""" - session.install("-r", "requirements.txt", "-U") - session.install("pyright", "-U") - - session.run("pyright", "bot") diff --git a/discord-bot/pyproject.toml b/discord-bot/pyproject.toml deleted file mode 100644 index 7a1e8d82..00000000 --- a/discord-bot/pyproject.toml +++ /dev/null @@ -1,47 +0,0 @@ -[project] -name = "Open-Assistant Discord Bot" -version = "0.0.1" - -[tool.black] -line-length = 120 -target-version = ["py310"] - -[tool.pyright] -include = ["ottbot", "noxfile.py"] -pythonVersion="3.10" -reportMissingImports=false -# reportInvalidTypeVarUse=false -# reportMissingModuleSource=false -reportUnknownVariableType=false -pythonPlatform="Linux" - -[tool.isort] -profile="black" -sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER'] -skip_glob = "**/__init__.pyi" - -[tool.flake8] -max-function-length = 130 -max-line-length = 130 -# Technically this is 120, but black has a policy of "1 or 2 over is fine if it is tidier", so we have to raise this. -accept-encodings = "utf-8" -docstring-convention = "numpy" -ignore = [ - "A002", # Argument is shadowing a python builtin. - "A003", # Class attribute is shadowing a python builtin. - "CFQ002", # Function has too many arguments. - "CFQ004", # Function has too many returns. - "D001", # False positive for depreciated functions. - "D102", # Missing docstring in public method. - "D105", # Magic methods not having a docstring. - "D412", # No blank lines allowed between a section header and its content - "E203", # Whitespace after : (to match how black formats it) - "E402", # Module level import not at top of file (isn't compatible with our import style). - "T101", # TO-DO comment detection (T102 is FIX-ME and T103 is XXX). - "W503", # line break before binary operator. - "W504", # line break before binary operator (again, I guess). - "S101", # Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. - "S105", # Possible hardcoded password. - "EXE002", # Executable file with not shebang - "D401", # Imperative mood -] From 84d52effee129d8232aba2a40cc6c6f7c597d355 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:11:36 -0800 Subject: [PATCH 17/27] remove address todo comments --- discord-bot/bot/extensions/EXAMPLES.md | 3 -- discord-bot/bot/extensions/text_labels.py | 1 - discord-bot/bot/extensions/work.py | 41 ++++++++++++++++++++--- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/discord-bot/bot/extensions/EXAMPLES.md b/discord-bot/bot/extensions/EXAMPLES.md index f031cd72..29598fde 100644 --- a/discord-bot/bot/extensions/EXAMPLES.md +++ b/discord-bot/bot/extensions/EXAMPLES.md @@ -396,9 +396,6 @@ async def modal_example(ctx: lightbulb.SlashContext) -> None: await view.start(await resp.message()) -# TODO: Database example -# TODO: Rest client example - def load(bot: lightbulb.BotApp): """Add the plugin to the bot.""" diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 1f278ca4..53d0a1fd 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -69,7 +69,6 @@ class LabelModal(miru.Modal): .add_field("Total Labeled Message", "0", inline=True) .add_field("Server Ranking", "0/0", inline=True) .add_field("Global Ranking", "0/0", inline=True) - .set_footer("Message ID: TODO") ) channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id) assert isinstance(channel, hikari.TextableChannel) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index ba71f41b..8e3ad7b5 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -31,7 +31,7 @@ logger.setLevel(logging.DEBUG) "The type of task to request.", choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType], required=False, - default=str(TaskRequestType.rank_initial_prompts), # TODO: change back to random + default=str(TaskRequestType.random), type=str, ) @lightbulb.command("work", "Complete a task.") @@ -79,11 +79,11 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) return # Invalid response - if event.content is None: - await ctx.author.send("No content in message") + if event.content is None or not _validate_user_input(event.content, task.type): + await ctx.author.send("Invalid response") continue - logger.info(f"User input received: {event.content}") + logger.info(f"Successful user input received: {event.content}") # Send the response to the backend reply = protocol_schema.TextReplyToPost( @@ -108,7 +108,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) logger.fatal(f"Unexpected task type received: {new_task.type}") # Send a message in the log channel that the task is complete - # TODO: Maybe do something with the msg ID + # TODO: Maybe do something with the msg ID so users can rate the "answer" assert ctx.guild_id is not None conn: Connection = ctx.bot.d.db guild_settings = await GuildSettings.from_db(conn, ctx.guild_id) @@ -252,6 +252,37 @@ async def _send_task( return view.choice, str(msg.id) +# TODO check what the backend expects +def _validate_user_input(content: str | None, task_type: str) -> bool: + """Returns whether the user's input is valid for the task type.""" + if content is None: + return False + + if ( + task_type == TaskRequestType.initial_prompt + or task_type == TaskRequestType.user_reply + or task_type == TaskRequestType.assistant_reply + ): + return len(content) > 0 + + elif ( + task_type == TaskRequestType.rank_initial_prompts + or task_type == TaskRequestType.rank_user_replies + or task_type == TaskRequestType.rank_assistant_replies + ): + rankings = [int(r) for r in content.split(",")] + return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5 + + elif task_type == TaskRequestType.summarize_story: + raise NotImplementedError + elif task_type == TaskRequestType.rate_summary: + raise NotImplementedError + + else: + logger.fatal(f"Unknown task type {task_type}") + raise ValueError(f"Unknown task type {task_type}") + + class TaskAcceptView(miru.View): """View with three buttons: accept, next, and cancel. From 708011e6a0f65c919c86f960e8da786765b995e6 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:14:16 -0800 Subject: [PATCH 18/27] move EXAMPLEs.md --- discord-bot/{bot/extensions => }/EXAMPLES.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord-bot/{bot/extensions => }/EXAMPLES.md (100%) diff --git a/discord-bot/bot/extensions/EXAMPLES.md b/discord-bot/EXAMPLES.md similarity index 100% rename from discord-bot/bot/extensions/EXAMPLES.md rename to discord-bot/EXAMPLES.md From 6cccd74e3491efec03b4ffed5841ec510cae0e50 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:28:51 -0800 Subject: [PATCH 19/27] switch to loguru --- discord-bot/bot/extensions/guild_settings.py | 3 ++ discord-bot/bot/extensions/hot_reload.py | 5 +++- discord-bot/bot/extensions/tasks.py | 4 +-- discord-bot/bot/extensions/text_labels.py | 4 +++ discord-bot/bot/extensions/work.py | 30 +++++++++----------- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py index 5623cd5a..f5785b8d 100644 --- a/discord-bot/bot/extensions/guild_settings.py +++ b/discord-bot/bot/extensions/guild_settings.py @@ -5,6 +5,7 @@ import lightbulb from aiosqlite import Connection from bot.db.schemas import GuildSettings from bot.utils import mention +from loguru import logger plugin = lightbulb.Plugin("GuildSettings") plugin.add_checks(lightbulb.guild_only) @@ -34,6 +35,7 @@ async def get(ctx: lightbulb.SlashContext) -> None: row = await cursor.fetchone() if row is None: + logger.warning(f"No guild settings for {ctx.guild_id}") await ctx.respond("No settings found for this guild.") return @@ -70,6 +72,7 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None: ) await conn.commit() + logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.") def load(bot: lightbulb.BotApp): diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py index 28bcede3..ad2cd730 100644 --- a/discord-bot/bot/extensions/hot_reload.py +++ b/discord-bot/bot/extensions/hot_reload.py @@ -4,6 +4,7 @@ from glob import glob import hikari import lightbulb +from loguru import logger plugin = lightbulb.Plugin( "HotReloadPlugin", @@ -37,7 +38,7 @@ async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikar required=False, default=None, ) -@lightbulb.command("reload", "Reload a plugin") +@lightbulb.command("reload", "Reload a plugin", ephemeral=True) @lightbulb.implements(lightbulb.SlashCommand) async def reload(ctx: lightbulb.SlashContext): """Reload a plugin or all plugins.""" @@ -45,10 +46,12 @@ async def reload(ctx: lightbulb.SlashContext): if ctx.options.plugin is None: ctx.bot.reload_extensions(*_get_extensions()) await ctx.respond("Reloaded all plugins.") + logger.info("Reloaded all plugins.") # Otherwise, reload the specified plugin. else: ctx.bot.reload_extensions(ctx.options.plugin) await ctx.respond(f"Reloaded `{ctx.options.plugin}`.") + logger.info(f"Reloaded `{ctx.options.plugin}`.") def load(bot: lightbulb.BotApp): diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py index 70fa5257..94ddb973 100644 --- a/discord-bot/bot/extensions/tasks.py +++ b/discord-bot/bot/extensions/tasks.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """Task plugin for testing different data collection methods.""" +# TODO: Delete this once user input method has been decided for final bot. import asyncio -import logging import typing as t from datetime import datetime, timedelta @@ -16,8 +16,6 @@ plugin = lightbulb.Plugin("TaskPlugin") MAX_TASK_TIME = 60 * 60 MAX_TASK_ACCEPT_TIME = 60 -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) @plugin.command diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 53d0a1fd..618e6642 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -9,6 +9,7 @@ import miru from aiosqlite import Connection from bot.db.schemas import GuildSettings from bot.utils import EMPTY +from loguru import logger plugin = lightbulb.Plugin( "TextLabels", @@ -49,6 +50,7 @@ class LabelModal(miru.Modal): f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.", flags=hikari.MessageFlag.EPHEMERAL, ) + logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.") # Send a notification to the log channel assert context.guild_id is not None # `guild_only` check @@ -56,6 +58,7 @@ class LabelModal(miru.Modal): guild_settings = await GuildSettings.from_db(conn, context.guild_id) if guild_settings is None or guild_settings.log_channel_id is None: + logger.warning(f"No guild settings or log channel for guild {context.guild_id}") return embed = ( @@ -148,6 +151,7 @@ async def label_message_text(ctx: lightbulb.MessageContext): """Label a message.""" # We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction # so the select menu will open the modal + msg: hikari.Message = ctx.options.target # Exit if the message is empty if not msg.content: diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 8e3ad7b5..5244920b 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """Work plugin for collecting user data.""" import asyncio -import logging import typing as t from datetime import datetime @@ -13,6 +12,7 @@ from aiosqlite import Connection from bot.api_client import OasstApiClient, TaskType from bot.db.schemas import GuildSettings from bot.utils import EMPTY +from loguru import logger from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TaskRequestType @@ -21,9 +21,6 @@ plugin = lightbulb.Plugin("WorkPlugin") MAX_TASK_TIME = 60 * 60 # 1 hour MAX_TASK_ACCEPT_TIME = 60 # 1 minute -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - @plugin.command @lightbulb.option( @@ -41,7 +38,7 @@ async def work(ctx: lightbulb.SlashContext): task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) - logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}") + logger.debug(f"Starting task_type: {task_type!r}") await _handle_task(ctx, task_type) @@ -76,6 +73,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) except asyncio.TimeoutError: await ctx.author.send("Task timed out. Exiting") await oasst_api.nack_task(task.id, reason="timed out") + logger.info(f"Task {task.id} timed out") return # Invalid response @@ -83,7 +81,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) await ctx.author.send("Invalid response") continue - logger.info(f"Successful user input received: {event.content}") + logger.debug(f"Successful user input received: {event.content}") # Send the response to the backend reply = protocol_schema.TextReplyToPost( @@ -105,7 +103,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) completed = True continue else: - logger.fatal(f"Unexpected task type received: {new_task.type}") + logger.critical(f"Unexpected task type received: {new_task.type}") # Send a message in the log channel that the task is complete # TODO: Maybe do something with the msg ID so users can rate the "answer" @@ -159,7 +157,7 @@ async def _select_task( task = await oasst_api.fetch_task(task_type, user) resp, msg_id = await _send_task(ctx, task) - logger.debug(f"user choice: {resp}") + logger.debug(f"User choice: {resp}") match resp: case "accept": logger.info(f"Task {task.id} accepted, sending ACK") @@ -200,32 +198,32 @@ async def _send_task( # Create an embed based on the task's type if task.type == TaskRequestType.initial_prompt: assert isinstance(task, protocol_schema.InitialPromptTask) - logger.info("sending initial prompt task") + logger.debug("sending initial prompt task") embed = _initial_prompt_embed(task) elif task.type == TaskRequestType.rank_initial_prompts: assert isinstance(task, protocol_schema.RankInitialPromptsTask) - logger.info("sending rank initial prompt task") + logger.debug("sending rank initial prompt task") embed = _rank_initial_prompt_embed(task) elif task.type == TaskRequestType.rank_user_replies: assert isinstance(task, protocol_schema.RankUserRepliesTask) - logger.info("sending rank user reply task") + logger.debug("sending rank user reply task") embed = _rank_user_reply_embed(task) elif task.type == TaskRequestType.rank_assistant_replies: assert isinstance(task, protocol_schema.RankAssistantRepliesTask) - logger.info("sending rank assistant reply task") + logger.debug("sending rank assistant reply task") embed = _rank_assistant_reply_embed(task) elif task.type == TaskRequestType.user_reply: assert isinstance(task, protocol_schema.UserReplyTask) - logger.info("sending user reply task") + logger.debug("sending user reply task") embed = _user_reply_embed(task) elif task.type == TaskRequestType.assistant_reply: assert isinstance(task, protocol_schema.AssistantReplyTask) - logger.info("sending assistant reply task") + logger.debug("sending assistant reply task") embed = _assistant_reply_embed(task) elif task.type == TaskRequestType.summarize_story: @@ -234,7 +232,7 @@ async def _send_task( raise NotImplementedError else: - logger.error(f"unknown task type {task.type}") + logger.critical(f"unknown task type {task.type}") raise ValueError(f"unknown task type {task.type}") view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) @@ -279,7 +277,7 @@ def _validate_user_input(content: str | None, task_type: str) -> bool: raise NotImplementedError else: - logger.fatal(f"Unknown task type {task_type}") + logger.critical(f"Unknown task type {task_type}") raise ValueError(f"Unknown task type {task_type}") From 150fc67bfdf417a38488e6266f6f9cd85c966a28 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:48:56 -0800 Subject: [PATCH 20/27] switch to using Pydantic for bot config --- discord-bot/.env.example | 2 +- discord-bot/bot/bot.py | 12 ++++++------ discord-bot/bot/config.py | 37 ------------------------------------ discord-bot/bot/settings.py | 15 +++++++++++++++ discord-bot/requirements.txt | 4 ++-- 5 files changed, 24 insertions(+), 46 deletions(-) delete mode 100644 discord-bot/bot/config.py create mode 100644 discord-bot/bot/settings.py diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 7c414a53..d32e80d1 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,4 +1,4 @@ TOKEN= DECLARE_GLOBAL_COMMANDS= -OWNER_IDS= +OWNER_IDS=[, ] PREFIX="./" diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index a328300a..2cf3d663 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -5,16 +5,16 @@ import hikari import lightbulb import miru from bot.api_client import OasstApiClient -from bot.config import Config +from bot.settings import Settings -config = Config.from_env() +settings = Settings() bot = lightbulb.BotApp( - token=config.token, + token=settings.token, logs="DEBUG", - prefix=config.prefix, - default_enabled_guilds=config.declare_global_commands, - owner_ids=config.owner_ids, + prefix=settings.prefix, + default_enabled_guilds=settings.declare_global_commands, + owner_ids=settings.owner_ids, intents=hikari.Intents.ALL, ) diff --git a/discord-bot/bot/config.py b/discord-bot/bot/config.py deleted file mode 100644 index fafcb308..00000000 --- a/discord-bot/bot/config.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -"""Configuration for the bot.""" - -import logging -from dataclasses import dataclass -from os import getenv - -from dotenv import load_dotenv # type: ignore - -load_dotenv() - -logger = logging.getLogger(__name__) - - -@dataclass -class Config: - """Configuration for the bot.""" - - token: str - declare_global_commands: int - owner_ids: list[int] - prefix: str - - @classmethod - def from_env(cls): - token = getenv("TOKEN", None) - - if token is None: - logger.error("Invalid token, please set the TOKEN environment variable.") - exit(1) - - return cls( - token=token, - declare_global_commands=int(getenv("DECLARE_GLOBAL_COMMANDS", 0)), - owner_ids=[int(x) for x in getenv("OWNER_IDS", "").split(",")], - prefix=getenv("PREFIX", "./"), - ) diff --git a/discord-bot/bot/settings.py b/discord-bot/bot/settings.py new file mode 100644 index 00000000..41c6ae52 --- /dev/null +++ b/discord-bot/bot/settings.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +"""Configuration for the bot.""" +from pydantic import BaseSettings, Field + + +class Settings(BaseSettings): + """Settings for the bot.""" + + token: str = Field(env="TOKEN", default="") + declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0) + owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list) + prefix: str = Field(env="PREFIX", default="./") + + class Config(BaseSettings.Config): + env_file = ".env" diff --git a/discord-bot/requirements.txt b/discord-bot/requirements.txt index 62b9b931..372bbd59 100644 --- a/discord-bot/requirements.txt +++ b/discord-bot/requirements.txt @@ -6,6 +6,6 @@ hikari-lightbulb # command handler hikari-miru # modals and buttons hikari[speedups] loguru +pydantic -python-dotenv # .env file support -uvloop; os_name != 'nt' +uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop From aff3f18b07c7312dcc7112a0102694181e843202 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:56:04 -0800 Subject: [PATCH 21/27] document code structure --- discord-bot/CONTRIBUTING.md | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md index 44484354..33b2d435 100644 --- a/discord-bot/CONTRIBUTING.md +++ b/discord-bot/CONTRIBUTING.md @@ -30,31 +30,32 @@ To test the bot on your own discord server you need to register a discord applic ### Structure +Important files + ```graphql .env # Environment variables .env.example # Example environment variables CONTRIBUTING.md # This file -dev-requirements.txt # Development requirements -flake8-requirements.txt # Flake8 extensions (for linting) -noxfile.py # Nox session definitions (for formatting, typechecking, linting) -pyproject.toml # Project configuration README.md # Project readme +EXAMPLES.md # Examples for commands and listeners requirements.txt # Requirements -templates/ # Message templates bot/ -├─ __init__.py -├─ __main__.py # Entrypoint -├─ bot.py # Main bot class -├─ config.py # Configuration and secrets -├─ utils.py # Utility Functions +├─ __main__.py # Entrypoint +├─ api_client.py # API Client for interacting with the backend +├─ bot.py # Main bot class +├─ settings.py # Settings and secrets +├─ utils.py # Utility Functions │ -├─ db/ # Database related code -│ ├─ database.db # SQLite database -│ └─ schema.sql # Database schema +├─ db/ # Database related code +│ ├─ database.db # SQLite database +│ ├─ schema.sql # SQL schema +│ └─ schemas.py # Python table schemas │ -└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html - └─ hot_reload.py # Utility for hot reload extension +└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html + ├─ work.py # Task handling logic <-- most important file + ├─ guild_settings.py # Server specific settings + └─ hot_reload.py # Utility for hot reload extensions during development ``` ### Adding a new command/listener From 0fb7bfd27ac3765379b387f8b0c5f039f4ddbbc8 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 17:30:56 -0800 Subject: [PATCH 22/27] Add more settings and refactor other md files into README.md --- discord-bot/.env.example | 3 + discord-bot/CONTRIBUTING.md | 105 ---------- discord-bot/EXAMPLES.md | 408 ------------------------------------ discord-bot/README.md | 106 +++++++++- discord-bot/bot/bot.py | 2 +- discord-bot/bot/settings.py | 3 + 6 files changed, 111 insertions(+), 516 deletions(-) delete mode 100644 discord-bot/CONTRIBUTING.md delete mode 100644 discord-bot/EXAMPLES.md diff --git a/discord-bot/.env.example b/discord-bot/.env.example index d32e80d1..4fcb23b3 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -2,3 +2,6 @@ TOKEN= DECLARE_GLOBAL_COMMANDS= OWNER_IDS=[, ] PREFIX="./" + +OASST_API_URL="http://localhost:8080" # No trailing '/' +OASST_API_KEY="" diff --git a/discord-bot/CONTRIBUTING.md b/discord-bot/CONTRIBUTING.md deleted file mode 100644 index 33b2d435..00000000 --- a/discord-bot/CONTRIBUTING.md +++ /dev/null @@ -1,105 +0,0 @@ -# Contributing - -## Setup - -To run the bot - -``` -cp .env.example .env - -python -V # 3.10 - -pip install -r requirements.txt -python -m bot -``` - -Before you push, make sure the `pre-commit` hooks are installed and run successfully. - -``` -pip install pre-commit -pre-commit install -pre-commit run --all-files -``` - -To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. - -1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) -2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable. - -## Resources - -### Structure - -Important files - -```graphql -.env # Environment variables -.env.example # Example environment variables -CONTRIBUTING.md # This file -README.md # Project readme -EXAMPLES.md # Examples for commands and listeners -requirements.txt # Requirements - -bot/ -├─ __main__.py # Entrypoint -├─ api_client.py # API Client for interacting with the backend -├─ bot.py # Main bot class -├─ settings.py # Settings and secrets -├─ utils.py # Utility Functions -│ -├─ db/ # Database related code -│ ├─ database.db # SQLite database -│ ├─ schema.sql # SQL schema -│ └─ schemas.py # Python table schemas -│ -└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html - ├─ work.py # Task handling logic <-- most important file - ├─ guild_settings.py # Server specific settings - └─ hot_reload.py # Utility for hot reload extensions during development -``` - -### Adding a new command/listener - -1. Create a new file in the `extensions` folder -2. Copy the template below - -```py -# -*- coding: utf-8 -*- -"""My plugin.""" -import lightbulb - -plugin = lightbulb.Plugin("MyPlugin") - -# Add your commands here - -def load(bot: lightbulb.BotApp): - """Add the plugin to the bot.""" - bot.add_plugin(plugin) - - -def unload(bot: lightbulb.BotApp): - """Remove the plugin to the bot.""" - bot.remove_plugin(plugin) -``` - -For example commands and listeners, see [EXAMPLES.md](/discord-bot/EXAMPLES.md) - -### Docs - -Discord - -- [Discord API Reference](https://discord.com/developers/docs/intro) - -Main framework - -- [Hikari Repo](https://github.com/hikari-py/hikari) -- [Hikari Docs](https://docs.hikari-py.dev/en/latest/) - -Command handler - -- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb) -- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/) - -Component handler (buttons, modals, etc... ) - -- [Miru Repo](https://github.com/HyperGH/hikari-miru) diff --git a/discord-bot/EXAMPLES.md b/discord-bot/EXAMPLES.md deleted file mode 100644 index 29598fde..00000000 --- a/discord-bot/EXAMPLES.md +++ /dev/null @@ -1,408 +0,0 @@ -# `hikari`, `lightbulb`, and `muri` examples - -Example plugin for reference. - -````py -import asyncio - -import hikari -import lightbulb -import lightbulb.decorators -import miru -from miru.ext import nav - -plugin = lightbulb.Plugin("ExamplePlugin") - -# To add checks to a plugin, you can use the `@plugin.check` decorator -# or the `plugin.add_check` method. Lightbulb has some built-in checks. -# The check will be called before any command in the plugin is called. -plugin.add_checks(lightbulb.guild_only) - - -# To create a slash command, use the template below -@plugin.command -@lightbulb.command("example", "Example command.") -@lightbulb.implements(lightbulb.SlashCommand) -async def example(ctx: lightbulb.SlashContext): - """Example command.""" - # To send a message, use the `respond` method on `ctx`. - # !!! Be sure to use `await` when calling `respond` !!! - await ctx.respond("Hello, world!") - - -# To add arguments, use the `@lightbulb.option` decorator. -@plugin.command -@lightbulb.option( - "name", # The name of the option. This is what you will use to access the value in `ctx.options.name` - "Your name.", # The description of the option. This will be shown in the slash command menu. - # Whether or not the option is required. - # If `required` is `True`, the user will not be able to use the command without providing a value for this option. - required=False, - default=None, # The default value for the option. If `required` is `True`, this will be ignored. - type=str | None, # The type of the option. This is used to convert the value to the correct type. - # https://hikari-lightbulb.readthedocs.io/en/latest/guides/commands.html#converters-and-slash-command-option-types -) -@lightbulb.option( - "age", - "Your age.", - type=int, - # These are enforced on the client side, so the user won't be able to enter a value outside of the range. - min_value=0, - max_value=100, -) -@lightbulb.option( - "gender", - "Your gender.", - # You can also use `choices` to limit the user to a specific set of values. - # This can be a list of `str`, `int, or `float` - # choices=["Male", "Female", "Other"], - # or a list of `hikari.CommandChoice` objects to have separate option names and values - choices=[ - hikari.CommandChoice(name="male", value="M"), - hikari.CommandChoice(name="female", value="F"), - hikari.CommandChoice(name="other", value="Other"), - ], - type=str, -) -@lightbulb.command("args_example", "Example command with arguments.") -@lightbulb.implements(lightbulb.SlashCommand) -async def args_example(ctx: lightbulb.SlashContext): - """Example command with arguments.""" - name: str | None = ctx.options.name - if name is None: - name = ctx.author.username - age: int = ctx.options.age - gender: str = ctx.options.gender - - await ctx.respond( - f"Hello {ctx.author.mention}! Your name is {name}, you are {age} years old, and your gender is {gender}.", - # in order to actually mention the user, you must pass `user_mentions=True` - # otherwise, the user won't get a notification - user_mentions=True, - ) - - -# To have autocomplete options, add the -# pass `autocomplete=function` to `@lightbulb.option` -# or `autocomplete=True` and mark the function with `@command.autocomplete("option_name")`. -# @autocomplete_example.autocomplete("language") -async def _programming_language_autocomplete( - option: hikari.CommandInteractionOption, interaction: hikari.AutocompleteInteraction -) -> list[str]: - # The `option` argument is the current text that the user typed in. - if not isinstance(option.value, str): - # This will raise a TypeError if `option.value` cannot be converted - option.value = str(option.value) - - # You can query a database, fetch an api, or return any list of strings - # !!! You can return a max of 25 options !!! - langs = [ - "C", - "C++", - "C#", - "CSS", - "Go", - "HTML", - "Java", - "Javascript", - "Kotlin", - "Matlab", - "NoSQL", - "PHP", - "Perl", - "Python", - "R", - "Ruby", - "Rust", - "SQL", - "Scala", - "Swift", - "TypeScript", - "Zig", - ] - return [lang for lang in langs if option.value.lower() in lang.lower()] - - -@plugin.command -@lightbulb.option( - "language", - "Your favorite programming language.", - autocomplete=_programming_language_autocomplete, -) -@lightbulb.command("autocomplete_example", "Autocomplete example.") -@lightbulb.implements(lightbulb.SlashCommand) -async def autocomplete_example(ctx: lightbulb.SlashContext): - """Autocomplete example.""" - await ctx.respond("Your favorite programming language is " + ctx.options.language) - - -# Command groups are like trees -# You can have subcommands, subcommand groups, and subcommand groups with subcommands -# Here is an example diagram: -# /group_example (group) -# subcommand (executable) -# subcommand_group (group) -# subsubcommand (executable) - -# Because those are slash commands, only the leaves (/subcommand and /subsubcommand) are callable. - -# To create a group, use the template below -# 1. Create the command group -@plugin.command -@lightbulb.command("group_example", "Example command group.") -@lightbulb.implements(lightbulb.SlashCommandGroup) -async def group_example(_: lightbulb.SlashContext) -> None: - """Group example.""" - # This will never execute because it is a group - pass - - -# 2. Add a child command -@group_example.child -@lightbulb.command("subcommand", "Example subcommand.") -@lightbulb.implements(lightbulb.SlashSubCommand) -async def subcommand(ctx: lightbulb.SlashContext) -> None: - """An example subcommand.""" - await ctx.respond("invoked `/group_example subcommand`") - - -# 3. Add a sub-group -@group_example.child -@lightbulb.command("subcommand_group", "Example subcommand group.") -@lightbulb.implements(lightbulb.SlashSubGroup) -async def subcommand_group(_: lightbulb.SlashContext) -> None: - """Subcommand group example.""" - # This will never execute because it is a sub-group - pass - - -# 4. Add a child to the sub-group -@subcommand_group.child -@lightbulb.command("subsubcommand", "Example subsubcommand.") -@lightbulb.implements(lightbulb.SlashSubCommand) -async def subsubcommand(ctx: lightbulb.SlashContext) -> None: - """An example subsubcommand.""" - await ctx.respond("invoked `/group_example subcommand_group subsubcommand`") - - -# Event listeners are a way to listen to events from the gateway. -# You can have stand alone event listeners or use `wait_for` to wait for a specific event inside a command / listener. -@plugin.listener(hikari.MemberCreateEvent) -async def on_member_join(event: hikari.MemberCreateEvent) -> None: - """Event listener to welcome new members.""" - guild = event.get_guild() - await event.member.send(f"Welcome to {guild.name if guild else 'the server'}!") - - -# You can also use `wait_for` to wait for a specific event -@plugin.command -@lightbulb.command("wait_for_example", "Example command with `wait_for` and `stream`.") -@lightbulb.implements(lightbulb.SlashCommand) -async def wait_for_example(ctx: lightbulb.SlashContext) -> None: - """Wait for example.""" - await ctx.respond("Send a message!") - - # We can add a predicate to `wait_for` to filter out events - def author_check(e: hikari.MessageCreateEvent) -> bool: - return e.author_id == ctx.author.id - - # You need to wrap wait_for in a try/catch block because it can raise `asyncio.TimeoutError` - try: - event = await ctx.bot.wait_for(hikari.MessageCreateEvent, timeout=10, predicate=author_check) - await ctx.respond(f"You sent: {event.message.content}") - except asyncio.TimeoutError: - await ctx.respond("Too slow!") - # remember to use try/except/finally if you need to clean up any resources - - # You can also use `stream` to listen for events - await ctx.respond("Waiting for guild events...") - with ctx.bot.stream(hikari.Event, timeout=5).filter( - # Only listen for events that have a guild_id and are not bots - lambda e: getattr(e, "guild_id", None) == ctx.guild_id - and getattr(e, "is_human", False) - ) as stream: - async for event in stream: - await ctx.respond(f"New `{event.__class__.__name__}`") - - await ctx.respond("Done!") - - -# You can interact with discord's API using the `rest` attribute on the bot -# This allows you to -# - fetch information about users, channels, guilds, etc. -# - create, edit, and delete messages, channels, threads, roles, categories, etc. -# - add, remove, and edit reactions -@plugin.command -@lightbulb.command("rest_example", "Example command using the `rest` attribute.") -@lightbulb.implements(lightbulb.SlashCommand) -async def rest_example(ctx: lightbulb.SlashContext) -> None: - """Example command using the `rest` attribute.""" - rest = ctx.bot.rest - your_messages = await rest.fetch_messages(ctx.channel_id).filter(lambda m: m.author.id == ctx.author.id).count() - await ctx.respond(f"{your_messages} out of the last 10 messages in this channel were sent by you.") - - -# Context Menus are a way to attach a command to a user or a message. -# By right clicking a user or a User, you can select to execute a command under the "Apps" menu item. -@plugin.command -@lightbulb.command("user_context_menu_example", "Example context menu on a user.") -@lightbulb.implements(lightbulb.UserCommand) -async def user_context_menu_example(ctx: lightbulb.UserContext) -> None: - """User context menu example.""" - user: hikari.Member = ctx.options.target - await ctx.respond(f"Hello {user.mention}!", user_mentions=True) - - -# Same with messages -@plugin.command -@lightbulb.command("message_context_menu_example", "Example context menu on a message.") -@lightbulb.implements(lightbulb.MessageCommand) -async def message_context_menu_example(ctx: lightbulb.MessageContext) -> None: - """Message context menu example.""" - message: hikari.Message = ctx.options.target - await ctx.respond(f"The message length is: {len(message.content or '')}", flags=hikari.MessageFlag.EPHEMERAL) - - -# Components are a way to add interactive buttons to your slash commands. -# We use `miru` to manage components and their callbacks. - -# To create a component, use the template below -# 1. Create the view -class MyView(miru.View): - """An example view with buttons.""" - - @miru.button(label="Rock", emoji="\N{ROCK}", style=hikari.ButtonStyle.PRIMARY) - async def rock_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: - await ctx.respond("Paper!") - - @miru.button(label="Paper", emoji="\N{SCROLL}", style=hikari.ButtonStyle.PRIMARY) - async def paper_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: - await ctx.respond("Scissors!") - - @miru.button(label="Scissors", emoji="\N{BLACK SCISSORS}", style=hikari.ButtonStyle.PRIMARY) - async def scissors_button(self, button: miru.Button, ctx: miru.ViewContext): - await ctx.respond("Rock!") - - @miru.button(emoji="\N{BLACK SQUARE FOR STOP}", style=hikari.ButtonStyle.DANGER, row=2) - async def stop_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: - self.stop() # Stop listening for interactions - - @miru.select( - options=[ - hikari.SelectMenuOption( - label="Thing 1", - value="1", - description="This is a thing", - emoji=hikari.UnicodeEmoji("🗿"), - is_default=True, - ), - hikari.SelectMenuOption( - label="Thing 2", - value="2", - description="This is another thing", - emoji=hikari.UnicodeEmoji("🗿"), - is_default=False, - ), - hikari.SelectMenuOption( - label="Thing 3", - value="3", - description="This is a different thing", - emoji=hikari.UnicodeEmoji("🗿"), - is_default=False, - ), - ], - placeholder="Select some stuff!", - min_values=0, - max_values=2, - row=3, - ) - async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None: - await ctx.respond(f"You selected {select.values}") - - -# 2. Create a command to use the view -@plugin.command -@lightbulb.command("button_example", "Example command with buttons.") -@lightbulb.implements(lightbulb.SlashCommand) -async def button_example(ctx: lightbulb.SlashContext) -> None: - """Wait for example.""" - # 3. Create an instance of the view and start it - view = MyView(timeout=60) - resp = await ctx.respond("Rock Paper Scissors!", components=view) - msg = await resp.message() - await view.start(msg) - await view.wait() - - await ctx.respond("Thank you for playing!") - - -# You can use buttons to create a navigation menu -@plugin.command -@lightbulb.command("nav_example", "Example command with button navigation.", auto_defer=True) -@lightbulb.implements(lightbulb.SlashCommand) -async def navigation_example(ctx: lightbulb.SlashContext) -> None: - """Navigation example.""" - # await ctx.respond(response_type=hikari.ResponseType.DEFERRED_MESSAGE_UPDATE) - embed = hikari.Embed(title="I'm the second page!", description="Also an embed!") - pages = ["I'm the first page!", embed, "I'm the last page!"] - - navigator = nav.NavigatorView(pages=pages, timeout=10) - # You may also pass an interaction object to this function - await navigator.send(ctx.channel_id) - - await navigator.wait() # This is not necessary, but we want to wait anyway - await ctx.respond("Done!") - - -# Miru also has modal support -class MyModal(miru.Modal): - """An example modal.""" - - # Define our modal items - # You can also use Modal.add_item() to add items to the modal after instantiation, just like with views. - name = miru.TextInput(label="Name", placeholder="Enter your name!", required=True) - bio = miru.TextInput(label="Biography", value="Pre-filled content!", style=hikari.TextInputStyle.PARAGRAPH) - - # You can currently only use TextInputs - # https://discord.com/developers/docs/interactions/receiving-and-responding#interaction-response-object-modal - - # The callback function is called after the user hits 'Submit' - async def callback(self, context: miru.ModalContext) -> None: - # You can also access the values using ctx.values, Modal.values, or use ctx.get_value_by_id() - await context.respond(f"Your name: `{self.name.value}`\nYour bio: ```{self.bio.value}```") - - -class ModalView(miru.View): - """An example view that opens a modal.""" - - # Create a new button that will invoke our modal - @miru.button(label="Click me!", style=hikari.ButtonStyle.PRIMARY) - async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None: - modal = MyModal(title="Example Title") - # You may also use Modal.send(interaction) if not working with a miru context object. (e.g. slash commands) - # Keep in mind that modals can only be sent in response to interactions. - await ctx.respond_with_modal(modal) - # OR - # await modal.send(ctx.interaction) - - -@plugin.command -@lightbulb.command("modal_example", "Example command with a modal.") -@lightbulb.implements(lightbulb.SlashCommand) -async def modal_example(ctx: lightbulb.SlashContext) -> None: - """Navigation example.""" - view = ModalView() - resp = await ctx.respond("This button triggers a modal!", components=view) - await view.start(await resp.message()) - - - -def load(bot: lightbulb.BotApp): - """Add the plugin to the bot.""" - bot.add_plugin(plugin) - - -def unload(bot: lightbulb.BotApp): - """Remove the plugin to the bot.""" - bot.remove_plugin(plugin) -```` diff --git a/discord-bot/README.md b/discord-bot/README.md index cde82025..f8b9e433 100644 --- a/discord-bot/README.md +++ b/discord-bot/README.md @@ -1,6 +1,6 @@ # Open-Assistant Data Collection Discord Bot -This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/). +This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large language model. You and other people can teach the bot how to respond to user requests by demonstration and by ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/). ## Invite official bot @@ -8,4 +8,106 @@ To add the official Open-Assistant data collection bot to your discord server [c ## Contributing -To contribute to the bot, please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file. +If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the [large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7) + +### Setup + +To run the bot + +``` +cp .env.example .env + +python -V # 3.10 + +pip install -r requirements.txt +python -m bot +``` + +Before you push, make sure the `pre-commit` hooks are installed and run successfully. + +``` +pip install pre-commit +pre-commit install +pre-commit run --all-files +``` + +To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. + +1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token) +2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable. + +### Resources + +#### Structure + +Important files + +```graphql +.env # Environment variables +.env.example # Example environment variables +CONTRIBUTING.md # This file +README.md # Project readme +EXAMPLES.md # Examples for commands and listeners +requirements.txt # Requirements + +bot/ +├─ __main__.py # Entrypoint +├─ api_client.py # API Client for interacting with the backend +├─ bot.py # Main bot class +├─ settings.py # Settings and secrets +├─ utils.py # Utility Functions +│ +├─ db/ # Database related code +│ ├─ database.db # SQLite database +│ ├─ schema.sql # SQL schema +│ └─ schemas.py # Python table schemas +│ +└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html + ├─ work.py # Task handling logic <-- most important file + ├─ guild_settings.py # Server specific settings + └─ hot_reload.py # Utility for hot reload extensions during development +``` + +#### Adding a new command/listener + +1. Create a new file in the `extensions` folder +2. Copy the template below + +```py +# -*- coding: utf-8 -*- +"""My plugin.""" +import lightbulb + +plugin = lightbulb.Plugin("MyPlugin") + +# Add your commands here + +def load(bot: lightbulb.BotApp): + """Add the plugin to the bot.""" + bot.add_plugin(plugin) + + +def unload(bot: lightbulb.BotApp): + """Remove the plugin to the bot.""" + bot.remove_plugin(plugin) +``` + +#### Docs + +Discord + +- [Discord API Reference](https://discord.com/developers/docs/intro) + +`hikari` (main framework) + +- [Hikari Repo](https://github.com/hikari-py/hikari) +- [Hikari Docs](https://docs.hikari-py.dev/en/latest/) + +`lightbulb` (command handler) + +- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb) +- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/) + +`miru` (component handler: buttons, modals, etc... ) + +- [Miru Repo](https://github.com/HyperGH/hikari-miru) diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index 2cf3d663..4e3bd12c 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -29,7 +29,7 @@ async def on_starting(event: hikari.StartingEvent): await bot.d.db.executescript(open("./bot/db/schema.sql").read()) await bot.d.db.commit() - bot.d.oasst_api = OasstApiClient("http://localhost:8080", "any_key") + bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key) @bot.listen() diff --git a/discord-bot/bot/settings.py b/discord-bot/bot/settings.py index 41c6ae52..200ab54b 100644 --- a/discord-bot/bot/settings.py +++ b/discord-bot/bot/settings.py @@ -10,6 +10,9 @@ class Settings(BaseSettings): declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0) owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list) prefix: str = Field(env="PREFIX", default="./") + oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080") + oasst_api_key: str = Field(env="OASST_API_KEY", default="") class Config(BaseSettings.Config): env_file = ".env" + case_sensitive = False From 37f30f4e3176de357ba25b0d18637c52e4b540ad Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 17:34:09 -0800 Subject: [PATCH 23/27] update readme --- discord-bot/README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/discord-bot/README.md b/discord-bot/README.md index f8b9e433..1ff47c31 100644 --- a/discord-bot/README.md +++ b/discord-bot/README.md @@ -14,7 +14,7 @@ If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the To run the bot -``` +```bash cp .env.example .env python -V # 3.10 @@ -25,10 +25,17 @@ python -m bot Before you push, make sure the `pre-commit` hooks are installed and run successfully. -``` +```bash pip install pre-commit pre-commit install -pre-commit run --all-files + +... + +git add . +git commit -m "" +# if the pre-commit fails +git add . +git commit -m "" ``` To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token. From 004a868cb4cf32ed2be0b43e0e05715f211d145d Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 17:44:20 -0800 Subject: [PATCH 24/27] update user input validator --- discord-bot/bot/extensions/work.py | 44 ++++++++++++++++++------------ 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 5244920b..28ef64c2 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -77,7 +77,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) return # Invalid response - if event.content is None or not _validate_user_input(event.content, task.type): + if event.content is None or not _validate_user_input(event.content, task): await ctx.author.send("Invalid response") continue @@ -250,35 +250,45 @@ async def _send_task( return view.choice, str(msg.id) -# TODO check what the backend expects -def _validate_user_input(content: str | None, task_type: str) -> bool: +def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool: """Returns whether the user's input is valid for the task type.""" if content is None: return False + # User message input if ( - task_type == TaskRequestType.initial_prompt - or task_type == TaskRequestType.user_reply - or task_type == TaskRequestType.assistant_reply + task.type == TaskRequestType.initial_prompt + or task.type == TaskRequestType.user_reply + or task.type == TaskRequestType.assistant_reply ): + assert isinstance( + task, protocol_schema.InitialPromptTask | protocol_schema.UserReplyTask | protocol_schema.AssistantReplyTask + ) return len(content) > 0 - elif ( - task_type == TaskRequestType.rank_initial_prompts - or task_type == TaskRequestType.rank_user_replies - or task_type == TaskRequestType.rank_assistant_replies - ): - rankings = [int(r) for r in content.split(",")] - return all([r in (1, 2, 3, 4, 5) for r in rankings]) and len(rankings) == 5 + # Ranking tasks + elif task.type == TaskRequestType.rank_user_replies or task.type == TaskRequestType.rank_assistant_replies: + assert isinstance(task, protocol_schema.RankUserRepliesTask | protocol_schema.RankAssistantRepliesTask) + num_replies = len(task.replies) - elif task_type == TaskRequestType.summarize_story: + rankings = [int(r) for r in content.split(",")] + return all([r in range(1, num_replies + 1) for r in rankings]) and len(rankings) == num_replies + + elif task.type == TaskRequestType.rank_initial_prompts: + assert isinstance(task, protocol_schema.RankInitialPromptsTask) + num_prompts = len(task.prompts) + + rankings = [int(r) for r in content.split(",")] + return all([r in range(1, num_prompts + 1) for r in rankings]) and len(rankings) == num_prompts + + elif task.type == TaskRequestType.summarize_story: raise NotImplementedError - elif task_type == TaskRequestType.rate_summary: + elif task.type == TaskRequestType.rate_summary: raise NotImplementedError else: - logger.critical(f"Unknown task type {task_type}") - raise ValueError(f"Unknown task type {task_type}") + logger.critical(f"Unknown task type {task.type}") + raise ValueError(f"Unknown task type {task.type}") class TaskAcceptView(miru.View): From 6c3a2eac0316aba7d7de4961498157bfeda6fa2a Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 17:48:22 -0800 Subject: [PATCH 25/27] rename task.py to test_user_input.py --- discord-bot/bot/extensions/{tasks.py => user_input_test.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename discord-bot/bot/extensions/{tasks.py => user_input_test.py} (100%) diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/user_input_test.py similarity index 100% rename from discord-bot/bot/extensions/tasks.py rename to discord-bot/bot/extensions/user_input_test.py From a7b7487611eb22dd4da89447e85e3f3275b2bf80 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 22:55:24 -0800 Subject: [PATCH 26/27] remove example table from sql schema --- discord-bot/bot/db/schema.sql | 5 ----- 1 file changed, 5 deletions(-) diff --git a/discord-bot/bot/db/schema.sql b/discord-bot/bot/db/schema.sql index 9fedf1da..0a710f95 100644 --- a/discord-bot/bot/db/schema.sql +++ b/discord-bot/bot/db/schema.sql @@ -3,8 +3,3 @@ CREATE TABLE IF NOT EXISTS guild_settings ( guild_id BIGINT NOT NULL PRIMARY KEY, log_channel_id BIGINT ); - -CREATE TABLE IF NOT EXISTS example ( - id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, - name VARCHAR(255) NOT NULL -); From 8067dc8f78d0831d41df252477f4f2d65227beec Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Sat, 31 Dec 2022 03:52:09 -0800 Subject: [PATCH 27/27] rename TOKEN env var to BOT_TOKEN --- discord-bot/.env.example | 2 +- discord-bot/bot/bot.py | 3 ++- discord-bot/bot/extensions/guild_settings.py | 11 +++++++++++ discord-bot/bot/settings.py | 2 +- 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/discord-bot/.env.example b/discord-bot/.env.example index 4fcb23b3..5cd18fac 100644 --- a/discord-bot/.env.example +++ b/discord-bot/.env.example @@ -1,4 +1,4 @@ -TOKEN= +BOT_TOKEN= DECLARE_GLOBAL_COMMANDS= OWNER_IDS=[, ] PREFIX="./" diff --git a/discord-bot/bot/bot.py b/discord-bot/bot/bot.py index 4e3bd12c..a305946f 100644 --- a/discord-bot/bot/bot.py +++ b/discord-bot/bot/bot.py @@ -9,8 +9,9 @@ from bot.settings import Settings settings = Settings() +# TODO: Revisit cache settings bot = lightbulb.BotApp( - token=settings.token, + token=settings.bot_token, logs="DEBUG", prefix=settings.prefix, default_enabled_guilds=settings.declare_global_commands, diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py index f5785b8d..1aba9f47 100644 --- a/discord-bot/bot/extensions/guild_settings.py +++ b/discord-bot/bot/extensions/guild_settings.py @@ -5,6 +5,7 @@ import lightbulb from aiosqlite import Connection from bot.db.schemas import GuildSettings from bot.utils import mention +from lightbulb.utils.permissions import permissions_in from loguru import logger plugin = lightbulb.Plugin("GuildSettings") @@ -62,6 +63,16 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None: channel: hikari.TextableGuildChannel = ctx.options.channel conn: Connection = ctx.bot.d.db assert ctx.guild_id is not None # `guild_only` check + assert isinstance(channel, hikari.PermissibleGuildChannel) + + # Check if the bot can send messages in that channel + assert (me := ctx.bot.get_me()) is not None # non-None after `StartedEvent` + if (own_member := ctx.bot.cache.get_member(ctx.guild_id, me.id)) is None: + own_member = await ctx.bot.rest.fetch_member(ctx.guild_id, me.id) + perms = permissions_in(channel, own_member) + if perms & ~hikari.Permissions.SEND_MESSAGES: + await ctx.respond("I don't have permission to send messages in that channel.") + return await ctx.respond(f"Setting `log_channel` to {channel.mention}.") diff --git a/discord-bot/bot/settings.py b/discord-bot/bot/settings.py index 200ab54b..136c2b22 100644 --- a/discord-bot/bot/settings.py +++ b/discord-bot/bot/settings.py @@ -6,7 +6,7 @@ from pydantic import BaseSettings, Field class Settings(BaseSettings): """Settings for the bot.""" - token: str = Field(env="TOKEN", default="") + bot_token: str = Field(env="BOT_TOKEN", default="") declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0) owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list) prefix: str = Field(env="PREFIX", default="./")