Merge branch 'main' of https://github.com/AlexanderHOtt/Open-Assistant into AlexanderHOtt-main

This commit is contained in:
Andreas Köpf
2022-12-31 12:54:16 +01:00
28 changed files with 1509 additions and 880 deletions
+2 -2
View File
@@ -55,7 +55,7 @@ def generate_task(
hint="Ask the assistant about a current event." # this is optional
)
case protocol_schema.TaskRequestType.prompter_reply:
logger.info("Generating a UserReplyTask.")
logger.info("Generating a PrompterReplyTask.")
messages = pr.fetch_random_conversation("assistant")
task_messages = [
protocol_schema.ConversationMessage(
@@ -86,7 +86,7 @@ def generate_task(
messages = pr.fetch_random_initial_prompts()
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
case protocol_schema.TaskRequestType.rank_prompter_replies:
logger.info("Generating a RankUserRepliesTask.")
logger.info("Generating a RankPrompterRepliesTask.")
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
task_messages = [
+7
View File
@@ -0,0 +1,7 @@
BOT_TOKEN=<discord bot token>
DECLARE_GLOBAL_COMMANDS=<testing guild id>
OWNER_IDS=[<your user id>, <other user ids>]
PREFIX="./"
OASST_API_URL="http://localhost:8080" # No trailing '/'
OASST_API_KEY=""
+7
View File
@@ -1,3 +1,10 @@
.env
*.egg-info/
__pycache__/
.venv
.nox
.env
# Database files
*.db
+105 -5
View File
@@ -1,20 +1,120 @@
# Open-Assistant Data Collection Discord Bot
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large language model. You and other people can teach the bot how to respond to user requests by demonstration and by ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
## Invite official bot
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
## Bot token for development
## Contributing
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the [large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
### Setup
To run the bot
```bash
cp .env.example .env
python -V # 3.10
pip install -r requirements.txt
python -m bot
```
Before you push, make sure the `pre-commit` hooks are installed and run successfully.
```bash
pip install pre-commit
pre-commit install
...
git add .
git commit -m "<good commit message>"
# if the pre-commit fails
git add .
git commit -m "<good commit message>"
```
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable.
The simplest way to configure the token is via an `.env` file:
### Resources
#### Structure
Important files
```graphql
.env # Environment variables
.env.example # Example environment variables
CONTRIBUTING.md # This file
README.md # Project readme
EXAMPLES.md # Examples for commands and listeners
requirements.txt # Requirements
bot/
__main__.py # Entrypoint
api_client.py # API Client for interacting with the backend
bot.py # Main bot class
settings.py # Settings and secrets
utils.py # Utility Functions
db/ # Database related code
database.db # SQLite database
schema.sql # SQL schema
schemas.py # Python table schemas
extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
work.py # Task handling logic <-- most important file
guild_settings.py # Server specific settings
hot_reload.py # Utility for hot reload extensions during development
```
BOT_TOKEN=XYZABC123...
#### Adding a new command/listener
1. Create a new file in the `extensions` folder
2. Copy the template below
```py
# -*- coding: utf-8 -*-
"""My plugin."""
import lightbulb
plugin = lightbulb.Plugin("MyPlugin")
# Add your commands here
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
```
#### Docs
Discord
- [Discord API Reference](https://discord.com/developers/docs/intro)
`hikari` (main framework)
- [Hikari Repo](https://github.com/hikari-py/hikari)
- [Hikari Docs](https://docs.hikari-py.dev/en/latest/)
`lightbulb` (command handler)
- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb)
- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/)
`miru` (component handler: buttons, modals, etc... )
- [Miru Repo](https://github.com/HyperGH/hikari-miru)
-18
View File
@@ -1,18 +0,0 @@
# -*- coding: utf-8 -*-
from bot import OpenAssistantBot
from bot_settings import settings
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
if __name__ == "__main__":
bot = OpenAssistantBot(
settings.BOT_TOKEN,
bot_channel_name=settings.BOT_CHANNEL_NAME,
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
-79
View File
@@ -1,79 +0,0 @@
# -*- coding: utf-8 -*-
import enum
from typing import Optional, Type
import requests
from oasst_shared.schemas import protocol as protocol_schema
class TaskType(str, enum.Enum):
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class ApiClient:
def __init__(self, backend_url: str, api_key: str):
self.backend_url = backend_url
self.api_key = api_key
task_models_map: dict[str, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
self.task_models_map = task_models_map
def post(self, path: str, json: dict) -> dict:
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return response.json()
def _parse_task(self, data: dict) -> protocol_schema.Task:
if not isinstance(data, dict):
raise ValueError("dict expected")
task_type = data.get("type")
if task_type not in self.task_models_map:
raise RuntimeError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data)
def fetch_task(
self,
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
) -> protocol_schema.Task:
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
data = self.post("/api/v1/tasks/", req.dict())
return self._parse_task(data)
def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
) -> protocol_schema.Task:
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
def ack_task(self, task_id: str, message_id: str) -> None:
req = protocol_schema.TaskAck(message_id=message_id)
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
def nack_task(self, task_id: str, reason: str) -> None:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
-283
View File
@@ -1,283 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
bot_channel_name: str,
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
if isinstance(owner_id, str):
owner_id = int(owner_id)
self.owner_id = owner_id
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
if message.author != client.user:
await self.handle_message(message)
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
handler = task_handlers.InitialPromptHandler()
case TaskType.prompter_reply:
handler = task_handlers.PrompterReplyHandler()
case TaskType.assistant_reply:
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_prompter_replies | TaskType.rank_assistant_replies:
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
logger.info(f"sync tree command received: {command}")
if command == "sync.copy_global":
await self.tree.copy_global_to(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.clear_guild":
self.tree.clear_commands(guild=message.guild)
synced = await self.tree.sync(guild=message.guild)
elif command == "sync.guild":
synced = await self.tree.sync(guild=message.guild)
else:
synced = await self.tree.sync()
logger.info(f"Synced {len(synced)} commands")
await message.reply(f"Synced {len(synced)} commands")
async def handle_command(self, message: discord.Message, is_owner: bool):
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
if channel.type == discord.ChannelType.text and channel.name == channel_name:
return channel
def run(self):
"""Run bot loop blocking."""
self.client.run(self.bot_token)
+2
View File
@@ -0,0 +1,2 @@
# -*- coding: utf-8 -*-
"""The official Open-Assistant Discord Bot."""
+17
View File
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
"""Entry point for the bot."""
import logging
import os
from bot.bot import bot
logger = logging.getLogger(__name__)
if __name__ == "__main__":
if os.name != "nt":
import uvloop
uvloop.install()
logger.info("Starting bot")
bot.run()
+113
View File
@@ -0,0 +1,113 @@
# -*- coding: utf-8 -*-
"""API Client for interacting with the OASST backend."""
import enum
import typing as t
from typing import Optional, Type
from uuid import UUID
import aiohttp
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
# TODO: Move to `protocol`?
class TaskType(str, enum.Enum):
"""Task types."""
summarize_story = "summarize_story"
rate_summary = "rate_summary"
initial_prompt = "initial_prompt"
prompter_reply = "prompter_reply"
assistant_reply = "assistant_reply"
rank_initial_prompts = "rank_initial_prompts"
rank_prompter_replies = "rank_prompter_replies"
rank_assistant_replies = "rank_assistant_replies"
done = "task_done"
class OasstApiClient:
"""API Client for interacting with the OASST backend."""
def __init__(self, backend_url: str, api_key: str):
"""Create a new OasstApiClient.
Args:
----
backend_url (str): The base backend URL.
api_key (str): The API key to use for authentication.
"""
logger.debug("Opening OasstApiClient session")
self.session = aiohttp.ClientSession()
self.backend_url = backend_url
self.api_key = api_key
self.task_models_map: dict[TaskType, Type[protocol_schema.Task]] = {
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
TaskType.rate_summary: protocol_schema.RateSummaryTask,
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
TaskType.done: protocol_schema.TaskDone,
}
async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]:
"""Make a POST request to the backend."""
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
response.raise_for_status()
return await response.json()
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
task_type = TaskType(data.get("type"))
model = self.task_models_map.get(task_type)
if not model:
logger.error(f"Unsupported task type: {task_type}")
raise ValueError(f"Unsupported task type: {task_type}")
return self.task_models_map[task_type].parse_obj(data) # type: ignore
async def fetch_task(
self,
task_type: protocol_schema.TaskRequestType,
user: Optional[protocol_schema.User] = None,
collective: bool = False,
) -> protocol_schema.Task:
"""Fetch a task from the backend."""
logger.debug(f"Fetching task {task_type} for user {user}")
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
resp = await self.post("/api/v1/tasks/", data=req.dict())
logger.debug(f"RESP {resp}")
return self._parse_task(resp)
async def fetch_random_task(
self, user: Optional[protocol_schema.User] = None, collective: bool = False
) -> protocol_schema.Task:
"""Fetch a random task from the backend."""
logger.debug(f"Fetching random for user {user}")
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
async def ack_task(self, task_id: str | UUID, message_id: str):
"""Send an ACK for a task to the backend."""
logger.debug(f"ACK task {task_id} with post {message_id}")
req = protocol_schema.TaskAck(message_id=message_id)
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
async def nack_task(self, task_id: str | UUID, reason: str):
"""Send a NACK for a task to the backend."""
logger.debug(f"NACK task {task_id} with reason {reason}")
req = protocol_schema.TaskNAck(reason=reason)
return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
"""Send a completed task to the backend."""
logger.debug(f"Interaction: {interaction}")
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
return self._parse_task(resp)
async def close(self):
logger.debug("Closing OasstApiClient session")
await self.session.close()
+40
View File
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
"""Bot logic."""
import aiosqlite
import hikari
import lightbulb
import miru
from bot.api_client import OasstApiClient
from bot.settings import Settings
settings = Settings()
# TODO: Revisit cache settings
bot = lightbulb.BotApp(
token=settings.bot_token,
logs="DEBUG",
prefix=settings.prefix,
default_enabled_guilds=settings.declare_global_commands,
owner_ids=settings.owner_ids,
intents=hikari.Intents.ALL,
)
@bot.listen()
async def on_starting(event: hikari.StartingEvent):
"""Setup."""
miru.install(bot) # component handler
bot.load_extensions_from("./bot/extensions") # load extensions
bot.d.db = await aiosqlite.connect("./bot/db/database.db")
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
await bot.d.db.commit()
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
@bot.listen()
async def on_stopping(event: hikari.StoppingEvent):
"""Cleanup."""
await bot.d.db.close()
await bot.d.oasst_api.close()
+5
View File
@@ -0,0 +1,5 @@
-- Sqlite3 schema for the bot
CREATE TABLE IF NOT EXISTS guild_settings (
guild_id BIGINT NOT NULL PRIMARY KEY,
log_channel_id BIGINT
);
+28
View File
@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
"""Database schemas."""
import typing as t
from aiosqlite import Connection, Row
from pydantic import BaseModel
class GuildSettings(BaseModel):
"""Guild settings."""
guild_id: int
log_channel_id: int | None
@classmethod
def parse_obj(cls, obj: Row) -> "GuildSettings":
"""Deserialize a Row object from aiosqlite into a GuildSettings object."""
return cls(guild_id=obj[0], log_channel_id=obj[1])
@classmethod
async def from_db(cls, conn: Connection, guild_id: int) -> t.Optional["GuildSettings"]:
async with conn.cursor() as cursor:
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
row = await cursor.fetchone()
if row is None:
return None
return cls.parse_obj(row)
+5
View File
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
"""Extensions for the bot.
See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
"""
@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
"""Guild settings."""
import hikari
import lightbulb
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import mention
from lightbulb.utils.permissions import permissions_in
from loguru import logger
plugin = lightbulb.Plugin("GuildSettings")
plugin.add_checks(lightbulb.guild_only)
plugin.add_checks(lightbulb.has_guild_permissions(hikari.Permissions.MANAGE_GUILD))
@plugin.command
@lightbulb.command("settings", "Bot settings for the server.")
@lightbulb.implements(lightbulb.SlashCommandGroup)
async def settings(_: lightbulb.SlashContext) -> None:
"""Bot settings for the server."""
# This will never execute because it is a group
pass
@settings.child
@lightbulb.command("get", "Get all the guild settings.")
@lightbulb.implements(lightbulb.SlashSubCommand)
async def get(ctx: lightbulb.SlashContext) -> None:
"""Get one of or all the guild settings."""
conn: Connection = ctx.bot.d.db
assert ctx.guild_id is not None # `guild_only` check
async with conn.cursor() as cursor:
# Get all settings
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (ctx.guild_id,))
row = await cursor.fetchone()
if row is None:
logger.warning(f"No guild settings for {ctx.guild_id}")
await ctx.respond("No settings found for this guild.")
return
guild_settings = GuildSettings.parse_obj(row)
# Respond with all
# TODO: Embed
await ctx.respond(
f"""\
**Guild Settings**
`log_channel`: {
mention(guild_settings.log_channel_id, "channel")
if guild_settings.log_channel_id else 'not set'}
"""
)
@settings.child
@lightbulb.option("channel", "The channel to use.", hikari.TextableGuildChannel)
@lightbulb.command("log_channel", "Set the channel that the bot logs task and label completions in.")
@lightbulb.implements(lightbulb.SlashSubCommand)
async def log_channel(ctx: lightbulb.SlashContext) -> None:
"""Set the channel that the bot logs task and label completions in."""
channel: hikari.TextableGuildChannel = ctx.options.channel
conn: Connection = ctx.bot.d.db
assert ctx.guild_id is not None # `guild_only` check
assert isinstance(channel, hikari.PermissibleGuildChannel)
# Check if the bot can send messages in that channel
assert (me := ctx.bot.get_me()) is not None # non-None after `StartedEvent`
if (own_member := ctx.bot.cache.get_member(ctx.guild_id, me.id)) is None:
own_member = await ctx.bot.rest.fetch_member(ctx.guild_id, me.id)
perms = permissions_in(channel, own_member)
if perms & ~hikari.Permissions.SEND_MESSAGES:
await ctx.respond("I don't have permission to send messages in that channel.")
return
await ctx.respond(f"Setting `log_channel` to {channel.mention}.")
async with conn.cursor() as cursor:
await cursor.execute(
"INSERT OR REPLACE INTO guild_settings (guild_id, log_channel_id) VALUES (?, ?)",
(ctx.guild_id, channel.id),
)
await conn.commit()
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+64
View File
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
"""Hot reload plugin."""
from glob import glob
import hikari
import lightbulb
from loguru import logger
plugin = lightbulb.Plugin(
"HotReloadPlugin",
)
plugin.add_checks(lightbulb.owner_only)
EXTENSIONS_FOLDER = "bot/extensions"
def _get_extensions() -> list[str]:
# Recursively get all the .py files in the extensions directory not starting with an `_`.
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]:
# Check that the option is a string.
if not isinstance(option.value, str):
raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`")
exts = _get_extensions()
return [ext for ext in exts if option.value in ext]
@plugin.command
@lightbulb.option(
"plugin",
"The plugin to reload. Leave empty to reload all plugins.",
autocomplete=_plugin_autocomplete,
required=False,
default=None,
)
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def reload(ctx: lightbulb.SlashContext):
"""Reload a plugin or all plugins."""
# If the plugin option is None, reload all plugins.
if ctx.options.plugin is None:
ctx.bot.reload_extensions(*_get_extensions())
await ctx.respond("Reloaded all plugins.")
logger.info("Reloaded all plugins.")
# Otherwise, reload the specified plugin.
else:
ctx.bot.reload_extensions(ctx.options.plugin)
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
logger.info(f"Reloaded `{ctx.options.plugin}`.")
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+181
View File
@@ -0,0 +1,181 @@
# -*- coding: utf-8 -*-
"""Hot reload plugin."""
import typing as t
from datetime import datetime
import hikari
import lightbulb
import miru
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from loguru import logger
plugin = lightbulb.Plugin(
"TextLabels",
)
plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds
DISCORD_GRAY = 0x2F3136
def clamp(num: float) -> float:
"""Clamp a number between 0 and 1."""
return min(max(0.0, num), 1.0)
class LabelModal(miru.Modal):
"""Modal for submitting text labels."""
def __init__(self, label: str, content: str, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
self.label = label
self.original_content = content
# Add the text of the message to the modal
self.content = miru.TextInput(
label="Text", style=hikari.TextInputStyle.PARAGRAPH, value=content, required=True, row=1
)
self.add_item(self.content)
value = miru.TextInput(label="Value", placeholder="Enter a value between 0 and 1", required=True, row=2)
async def callback(self, context: miru.ModalContext) -> None:
val = float(self.value.value) if self.value.value else 0.0
val = clamp(val)
edited = self.content.value != self.original_content
await context.respond(
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
flags=hikari.MessageFlag.EPHEMERAL,
)
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
# Send a notification to the log channel
assert context.guild_id is not None # `guild_only` check
conn: Connection = context.bot.d.db # type: ignore
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
if guild_settings is None or guild_settings.log_channel_id is None:
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
return
embed = (
hikari.Embed(
title="Message Label",
description=f"{context.author.mention} labeled a message as `{self.label}`.",
timestamp=datetime.now().astimezone(),
color=0x00FF00,
)
.set_author(name=context.author.username, icon=context.author.avatar_url)
.add_field("Total Labeled Message", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
)
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
assert isinstance(channel, hikari.TextableChannel)
await channel.send(EMPTY, embed=embed)
class LabelSelect(miru.View):
"""Select menu for selecting a label.
The current labels are:
- contains toxic language
- encourages illegal activity
- good quality
- bad quality
- is spam
"""
def __init__(self, content: str, *args: t.Any, **kwargs: t.Any):
super().__init__(*args, **kwargs)
self.content = content
@miru.select(
options=[
hikari.SelectMenuOption(
label="Toxic Language",
value="toxic_language",
description="The message contains toxic language.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Illegal Activity",
value="illegal_activity",
description="The message encourages illegal activity.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Good Quality",
value="good_quality",
description="The message is good quality.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Bad Quality",
value="bad_quality",
description="The message is bad quality.",
is_default=False,
emoji=None,
),
hikari.SelectMenuOption(
label="Spam",
value="spam",
description="The message is spam.",
is_default=False,
emoji=None,
),
],
min_values=1,
max_values=1,
)
async def label_select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
"""Handle the select menu."""
label = select.values[0]
modal = LabelModal(label, self.content, title=f"Text Label: {label}", timeout=60)
await modal.send(ctx.interaction)
await modal.wait()
self.stop()
@plugin.command
@lightbulb.command("Label Message", "Label a message")
@lightbulb.implements(lightbulb.MessageCommand)
async def label_message_text(ctx: lightbulb.MessageContext):
"""Label a message."""
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
# so the select menu will open the modal
msg: hikari.Message = ctx.options.target
# Exit if the message is empty
if not msg.content:
await ctx.respond("Cannot label an empty message.", flags=hikari.MessageFlag.EPHEMERAL)
return
# Send the select menu
# The modal will be opened from the select menu interaction
embed = hikari.Embed(title="Label Message", description="Select a label for the message.", color=DISCORD_GRAY)
label_select_view = LabelSelect(
msg.content,
timeout=60,
)
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
await label_select_view.start(await resp.message())
await label_select_view.wait()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
@@ -0,0 +1,301 @@
# -*- coding: utf-8 -*-
"""Task plugin for testing different data collection methods."""
# TODO: Delete this once user input method has been decided for final bot.
import asyncio
import typing as t
from datetime import datetime, timedelta
import hikari
import lightbulb
import lightbulb.decorators
import miru
from bot.utils import format_time
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("TaskPlugin")
MAX_TASK_TIME = 60 * 60
MAX_TASK_ACCEPT_TIME = 60
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_thread(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
typ: str = ctx.options.type
# Create a thread for the task
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
# Send the task in the thread
await thread.send(
f"""\
Please complete the task.
Sample Task
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
"""
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.GuildMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
)
await ctx.respond(f"Received message: {event.message.content}")
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
finally:
await thread.delete()
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
async def task_dm(ctx: lightbulb.Context):
"""Request a task from the backend."""
await ctx.respond("Please complete the task in your DMs")
# Send the task in the dm
await ctx.author.send(
f"""\
Please complete the task.
Sample Task
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
"""
)
# Wait for the user to respond
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent,
timeout=MAX_TASK_TIME,
predicate=lambda e: e.author.id == ctx.author.id,
)
await ctx.respond(f"Received message: {event.message.content}")
except asyncio.TimeoutError:
await ctx.respond("You took too long to respond.")
class TaskModal(miru.Modal):
"""Modal for submitting a task."""
response = miru.TextInput(
label="Response",
placeholder="Enter your response!",
required=True,
style=hikari.TextInputStyle.PARAGRAPH,
row=2,
)
async def callback(self, context: miru.ModalContext) -> None:
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
class ModalView(miru.View):
"""View for opening a modal."""
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.modal_title = modal_title
self.task = task
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
modal = TaskModal(title=self.modal_title)
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
await ctx.respond_with_modal(modal)
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
required=False,
default=TaskRequestType.summarize_story,
type=str,
)
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def task_modal(ctx: lightbulb.SlashContext):
"""Request a task from the backend."""
# typ: str = ctx.options.type
view = ModalView(
modal_title="Assistant Response",
task="Please explain the moon landing to a six year old.",
timeout=MAX_TASK_TIME,
)
resp = await ctx.respond(
"Task - Respond to the prompt as if you were the Assistant:",
flags=hikari.MessageFlag.EPHEMERAL,
components=view,
)
await view.start(await resp.message())
class RatingView(miru.View):
"""View for rating a task."""
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
super().__init__(*args, **kwargs)
self.presses: list[str] = []
def _close_if_all_pressed(self) -> None:
if len(self.presses) == 5:
self.stop()
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("1")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("2")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("3")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("4")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
if button.label not in self.presses:
self.presses.append("5")
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
self._close_if_all_pressed()
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.presses = []
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
class SelectRating(miru.View):
"""View for rating a task with a select menu."""
@miru.select(
options=[
hikari.SelectMenuOption(
label="1",
value="1",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="2",
value="2",
description=None,
emoji=None,
is_default=False,
),
hikari.SelectMenuOption(
label="3",
value="3",
description=None,
emoji=None,
is_default=False,
),
],
placeholder="Select the good responses",
min_values=0,
max_values=3,
row=3,
)
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
@plugin.command
@lightbulb.command("rating_task", "Rate stuff.")
@lightbulb.implements(lightbulb.SlashCommand)
async def rating_task(ctx: lightbulb.SlashContext):
"""Rate stuff."""
# Message Based rating
await ctx.respond(
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
)
try:
event = await ctx.bot.wait_for(
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
)
except asyncio.TimeoutError:
await ctx.respond("Timed out waiting for response")
return
if event.content is None:
await ctx.respond("No content in message")
return
ratings = event.content.replace(" ", "").split(",")
# Check if the ratings are valid
if len(ratings) != 5:
await ctx.respond("Invalid number of ratings")
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
await ctx.respond("Invalid rating")
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
# Button Based rating
view = RatingView(timeout=MAX_TASK_TIME)
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
await view.start(await resp.message())
await view.wait()
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
await resp.delete()
# Select Based rating
select_view = SelectRating(timeout=MAX_TASK_TIME)
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
await select_view.start(await resp_2.message())
await select_view.wait()
await resp_2.delete()
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+451
View File
@@ -0,0 +1,451 @@
# -*- coding: utf-8 -*-
"""Work plugin for collecting user data."""
import asyncio
import typing as t
from datetime import datetime
import hikari
import lightbulb
import lightbulb.decorators
import miru
from aiosqlite import Connection
from bot.api_client import OasstApiClient, TaskType
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
@plugin.command
@lightbulb.option(
"type",
"The type of task to request.",
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
required=False,
default=str(TaskRequestType.random),
type=str,
)
@lightbulb.command("work", "Complete a task.")
@lightbulb.implements(lightbulb.SlashCommand)
async def work(ctx: lightbulb.SlashContext):
"""Create and handle a task."""
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
logger.debug(f"Starting task_type: {task_type!r}")
await _handle_task(ctx, task_type)
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
"""Handle creating and collecting user input for a task.
Continually present tasks to the user until they select one, cancel, or time out.
If they select one, present the task steps until a `task_done` task is received.
Finally, ask the user if they want to perform another task (of the same type).
"""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
# Continue to complete tasks until the user doesn't want to do another
done = False
while not done:
# Loop until the user accepts a task
task, msg_id = await _select_task(ctx, task_type)
if task is None:
return
# Task action loop
completed = False
while not completed:
await ctx.author.send("Please type your response here:")
try:
event = await ctx.bot.wait_for(
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
)
except asyncio.TimeoutError:
await ctx.author.send("Task timed out. Exiting")
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
if event.content is None or not _validate_user_input(event.content, task):
await ctx.author.send("Invalid response")
continue
logger.debug(f"Successful user input received: {event.content}")
# Send the response to the backend
reply = protocol_schema.TextReplyToMessage(
message_id=str(msg_id),
user_message_id=str(event.message_id),
user=protocol_schema.User(
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
),
text=event.content,
)
logger.debug(f"Sending reply to backend: {reply!r}")
# Get next task
new_task = await oasst_api.post_interaction(reply)
logger.info(f"New task {new_task}")
if new_task.type == TaskType.done:
await ctx.author.send("Task completed")
completed = True
continue
else:
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in the log channel that the task is complete
# TODO: Maybe do something with the msg ID so users can rate the "answer"
assert ctx.guild_id is not None
conn: Connection = ctx.bot.d.db
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
if guild_settings is not None and guild_settings.log_channel_id is not None:
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
assert isinstance(channel, hikari.TextableChannel) # option converter
done_embed = (
hikari.Embed(
title="Task Completion",
description=f"`{task.type}` completed by {ctx.author.mention}",
color=hikari.Color(0x00FF00),
timestamp=datetime.now().astimezone(),
)
.add_field("Total Tasks", "0", inline=True)
.add_field("Server Ranking", "0/0", inline=True)
.add_field("Global Ranking", "0/0", inline=True)
.set_footer(f"Task ID: {task.id}")
)
await channel.send(EMPTY, embed=done_embed)
# ask the user if they want to do another task
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send("Would you like another task?", components=choice_view)
await choice_view.start(msg)
await choice_view.wait()
match choice_view.choice:
case False | None:
done = True
await ctx.author.send("Exiting, goodbye!")
case True:
pass
async def _select_task(
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
) -> tuple[protocol_schema.Task | None, str]:
"""Present tasks to the user until they accept one, cancel, or time out."""
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
logger.debug(f"Starting task selection for {task_type}")
# Loop until the user accepts a task, cancels, or times out
while True:
logger.debug(f"Requesting task of type {task_type}")
task = await oasst_api.fetch_task(task_type, user)
resp, msg_id = await _send_task(ctx, task)
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
await oasst_api.ack_task(task.id, msg_id)
return task, msg_id
case "next":
logger.info(f"Task {task.id} rejected, sending NACK")
await oasst_api.nack_task(task.id, "rejected")
await ctx.author.send("Sending next task...")
continue
case "cancel":
logger.info(f"Task {task.id} canceled, sending NACK")
await oasst_api.nack_task(task.id, "canceled")
await ctx.author.send("Task canceled. Exiting")
return None, msg_id
case None:
logger.info(f"Task {task.id} timed out, sending NACK")
await oasst_api.nack_task(task.id, "timed out")
await ctx.author.send("Task timed out. Exiting")
return None, msg_id
async def _send_task(
ctx: lightbulb.SlashContext, task: protocol_schema.Task
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
"""Send a task to the user.
Returns the user's choice and the message ID of the task message.
"""
# The clean way to do this would be to attach a `to_embed` method to the task classes
# but the tasks aren't discord specific so that doesn't really make sense.
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.debug("sending initial prompt task")
embed = _initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.debug("sending rank initial prompt task")
embed = _rank_initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_prompter_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
logger.debug("sending rank user reply task")
embed = _rank_prompter_reply_embed(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.debug("sending rank assistant reply task")
embed = _rank_assistant_reply_embed(task)
elif task.type == TaskRequestType.prompter_reply:
assert isinstance(task, protocol_schema.PrompterReplyTask)
logger.debug("sending user reply task")
embed = _prompter_reply_embed(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.debug("sending assistant reply task")
embed = _assistant_reply_embed(task)
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
msg = await ctx.author.send(
EMPTY,
embed=embed,
components=view,
)
assert msg is not None
await view.start(msg)
await view.wait()
return view.choice, str(msg.id)
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
"""Returns whether the user's input is valid for the task type."""
if content is None:
return False
# User message input
if (
task.type == TaskRequestType.initial_prompt
or task.type == TaskRequestType.prompter_reply
or task.type == TaskRequestType.assistant_reply
):
assert isinstance(
task,
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
)
return len(content) > 0
# Ranking tasks
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
num_replies = len(task.replies)
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
num_prompts = len(task.prompts)
rankings = content.split(",")
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
elif task.type == TaskRequestType.summarize_story:
raise NotImplementedError
elif task.type == TaskRequestType.rate_summary:
raise NotImplementedError
else:
logger.critical(f"Unknown task type {task.type}")
raise ValueError(f"Unknown task type {task.type}")
class TaskAcceptView(miru.View):
"""View with three buttons: accept, next, and cancel.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: t.Literal["accept", "next", "cancel"] | None = None
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Accept button pressed")
self.choice = "accept"
self.stop()
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Next button pressed")
self.choice = "next"
self.stop()
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
logger.info("Cancel button pressed")
self.choice = "cancel"
self.stop()
class ChoiceView(miru.View):
"""View with two buttons: yes and no.
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
"""
choice: bool | None = None
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = True
self.stop()
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
self.choice = False
self.stop()
################################################################
# Template Embeds #
################################################################
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
return (
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
.set_footer(text=f"OASST Assistant | {task.id}")
)
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank Initial Prompt",
description="Rank the following tasks from best to worst (1,2,3,4,5)",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, prompt in enumerate(task.prompts):
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
return embed
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank User Reply",
description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, reply in enumerate(task.replies):
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
return embed
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="Rank Assistant Reply",
description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4",
timestamp=datetime.now().astimezone(),
)
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for i, reply in enumerate(task.replies):
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
return embed
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="User Reply",
description=f"""\
Send the next message in the conversation as if you were the user.
{'Hint: ' if task.hint else ''}
""",
timestamp=datetime.now().astimezone(),
)
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for message in task.conversation.messages:
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
return embed
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
embed = (
hikari.Embed(
title="User Reply",
description="Send the next message in the conversation as if you were the user.",
timestamp=datetime.now().astimezone(),
)
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
.set_footer(text=f"OASST Assistant | {task.id}")
)
for message in task.conversation.messages:
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
return embed
def load(bot: lightbulb.BotApp):
"""Add the plugin to the bot."""
bot.add_plugin(plugin)
def unload(bot: lightbulb.BotApp):
"""Remove the plugin to the bot."""
bot.remove_plugin(plugin)
+18
View File
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""Configuration for the bot."""
from pydantic import BaseSettings, Field
class Settings(BaseSettings):
"""Settings for the bot."""
bot_token: str = Field(env="BOT_TOKEN", default="")
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
prefix: str = Field(env="PREFIX", default="./")
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
class Config(BaseSettings.Config):
env_file = ".env"
case_sensitive = False
+48
View File
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
"""Utility functions."""
import typing as t
from datetime import datetime
import hikari
def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str:
"""Format a datetime object into the discord time format.
```
| t | HH:MM | 16:20
| T | HH:MM:SS | 16:20:11
| D | D Mo Yr | 20 April 2022
| f | D Mo Yr HH:MM | 20 April 2022 16:20
| F | W, D Mo Yr HH:MM | Wednesday, 20 April 2022 16:20
| R | relative | in an hour
```
"""
match fmt:
case "t" | "T" | "D" | "f" | "F" | "R":
return f"<t:{dt.timestamp():.0f}:{fmt}>"
case _:
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
EMPTY = "\u200d"
"""Zero-width joiner.
This appears as an empty message in Discord.
"""
def mention(
id: hikari.Snowflakeish,
type: t.Literal["channel", "role", "user"],
) -> str:
"""Mention an object."""
match type:
case "channel":
return f"<#{id}>"
case "user":
return f"<@{id}>"
case "role":
return f"<@&{id}>"
-61
View File
@@ -1,61 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from abc import ABC
from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@dataclass
class ReplyHandlerInfo:
msg_id: int
handler_task: asyncio.Task
handler: ChannelHandlerBase
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
bot_channel: discord.TextChannel
templates: MessageTemplates
reply_handlers: dict[int, ReplyHandlerInfo]
def __init__(self):
self.reply_handlers = {} # handlers by msg_id
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
) -> discord.Message:
if channel is None:
self.ensure_bot_channel()
channel = self.bot_channel
return await channel.send(content=content, view=view)
async def post_template(
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
) -> discord.Message:
logger.debug(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view=view, channel=channel)
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
if msg_id in self.reply_handlers:
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
task.add_done_callback(lambda t: handler.on_completed())
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
-15
View File
@@ -1,15 +0,0 @@
# -*- coding: utf-8 -*-
from pydantic import AnyHttpUrl, BaseSettings
class BotSettings(BaseSettings):
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
API_KEY: str = "any_key"
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
-88
View File
@@ -1,88 +0,0 @@
# -*- coding: utf-8 -*-
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
import discord
from loguru import logger
class ChannelExpiredException(Exception):
pass
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool = False
expiry_date: datetime
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.queue = asyncio.Queue()
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
if self.expired:
raise ChannelExpiredException()
msg = await self.queue.get()
if msg is None:
if self.expired:
raise ChannelExpiredException()
else:
raise RuntimeError("Unexpected None message read")
return msg
def on_reply(self, message: discord.Message) -> None:
self.queue.put_nowait(message)
def on_expire(self) -> None:
logger.info("ChannelHandler: on_expire")
self.expired = True
self.queue.put_nowait(None)
def on_completed(self) -> None:
logger.info("ChannelHandler: on_completed")
self.completed = True
def tick(self, now: datetime):
if now > self.expiry_date and not self.expired:
self.on_expire()
@abstractmethod
async def handler_loop(self):
...
async def finalize(self):
pass
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
async def read(self) -> discord.Message:
try:
return await super().read()
except ChannelExpiredException:
await self.cleanup()
raise
async def cleanup(self):
logger.debug("AutoDestructThreadHandler.cleanup")
if self.thread:
logger.debug(f"deleting thread: {self.thread.name}")
await self.thread.delete()
self.thread = None
if self.first_message:
logger.debug(f"deleting first_message: {self.first_message.content}")
await self.first_message.delete()
self.first_message = None
async def finalize(self):
await self.cleanup()
return await super().finalize()
+8 -3
View File
@@ -1,16 +1,21 @@
# -*- coding: utf-8 -*-
"""Message templates for the discord bot."""
import typing
import jinja2
from loguru import logger
class MessageTemplates:
def __init__(self, template_dir="./templates"):
self.env = jinja2.Environment(
"""Create message templates for the discord bot."""
def __init__(self, template_dir: str = "./templates"):
self.env = jinja2.Environment( # noqa: S701
loader=jinja2.FileSystemLoader(template_dir),
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
)
def render(self, template_name, **kwargs):
def render(self, template_name: str, **kwargs: typing.Any):
template = self.env.get_template(template_name)
txt = template.render(kwargs)
logger.debug(txt)
+11 -7
View File
@@ -1,7 +1,11 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
aiohttp # http client
aiohttp[speedups] # speedups for aiohttp
aiosqlite # database
hikari # discord framework
hikari-lightbulb # command handler
hikari-miru # modals and buttons
hikari[speedups]
loguru
pydantic
uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop
-267
View File
@@ -1,267 +0,0 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
name = discord.ui.TextInput(label="Name")
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
pass
@abstractmethod
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_teaser_msg(self, template_name: str):
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
return await self.bot.post_template(
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToMessage(
message_id=str(self.first_message.id),
user_message_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_text_reply_to_post()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.MessageRanking(
message_id=str(self.first_message.id),
user_message_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
ranking=ranking,
)
)
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
ranking_str = user_msg.content
ranking = [int(x) - 1 for x in ranking_str.split(",")]
self.post_ranking(user_msg, ranking=ranking)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_ranking()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_summarize_story.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
task: protocol_schema.InitialPromptTask
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_initial_prompt.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class PrompterReplyHandler(ChannelTaskBase):
task: protocol_schema.PrompterReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_prompter_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_prompter_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
task: protocol_schema.AssistantReplyTask
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_assistant_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class RankInitialPromptsHandler(ChannelTaskBase):
task: protocol_schema.RankInitialPromptsTask
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RankConversationsHandler(ChannelTaskBase):
task: protocol_schema.RankConversationRepliesTask
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class RateSummaryHandler(ChannelTaskBase):
task: protocol_schema.RateSummaryTask
thread_name: str = "Ratings"
async def _rating_response_handler(self, score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
if self.thread:
try:
self.backend.post_interaction(
protocol_schema.MessageRating(
message_id=str(self.first_message.id),
user_message_id=str(interaction.id),
user=self.to_api_user(interaction.user),
rating=score,
)
)
await interaction.response.send_message(
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
)
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in _rating_response_handler()")
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rate_summary.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
logger.info(f"on_rate_summary_reply: {msg.content}")
await msg.add_reaction("")
await msg.reply("❌ Text intput not supported.")
-52
View File
@@ -1,52 +0,0 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)