initial bot structure

This commit is contained in:
AlexanderHOtt
2022-12-28 16:43:14 -08:00
parent a4e5f566a8
commit 3ce6ab80d6
24 changed files with 340 additions and 802 deletions
+3
View File
@@ -0,0 +1,3 @@
TOKEN=<discord bot token>
DECLARE_GLOBAL_COMMANDS=<testing guild id>
OWNER_IDS=<your user id>
+4
View File
@@ -1,3 +1,7 @@
.env
*.egg-info/
__pycache__/
.venv
.nox
.env
+43
View File
@@ -0,0 +1,43 @@
# Contributing
## Setup
To run the bot
```
cp .env.example .env
python -V # 3.10
pip install -r requirements.txt
python -m bot
```
To test the bot
```
python -m pip install -r dev-requirements.txt
nox
```
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable.
## Resources
Main framework
- [Hikari Repo](https://github.com/hikari-py/hikari)
- [Hikari Docs](https://docs.hikari-py.dev/en/latest/)
Command handler
- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb)
- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/)
Component handler (buttons, modals, etc... )
- [Miru Repo](https://github.com/HyperGH/hikari-miru)
+2 -11
View File
@@ -6,15 +6,6 @@ This bot collects human feedback to create a dataset for RLHF-alignment of an as
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
## Bot token for development
## Contributing
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
The simplest way to configure the token is via an `.env` file:
```
BOT_TOKEN=XYZABC123...
```
To contribute to the bot, please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) file.
-17
View File
@@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
from bot import OpenAssistantBot
from bot_settings import settings
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
if __name__ == "__main__":
bot = OpenAssistantBot(
settings.BOT_TOKEN,
bot_channel_name=settings.BOT_CHANNEL_NAME,
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
+2 -1
View File
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import enum
from typing import Optional, Type
import typing as t
import requests
from oasst_shared.schemas import protocol as protocol_schema
@@ -41,7 +42,7 @@ class ApiClient:
response.raise_for_status()
return response.json()
def _parse_task(self, data: dict) -> protocol_schema.Task:
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
if not isinstance(data, dict):
raise ValueError("dict expected")
-283
View File
@@ -1,283 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)
+2
View File
@@ -0,0 +1,2 @@
# -*- coding=utf-8 -*-
"""The official Open-Assistant Discord Bot."""
+17
View File
@@ -0,0 +1,17 @@
# -*- coding=utf-8 -*-
"""Entry point for the bot."""
import logging
import os
from bot.bot import bot
logger = logging.getLogger(__name__)
if __name__ == "__main__":
if os.name != "nt":
import uvloop
uvloop.install()
logger.info("Starting bot")
bot.run()
+37
View File
@@ -0,0 +1,37 @@
# -*- coding=utf-8 -*-
"""Bot logic."""
import hikari
import aiosqlite
import lightbulb
import miru
from bot.config import Config
config = Config.from_env()
bot = lightbulb.BotApp(
token=config.token,
logs="DEBUG",
prefix="./",
default_enabled_guilds=config.declare_global_commands,
owner_ids=config.owner_ids,
intents=hikari.Intents.ALL,
)
@bot.listen()
async def on_starting(event: hikari.StartingEvent):
"""Setup."""
miru.install(bot) # component handler
bot.load_extensions_from("./bot/extensions") # load extensions
bot.d.db = await aiosqlite.connect(":memory:") # TODO: Update
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
await bot.d.db.commit()
@bot.listen()
async def on_stopping(event: hikari.StoppingEvent):
"""Cleanup."""
await bot.d.db.close()
+35
View File
@@ -0,0 +1,35 @@
# -*- coding=utf-8 -*-
"""Configuration for the bot."""
import logging
from dataclasses import dataclass
from os import getenv
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
@dataclass
class Config:
"""Configuration for the bot."""
token: str
declare_global_commands: int
owner_ids: list[int]
@classmethod
def from_env(cls):
token = getenv("TOKEN", None)
if token is None:
logger.error("Invalid token, please set the TOKEN environment variable.")
exit(1)
return cls(
token=token,
declare_global_commands=int(getenv("DECLARE_GLOBAL_COMMANDS", 0)),
owner_ids=[int(x) for x in getenv("OWNER_IDS", "").split(",")],
)
View File
+10
View File
@@ -0,0 +1,10 @@
-- Sqlite3 schema for the bot
CREATE TABLE IF NOT EXISTS guild_settings (
guild_id BIGINT NOT NULL PRIMARY KEY,
log_channel_id BIGINT
);
CREATE TABLE IF NOT EXISTS example (
id INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT,
name VARCHAR(255) NOT NULL
);
+61
View File
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
"""Hot reload plugin."""
from glob import glob
import hikari
import lightbulb
plugin = lightbulb.Plugin(
"HotReloadPlugin",
)
plugin.add_checks(lightbulb.owner_only)
EXTENSIONS_FOLDER = "bot/extensions"
def _get_extensions() -> list[str]:
# Recursively get all the .py files in the extensions directory.
exts = glob("bot/extensions/**/*.py", recursive=True)
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]:
# Check that the option is a string.
if not isinstance(option.value, str):
raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`")
exts = _get_extensions()
return [ext for ext in exts if option.value in ext]
@plugin.command
@lightbulb.option(
"plugin",
"The plugin to reload. Leave empty to reload all plugins.",
autocomplete=_plugin_autocomplete,
required=False,
default=None,
)
@lightbulb.command("reload", "Reload a plugin")
@lightbulb.implements(lightbulb.SlashCommand)
async def reload(ctx: lightbulb.SlashContext):
"""Reload a plugin or all plugins."""
# If the plugin option is None, reload all plugins.
if ctx.options.plugin is None:
ctx.bot.reload_extensions(*_get_extensions())
await ctx.respond("Reloaded all plugins.")
# Otherwise, reload the specified plugin.
else:
ctx.bot.reload_extensions(ctx.options.plugin)
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
-61
View File
@@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from abc import ABC
from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@dataclass
class ReplyHandlerInfo:
msg_id: int
handler_task: asyncio.Task
handler: ChannelHandlerBase
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
bot_channel: discord.TextChannel
templates: MessageTemplates
reply_handlers: dict[int, ReplyHandlerInfo]
def __init__(self):
self.reply_handlers = {} # handlers by msg_id
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
) -> discord.Message:
if channel is None:
self.ensure_bot_channel()
channel = self.bot_channel
return await channel.send(content=content, view=view)
async def post_template(
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
) -> discord.Message:
logger.debug(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view=view, channel=channel)
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
if msg_id in self.reply_handlers:
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
task.add_done_callback(lambda t: handler.on_completed())
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
-15
View File
@@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
from pydantic import AnyHttpUrl, BaseSettings
class BotSettings(BaseSettings):
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
API_KEY: str = "any_key"
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
-88
View File
@@ -1,88 +0,0 @@
# -*- coding: utf-8 -*-
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
import discord
from loguru import logger
class ChannelExpiredException(Exception):
pass
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool = False
expiry_date: datetime
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.queue = asyncio.Queue()
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
if self.expired:
raise ChannelExpiredException()
msg = await self.queue.get()
if msg is None:
if self.expired:
raise ChannelExpiredException()
else:
raise RuntimeError("Unexpected None message read")
return msg
def on_reply(self, message: discord.Message) -> None:
self.queue.put_nowait(message)
def on_expire(self) -> None:
logger.info("ChannelHandler: on_expire")
self.expired = True
self.queue.put_nowait(None)
def on_completed(self) -> None:
logger.info("ChannelHandler: on_completed")
self.completed = True
def tick(self, now: datetime):
if now > self.expiry_date and not self.expired:
self.on_expire()
@abstractmethod
async def handler_loop(self):
...
async def finalize(self):
pass
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
async def read(self) -> discord.Message:
try:
return await super().read()
except ChannelExpiredException:
await self.cleanup()
raise
async def cleanup(self):
logger.debug("AutoDestructThreadHandler.cleanup")
if self.thread:
logger.debug(f"deleting thread: {self.thread.name}")
await self.thread.delete()
self.thread = None
if self.first_message:
logger.debug(f"deleting first_message: {self.first_message.content}")
await self.first_message.delete()
self.first_message = None
async def finalize(self):
await self.cleanup()
return await super().finalize()
+8
View File
@@ -0,0 +1,8 @@
nox
black
isort
codespell
flake8
pyright
+26
View File
@@ -0,0 +1,26 @@
flake8==6.0.0
# Plugins
Flake8-pyproject # use the pyproject.toml as the config file
flake8-bandit # runs bandit
flake8-black # runs black
# flake8-broken-line # forbey "\" linebreaks
flake8-builtins # builtin shadowing checks
flake8-coding # coding magic-comment detection
flake8-comprehensions # comprehension checks
flake8-deprecated # deprecated call checks
flake8-docstrings # pydocstyle support
flake8-executable # shebangs
flake8-fixme # "fix me" counter
flake8-functions # function linting
flake8-html # html output
flake8-if-statements # condition linting
flake8-isort # runs isort
flake8-mutable # mutable default argument detection
flake8-pep3101 # new-style format strings only
flake8-print # complain about print statements in code
flake8-printf-formatting # forbey printf-style python2 string formatting
flake8-pytest-style # pytest checks
flake8-raise # exception raising linting
flake8-use-fstring # format string checking
+33
View File
@@ -0,0 +1,33 @@
# -*- coding=utf-8 -*-
"""Automated linting, formatting, and typechecking."""
import nox
from nox.sessions import Session
@nox.session(reuse_venv=True)
def format_code(session: Session):
"""Format the codebase."""
session.install("isort", "-U")
session.install("black", "-U")
session.run("isort", "bot")
session.run("black", "bot")
@nox.session(reuse_venv=True)
def lint_code(session: Session):
"""Lint the codebase."""
session.install("codespell", "-U")
session.install("flake8", "-U")
session.install("-r", "flake8-requirements.txt", "-U")
session.run("codespell", "bot")
session.run("flake8", "bot")
@nox.session(reuse_venv=True)
def typecheck_code(session: Session):
session.install("-r", "requirements.txt", "-U")
session.install("pyright", "-U")
session.run("pyright", "bot")
+47
View File
@@ -0,0 +1,47 @@
[project]
name = "Open-Assistant Discord Bot"
version = "0.0.1"
[tool.black]
line-length = 120
target-version = ["py310"]
[tool.pyright]
include = ["ottbot", "noxfile.py"]
pythonVersion="3.10"
reportMissingImports=false
# reportInvalidTypeVarUse=false
# reportMissingModuleSource=false
reportUnknownVariableType=false
pythonPlatform="Linux"
[tool.isort]
profile="black"
sections = ['FUTURE', 'STDLIB', 'THIRDPARTY', 'FIRSTPARTY', 'LOCALFOLDER']
skip_glob = "**/__init__.pyi"
[tool.flake8]
max-function-length = 130
max-line-length = 130
# Technically this is 120, but black has a policy of "1 or 2 over is fine if it is tidier", so we have to raise this.
accept-encodings = "utf-8"
docstring-convention = "numpy"
ignore = [
"A002", # Argument is shadowing a python builtin.
"A003", # Class attribute is shadowing a python builtin.
"CFQ002", # Function has too many arguments.
"CFQ004", # Function has too many returns.
"D001", # False positive for depreciated functions.
"D102", # Missing docstring in public method.
"D105", # Magic methods not having a docstring.
"D412", # No blank lines allowed between a section header and its content
"E203", # Whitespace after : (to match how black formats it)
"E402", # Module level import not at top of file (isn't compatible with our import style).
"T101", # TO-DO comment detection (T102 is FIX-ME and T103 is XXX).
"W503", # line break before binary operator.
"W504", # line break before binary operator (again, I guess).
"S101", # Use of assert detected. The enclosed code will be removed when compiling to optimised byte code.
"S105", # Possible hardcoded password.
"EXE002", # Executable file with not shebang
"D401", # Imperative mood
]
+10 -7
View File
@@ -1,7 +1,10 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
hikari # discord framework
hikari[speedups]
uvloop; os_name != 'nt'
hikari-lightbulb # command handler
hikari-miru # modals and buttons
python-dotenv # .env file support
aiosqlite # database
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
-267
View File
@@ -1,267 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
name = discord.ui.TextInput(label="Name")
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
pass
@abstractmethod
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_teaser_msg(self, template_name: str):
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
return await self.bot.post_template(
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToPost(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_text_reply_to_post()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.PostRanking(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
ranking=ranking,
)
)
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
ranking_str = user_msg.content
ranking = [int(x) - 1 for x in ranking_str.split(",")]
self.post_ranking(user_msg, ranking=ranking)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_ranking()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_summarize_story.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
task: protocol_schema.InitialPromptTask
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_initial_prompt.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
task: protocol_schema.UserReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_user_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
task: protocol_schema.AssistantReplyTask
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_assistant_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class RankInitialPromptsHandler(ChannelTaskBase):
task: protocol_schema.RankInitialPromptsTask
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RankConversationsHandler(ChannelTaskBase):
task: protocol_schema.RankConversationRepliesTask
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class RateSummaryHandler(ChannelTaskBase):
task: protocol_schema.RateSummaryTask
thread_name: str = "Ratings"
async def _rating_response_handler(self, score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
if self.thread:
try:
self.backend.post_interaction(
protocol_schema.PostRating(
post_id=str(self.first_message.id),
user_post_id=str(interaction.id),
user=self.to_api_user(interaction.user),
rating=score,
)
)
await interaction.response.send_message(
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
)
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in _rating_response_handler()")
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rate_summary.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
logger.info(f"on_rate_summary_reply: {msg.content}")
await msg.add_reaction("")
await msg.reply("❌ Text intput not supported.")
-52
View File
@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)