mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Merge branch 'main' of https://github.com/AlexanderHOtt/Open-Assistant into AlexanderHOtt-main
This commit is contained in:
@@ -55,7 +55,7 @@ def generate_task(
|
||||
hint="Ask the assistant about a current event." # this is optional
|
||||
)
|
||||
case protocol_schema.TaskRequestType.prompter_reply:
|
||||
logger.info("Generating a UserReplyTask.")
|
||||
logger.info("Generating a PrompterReplyTask.")
|
||||
messages = pr.fetch_random_conversation("assistant")
|
||||
task_messages = [
|
||||
protocol_schema.ConversationMessage(
|
||||
@@ -86,7 +86,7 @@ def generate_task(
|
||||
messages = pr.fetch_random_initial_prompts()
|
||||
task = protocol_schema.RankInitialPromptsTask(prompts=[msg.payload.payload.text for msg in messages])
|
||||
case protocol_schema.TaskRequestType.rank_prompter_replies:
|
||||
logger.info("Generating a RankUserRepliesTask.")
|
||||
logger.info("Generating a RankPrompterRepliesTask.")
|
||||
conversation, replies = pr.fetch_multiple_random_replies(message_role="assistant")
|
||||
|
||||
task_messages = [
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
BOT_TOKEN=<discord bot token>
|
||||
DECLARE_GLOBAL_COMMANDS=<testing guild id>
|
||||
OWNER_IDS=[<your user id>, <other user ids>]
|
||||
PREFIX="./"
|
||||
|
||||
OASST_API_URL="http://localhost:8080" # No trailing '/'
|
||||
OASST_API_KEY=""
|
||||
@@ -1,3 +1,10 @@
|
||||
.env
|
||||
*.egg-info/
|
||||
__pycache__/
|
||||
|
||||
.venv
|
||||
.nox
|
||||
.env
|
||||
|
||||
# Database files
|
||||
*.db
|
||||
|
||||
+105
-5
@@ -1,20 +1,120 @@
|
||||
# Open-Assistant Data Collection Discord Bot
|
||||
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large langugae model. You and other people can teach the bot how to respond to user requests by demonstration and by garding and ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
This bot collects human feedback to create a dataset for RLHF-alignment of an assistant chat bot based on a large language model. You and other people can teach the bot how to respond to user requests by demonstration and by ranking the bot's outputs. If you want to learn more about RLHF please refer [to OpenAI's InstructGPT blog post](https://openai.com/blog/instruction-following/).
|
||||
|
||||
## Invite official bot
|
||||
|
||||
To add the official Open-Assistant data collection bot to your discord server [click here](https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot). The bot needs access to read the contents of user text messages.
|
||||
|
||||
## Bot token for development
|
||||
## Contributing
|
||||
|
||||
If you are unfamiliar with `hikari`, `lightbulb`, or `miru`, please refer to the [large list of examples](https://gist.github.com/AlexanderHOtt/7805843a7120f755938a3b75d680d2e7)
|
||||
|
||||
### Setup
|
||||
|
||||
To run the bot
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
|
||||
python -V # 3.10
|
||||
|
||||
pip install -r requirements.txt
|
||||
python -m bot
|
||||
```
|
||||
|
||||
Before you push, make sure the `pre-commit` hooks are installed and run successfully.
|
||||
|
||||
```bash
|
||||
pip install pre-commit
|
||||
pre-commit install
|
||||
|
||||
...
|
||||
|
||||
git add .
|
||||
git commit -m "<good commit message>"
|
||||
# if the pre-commit fails
|
||||
git add .
|
||||
git commit -m "<good commit message>"
|
||||
```
|
||||
|
||||
To test the bot on your own discord server you need to register a discord application at the [Discord Developer Portal](https://discord.com/developers/applications) and get at bot token.
|
||||
|
||||
1. Follow a tutorial on how to get a bot token, for example this one: [Creating a discord bot & getting a token](https://github.com/reactiflux/discord-irc/wiki/Creating-a-discord-bot-&-getting-a-token)
|
||||
2. The bot script expects the bot token to be in an environment variable called `BOT_TOKEN`.
|
||||
2. The bot script expects the bot token to be in the `.env` file under the `TOKEN` variable.
|
||||
|
||||
The simplest way to configure the token is via an `.env` file:
|
||||
### Resources
|
||||
|
||||
#### Structure
|
||||
|
||||
Important files
|
||||
|
||||
```graphql
|
||||
.env # Environment variables
|
||||
.env.example # Example environment variables
|
||||
CONTRIBUTING.md # This file
|
||||
README.md # Project readme
|
||||
EXAMPLES.md # Examples for commands and listeners
|
||||
requirements.txt # Requirements
|
||||
|
||||
bot/
|
||||
├─ __main__.py # Entrypoint
|
||||
├─ api_client.py # API Client for interacting with the backend
|
||||
├─ bot.py # Main bot class
|
||||
├─ settings.py # Settings and secrets
|
||||
├─ utils.py # Utility Functions
|
||||
│
|
||||
├─ db/ # Database related code
|
||||
│ ├─ database.db # SQLite database
|
||||
│ ├─ schema.sql # SQL schema
|
||||
│ └─ schemas.py # Python table schemas
|
||||
│
|
||||
└── extensions/ # Application logic, see https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
|
||||
├─ work.py # Task handling logic <-- most important file
|
||||
├─ guild_settings.py # Server specific settings
|
||||
└─ hot_reload.py # Utility for hot reload extensions during development
|
||||
```
|
||||
BOT_TOKEN=XYZABC123...
|
||||
|
||||
#### Adding a new command/listener
|
||||
|
||||
1. Create a new file in the `extensions` folder
|
||||
2. Copy the template below
|
||||
|
||||
```py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""My plugin."""
|
||||
import lightbulb
|
||||
|
||||
plugin = lightbulb.Plugin("MyPlugin")
|
||||
|
||||
# Add your commands here
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
```
|
||||
|
||||
#### Docs
|
||||
|
||||
Discord
|
||||
|
||||
- [Discord API Reference](https://discord.com/developers/docs/intro)
|
||||
|
||||
`hikari` (main framework)
|
||||
|
||||
- [Hikari Repo](https://github.com/hikari-py/hikari)
|
||||
- [Hikari Docs](https://docs.hikari-py.dev/en/latest/)
|
||||
|
||||
`lightbulb` (command handler)
|
||||
|
||||
- [Lightbulb Repo](https://github.com/tandemdude/hikari-lightbulb)
|
||||
- [Lightbulb Docs](https://hikari-lightbulb.readthedocs.io/en/latest/)
|
||||
|
||||
`miru` (component handler: buttons, modals, etc... )
|
||||
|
||||
- [Miru Repo](https://github.com/HyperGH/hikari-miru)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from bot import OpenAssistantBot
|
||||
from bot_settings import settings
|
||||
|
||||
# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot
|
||||
|
||||
if __name__ == "__main__":
|
||||
bot = OpenAssistantBot(
|
||||
settings.BOT_TOKEN,
|
||||
bot_channel_name=settings.BOT_CHANNEL_NAME,
|
||||
backend_url=settings.BACKEND_URL,
|
||||
api_key=settings.API_KEY,
|
||||
owner_id=settings.OWNER_ID,
|
||||
template_dir=settings.TEMPLATE_DIR,
|
||||
debug=settings.DEBUG,
|
||||
)
|
||||
bot.run()
|
||||
@@ -1,79 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
import requests
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
class TaskType(str, enum.Enum):
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class ApiClient:
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
task_models_map: dict[str, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
self.task_models_map = task_models_map
|
||||
|
||||
def post(self, path: str, json: dict) -> dict:
|
||||
response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def _parse_task(self, data: dict) -> protocol_schema.Task:
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("dict expected")
|
||||
|
||||
task_type = data.get("type")
|
||||
if task_type not in self.task_models_map:
|
||||
raise RuntimeError(f"Unsupported task type: {task_type}")
|
||||
|
||||
return self.task_models_map[task_type].parse_obj(data)
|
||||
|
||||
def fetch_task(
|
||||
self,
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
) -> protocol_schema.Task:
|
||||
req = protocol_schema.TaskRequest(type=task_type, user=user, collective=collective)
|
||||
data = self.post("/api/v1/tasks/", req.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
) -> protocol_schema.Task:
|
||||
return self.fetch_task(protocol_schema.TaskRequestType.random, user, collective=collective)
|
||||
|
||||
def ack_task(self, task_id: str, message_id: str) -> None:
|
||||
req = protocol_schema.TaskAck(message_id=message_id)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict())
|
||||
|
||||
def nack_task(self, task_id: str, reason: str) -> None:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
@@ -1,283 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import discord
|
||||
import task_handlers
|
||||
from api_client import ApiClient, TaskType
|
||||
from bot_base import BotBase
|
||||
from discord import app_commands
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import get_git_head_hash, utcnow
|
||||
|
||||
__version__ = "0.0.3"
|
||||
BOT_NAME = "Open-Assistant Junior"
|
||||
|
||||
|
||||
class OpenAssistantBot(BotBase):
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str,
|
||||
bot_channel_name: str,
|
||||
backend_url: str,
|
||||
api_key: str,
|
||||
owner_id: Optional[Union[int, str]] = None,
|
||||
template_dir: str = "./templates",
|
||||
debug: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.template_dir = Path(template_dir)
|
||||
self.bot_channel_name = bot_channel_name
|
||||
self.templates = MessageTemplates(template_dir)
|
||||
self.debug = debug
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
if isinstance(owner_id, str):
|
||||
owner_id = int(owner_id)
|
||||
self.owner_id = owner_id
|
||||
|
||||
self.bot_token = bot_token
|
||||
client = discord.Client(intents=intents)
|
||||
self.client = client
|
||||
self.loop = client.loop
|
||||
|
||||
self.bot_channel: discord.TextChannel = None
|
||||
self.backend = ApiClient(backend_url, api_key)
|
||||
|
||||
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
|
||||
|
||||
@client.event
|
||||
async def on_ready():
|
||||
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
|
||||
logger.info(f"{client.user} is now running!")
|
||||
|
||||
await self.delete_all_old_bot_messages()
|
||||
# if self.debug:
|
||||
# await self.post_boot_message()
|
||||
await self.post_welcome_message()
|
||||
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message):
|
||||
# ignore own messages
|
||||
if message.author != client.user:
|
||||
await self.handle_message(message)
|
||||
|
||||
@self.tree.command()
|
||||
async def tutorial(interaction: discord.Interaction):
|
||||
"""Start the Open-Assistant tutorial via DMs."""
|
||||
|
||||
dm = await self.client.create_dm(discord.Object(interaction.user.id))
|
||||
await dm.send("Tutorial coming soon... :-)")
|
||||
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
|
||||
|
||||
@self.tree.command()
|
||||
async def help(interaction: discord.Interaction):
|
||||
"""Sends the user a list of all available commands"""
|
||||
await self.post_help(interaction.user)
|
||||
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
|
||||
|
||||
@self.tree.command()
|
||||
async def work(interaction: discord.Interaction):
|
||||
"""Request a new personalized task"""
|
||||
|
||||
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
q = task_handlers.Questionnaire()
|
||||
await interaction.response.send_modal(q)
|
||||
|
||||
async def post_help(self, user: discord.abc.User) -> discord.Message:
|
||||
is_bot_owner = user.id == self.owner_id
|
||||
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
|
||||
|
||||
async def post_boot_message(self) -> discord.Message:
|
||||
return await self.post_template(
|
||||
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
||||
)
|
||||
|
||||
async def post_welcome_message(self) -> discord.Message:
|
||||
return await self.post_template("welcome.msg")
|
||||
|
||||
async def delete_all_old_bot_messages(self) -> None:
|
||||
logger.info("Deleting old threads...")
|
||||
for thread in self.bot_channel.threads:
|
||||
if thread.owner_id == self.client.user.id:
|
||||
await thread.delete()
|
||||
logger.info("Completed deleting old theards.")
|
||||
|
||||
logger.info("Deleting old messages...")
|
||||
look_until = utcnow() - timedelta(days=365)
|
||||
async for msg in self.bot_channel.history(limit=None):
|
||||
msg: discord.Message
|
||||
if msg.created_at < look_until:
|
||||
break
|
||||
if msg.author.id == self.client.user.id:
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old messages.")
|
||||
|
||||
async def next_task(self):
|
||||
task_type = protocol_schema.TaskRequestType.random
|
||||
task = self.backend.fetch_task(task_type, user=None)
|
||||
|
||||
handler: task_handlers.ChannelTaskBase = None
|
||||
match task.type:
|
||||
case TaskType.summarize_story:
|
||||
handler = task_handlers.SummarizeStoryHandler()
|
||||
case TaskType.rate_summary:
|
||||
handler = task_handlers.RateSummaryHandler()
|
||||
case TaskType.initial_prompt:
|
||||
handler = task_handlers.InitialPromptHandler()
|
||||
case TaskType.prompter_reply:
|
||||
handler = task_handlers.PrompterReplyHandler()
|
||||
case TaskType.assistant_reply:
|
||||
handler = task_handlers.AssistantReplyHandler()
|
||||
case TaskType.rank_initial_prompts:
|
||||
handler = task_handlers.RankInitialPromptsHandler()
|
||||
case TaskType.rank_prompter_replies | TaskType.rank_assistant_replies:
|
||||
handler = task_handlers.RankConversationsHandler()
|
||||
case _:
|
||||
logger.warning(f"Unsupported task type received: {task.type}")
|
||||
self.backend.nack_task(task.id, "not supported")
|
||||
|
||||
if handler:
|
||||
try:
|
||||
logger.info(f"strarting task {task.id}")
|
||||
msg = await handler.start(self, task)
|
||||
self.backend.ack_task(task.id, msg.id)
|
||||
except Exception:
|
||||
logger.exception("Starting task failed.")
|
||||
self.backend.nack_task(task.id, "faled")
|
||||
|
||||
async def background_timer(self):
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
next_fetch_task = utcnow() + timedelta(seconds=1)
|
||||
while True:
|
||||
now = utcnow()
|
||||
|
||||
if self.bot_channel:
|
||||
if now > next_fetch_task:
|
||||
next_fetch_task = utcnow() + timedelta(seconds=60)
|
||||
|
||||
try:
|
||||
await self.next_task()
|
||||
except Exception:
|
||||
logger.exception("fetching next task failed")
|
||||
|
||||
for x in self.reply_handlers.values():
|
||||
x.handler.tick(now)
|
||||
|
||||
if now > next_remove_completed:
|
||||
next_remove_completed = utcnow() + timedelta(seconds=10)
|
||||
await self.remove_completed_handlers()
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _sync(self, command: str, message: discord.Message):
|
||||
|
||||
logger.info(f"sync tree command received: {command}")
|
||||
|
||||
if command == "sync.copy_global":
|
||||
await self.tree.copy_global_to(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.clear_guild":
|
||||
self.tree.clear_commands(guild=message.guild)
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
elif command == "sync.guild":
|
||||
synced = await self.tree.sync(guild=message.guild)
|
||||
else:
|
||||
synced = await self.tree.sync()
|
||||
|
||||
logger.info(f"Synced {len(synced)} commands")
|
||||
await message.reply(f"Synced {len(synced)} commands")
|
||||
|
||||
async def handle_command(self, message: discord.Message, is_owner: bool):
|
||||
command_text: str = message.content
|
||||
command_text = command_text[1:]
|
||||
match command_text:
|
||||
case "help" | "?":
|
||||
await self.post_help(user=message.author)
|
||||
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
|
||||
if is_owner:
|
||||
await self._sync(command_text, message)
|
||||
case _:
|
||||
await message.reply(f"unknown command: {command_text}")
|
||||
|
||||
def recipient_filter(self, message: discord.Message) -> bool:
|
||||
channel = message.channel
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.private
|
||||
or message.channel.type == discord.ChannelType.private_thread
|
||||
):
|
||||
return True
|
||||
|
||||
if (
|
||||
message.channel.type == discord.ChannelType.text
|
||||
or message.channel.type == discord.ChannelType.public_thread
|
||||
):
|
||||
while channel:
|
||||
if self.bot_channel and channel.id == self.bot_channel.id:
|
||||
return True
|
||||
channel = channel.parent
|
||||
|
||||
return False
|
||||
|
||||
async def handle_message(self, message: discord.Message):
|
||||
if not self.recipient_filter(message):
|
||||
return
|
||||
|
||||
user_id = message.author.id
|
||||
user_display_name = message.author.name
|
||||
|
||||
logger.debug(
|
||||
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
|
||||
)
|
||||
|
||||
command_prefix = "!"
|
||||
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
|
||||
is_owner = self.owner_id and user_id == self.owner_id
|
||||
await self.handle_command(message, is_owner)
|
||||
|
||||
if isinstance(message.channel, discord.Thread):
|
||||
handler = self.reply_handlers.get(message.channel.id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
if message.reference:
|
||||
handler = self.reply_handlers.get(message.reference.message_id)
|
||||
if handler and not handler.handler.completed:
|
||||
handler.handler.on_reply(message)
|
||||
|
||||
async def remove_completed_handlers(self):
|
||||
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
|
||||
if len(completed) == 0:
|
||||
return
|
||||
|
||||
for c in completed:
|
||||
handler = self.reply_handlers[c]
|
||||
del self.reply_handlers[c]
|
||||
try:
|
||||
await handler.handler.finalize()
|
||||
except Exception:
|
||||
logger.exception("handler finalize failed")
|
||||
|
||||
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
|
||||
|
||||
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
|
||||
for channel in self.client.get_all_channels():
|
||||
if channel.type == discord.ChannelType.text and channel.name == channel_name:
|
||||
return channel
|
||||
|
||||
def run(self):
|
||||
"""Run bot loop blocking."""
|
||||
self.client.run(self.bot_token)
|
||||
@@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The official Open-Assistant Discord Bot."""
|
||||
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Entry point for the bot."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
from bot.bot import bot
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.name != "nt":
|
||||
import uvloop
|
||||
|
||||
uvloop.install()
|
||||
|
||||
logger.info("Starting bot")
|
||||
bot.run()
|
||||
@@ -0,0 +1,113 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
import enum
|
||||
import typing as t
|
||||
from typing import Optional, Type
|
||||
from uuid import UUID
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
|
||||
|
||||
# TODO: Move to `protocol`?
|
||||
class TaskType(str, enum.Enum):
|
||||
"""Task types."""
|
||||
|
||||
summarize_story = "summarize_story"
|
||||
rate_summary = "rate_summary"
|
||||
initial_prompt = "initial_prompt"
|
||||
prompter_reply = "prompter_reply"
|
||||
assistant_reply = "assistant_reply"
|
||||
rank_initial_prompts = "rank_initial_prompts"
|
||||
rank_prompter_replies = "rank_prompter_replies"
|
||||
rank_assistant_replies = "rank_assistant_replies"
|
||||
done = "task_done"
|
||||
|
||||
|
||||
class OasstApiClient:
|
||||
"""API Client for interacting with the OASST backend."""
|
||||
|
||||
def __init__(self, backend_url: str, api_key: str):
|
||||
"""Create a new OasstApiClient.
|
||||
|
||||
Args:
|
||||
----
|
||||
backend_url (str): The base backend URL.
|
||||
api_key (str): The API key to use for authentication.
|
||||
"""
|
||||
logger.debug("Opening OasstApiClient session")
|
||||
self.session = aiohttp.ClientSession()
|
||||
self.backend_url = backend_url
|
||||
self.api_key = api_key
|
||||
|
||||
self.task_models_map: dict[TaskType, Type[protocol_schema.Task]] = {
|
||||
TaskType.summarize_story: protocol_schema.SummarizeStoryTask,
|
||||
TaskType.rate_summary: protocol_schema.RateSummaryTask,
|
||||
TaskType.initial_prompt: protocol_schema.InitialPromptTask,
|
||||
TaskType.prompter_reply: protocol_schema.PrompterReplyTask,
|
||||
TaskType.assistant_reply: protocol_schema.AssistantReplyTask,
|
||||
TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask,
|
||||
TaskType.rank_prompter_replies: protocol_schema.RankPrompterRepliesTask,
|
||||
TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask,
|
||||
TaskType.done: protocol_schema.TaskDone,
|
||||
}
|
||||
|
||||
async def post(self, path: str, data: dict[str, t.Any]) -> dict[str, t.Any]:
|
||||
"""Make a POST request to the backend."""
|
||||
logger.debug(f"POST {self.backend_url}{path} DATA: {data}")
|
||||
response = await self.session.post(f"{self.backend_url}{path}", json=data, headers={"X-API-Key": self.api_key})
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
|
||||
def _parse_task(self, data: dict[str, t.Any]) -> protocol_schema.Task:
|
||||
task_type = TaskType(data.get("type"))
|
||||
|
||||
model = self.task_models_map.get(task_type)
|
||||
if not model:
|
||||
logger.error(f"Unsupported task type: {task_type}")
|
||||
raise ValueError(f"Unsupported task type: {task_type}")
|
||||
return self.task_models_map[task_type].parse_obj(data) # type: ignore
|
||||
|
||||
async def fetch_task(
|
||||
self,
|
||||
task_type: protocol_schema.TaskRequestType,
|
||||
user: Optional[protocol_schema.User] = None,
|
||||
collective: bool = False,
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a task from the backend."""
|
||||
logger.debug(f"Fetching task {task_type} for user {user}")
|
||||
req = protocol_schema.TaskRequest(type=task_type.value, user=user, collective=collective)
|
||||
resp = await self.post("/api/v1/tasks/", data=req.dict())
|
||||
logger.debug(f"RESP {resp}")
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def fetch_random_task(
|
||||
self, user: Optional[protocol_schema.User] = None, collective: bool = False
|
||||
) -> protocol_schema.Task:
|
||||
"""Fetch a random task from the backend."""
|
||||
logger.debug(f"Fetching random for user {user}")
|
||||
return await self.fetch_task(protocol_schema.TaskRequestType.random, user, collective)
|
||||
|
||||
async def ack_task(self, task_id: str | UUID, message_id: str):
|
||||
"""Send an ACK for a task to the backend."""
|
||||
logger.debug(f"ACK task {task_id} with post {message_id}")
|
||||
req = protocol_schema.TaskAck(message_id=message_id)
|
||||
return await self.post(f"/api/v1/tasks/{task_id}/ack", data=req.dict())
|
||||
|
||||
async def nack_task(self, task_id: str | UUID, reason: str):
|
||||
"""Send a NACK for a task to the backend."""
|
||||
logger.debug(f"NACK task {task_id} with reason {reason}")
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return await self.post(f"/api/v1/tasks/{task_id}/nack", data=req.dict())
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
"""Send a completed task to the backend."""
|
||||
logger.debug(f"Interaction: {interaction}")
|
||||
resp = await self.post("/api/v1/tasks/interaction", data=interaction.dict())
|
||||
|
||||
return self._parse_task(resp)
|
||||
|
||||
async def close(self):
|
||||
logger.debug("Closing OasstApiClient session")
|
||||
await self.session.close()
|
||||
@@ -0,0 +1,40 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Bot logic."""
|
||||
import aiosqlite
|
||||
import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from bot.api_client import OasstApiClient
|
||||
from bot.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# TODO: Revisit cache settings
|
||||
bot = lightbulb.BotApp(
|
||||
token=settings.bot_token,
|
||||
logs="DEBUG",
|
||||
prefix=settings.prefix,
|
||||
default_enabled_guilds=settings.declare_global_commands,
|
||||
owner_ids=settings.owner_ids,
|
||||
intents=hikari.Intents.ALL,
|
||||
)
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_starting(event: hikari.StartingEvent):
|
||||
"""Setup."""
|
||||
miru.install(bot) # component handler
|
||||
bot.load_extensions_from("./bot/extensions") # load extensions
|
||||
|
||||
bot.d.db = await aiosqlite.connect("./bot/db/database.db")
|
||||
await bot.d.db.executescript(open("./bot/db/schema.sql").read())
|
||||
await bot.d.db.commit()
|
||||
|
||||
bot.d.oasst_api = OasstApiClient(settings.oasst_api_url, settings.oasst_api_key)
|
||||
|
||||
|
||||
@bot.listen()
|
||||
async def on_stopping(event: hikari.StoppingEvent):
|
||||
"""Cleanup."""
|
||||
await bot.d.db.close()
|
||||
await bot.d.oasst_api.close()
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Sqlite3 schema for the bot
|
||||
CREATE TABLE IF NOT EXISTS guild_settings (
|
||||
guild_id BIGINT NOT NULL PRIMARY KEY,
|
||||
log_channel_id BIGINT
|
||||
);
|
||||
@@ -0,0 +1,28 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Database schemas."""
|
||||
import typing as t
|
||||
|
||||
from aiosqlite import Connection, Row
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GuildSettings(BaseModel):
|
||||
"""Guild settings."""
|
||||
|
||||
guild_id: int
|
||||
log_channel_id: int | None
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj: Row) -> "GuildSettings":
|
||||
"""Deserialize a Row object from aiosqlite into a GuildSettings object."""
|
||||
return cls(guild_id=obj[0], log_channel_id=obj[1])
|
||||
|
||||
@classmethod
|
||||
async def from_db(cls, conn: Connection, guild_id: int) -> t.Optional["GuildSettings"]:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (guild_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return cls.parse_obj(row)
|
||||
@@ -0,0 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Extensions for the bot.
|
||||
|
||||
See: https://hikari-lightbulb.readthedocs.io/en/latest/guides/extensions.html
|
||||
"""
|
||||
@@ -0,0 +1,96 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Guild settings."""
|
||||
import hikari
|
||||
import lightbulb
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import mention
|
||||
from lightbulb.utils.permissions import permissions_in
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin("GuildSettings")
|
||||
plugin.add_checks(lightbulb.guild_only)
|
||||
plugin.add_checks(lightbulb.has_guild_permissions(hikari.Permissions.MANAGE_GUILD))
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("settings", "Bot settings for the server.")
|
||||
@lightbulb.implements(lightbulb.SlashCommandGroup)
|
||||
async def settings(_: lightbulb.SlashContext) -> None:
|
||||
"""Bot settings for the server."""
|
||||
# This will never execute because it is a group
|
||||
pass
|
||||
|
||||
|
||||
@settings.child
|
||||
@lightbulb.command("get", "Get all the guild settings.")
|
||||
@lightbulb.implements(lightbulb.SlashSubCommand)
|
||||
async def get(ctx: lightbulb.SlashContext) -> None:
|
||||
"""Get one of or all the guild settings."""
|
||||
conn: Connection = ctx.bot.d.db
|
||||
assert ctx.guild_id is not None # `guild_only` check
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
# Get all settings
|
||||
await cursor.execute("SELECT * FROM guild_settings WHERE guild_id = ?", (ctx.guild_id,))
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
logger.warning(f"No guild settings for {ctx.guild_id}")
|
||||
await ctx.respond("No settings found for this guild.")
|
||||
return
|
||||
|
||||
guild_settings = GuildSettings.parse_obj(row)
|
||||
|
||||
# Respond with all
|
||||
# TODO: Embed
|
||||
await ctx.respond(
|
||||
f"""\
|
||||
**Guild Settings**
|
||||
`log_channel`: {
|
||||
mention(guild_settings.log_channel_id, "channel")
|
||||
if guild_settings.log_channel_id else 'not set'}
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
@settings.child
|
||||
@lightbulb.option("channel", "The channel to use.", hikari.TextableGuildChannel)
|
||||
@lightbulb.command("log_channel", "Set the channel that the bot logs task and label completions in.")
|
||||
@lightbulb.implements(lightbulb.SlashSubCommand)
|
||||
async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
"""Set the channel that the bot logs task and label completions in."""
|
||||
channel: hikari.TextableGuildChannel = ctx.options.channel
|
||||
conn: Connection = ctx.bot.d.db
|
||||
assert ctx.guild_id is not None # `guild_only` check
|
||||
assert isinstance(channel, hikari.PermissibleGuildChannel)
|
||||
|
||||
# Check if the bot can send messages in that channel
|
||||
assert (me := ctx.bot.get_me()) is not None # non-None after `StartedEvent`
|
||||
if (own_member := ctx.bot.cache.get_member(ctx.guild_id, me.id)) is None:
|
||||
own_member = await ctx.bot.rest.fetch_member(ctx.guild_id, me.id)
|
||||
perms = permissions_in(channel, own_member)
|
||||
if perms & ~hikari.Permissions.SEND_MESSAGES:
|
||||
await ctx.respond("I don't have permission to send messages in that channel.")
|
||||
return
|
||||
|
||||
await ctx.respond(f"Setting `log_channel` to {channel.mention}.")
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"INSERT OR REPLACE INTO guild_settings (guild_id, log_channel_id) VALUES (?, ?)",
|
||||
(ctx.guild_id, channel.id),
|
||||
)
|
||||
|
||||
await conn.commit()
|
||||
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,64 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
from glob import glob
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"HotReloadPlugin",
|
||||
)
|
||||
plugin.add_checks(lightbulb.owner_only)
|
||||
|
||||
EXTENSIONS_FOLDER = "bot/extensions"
|
||||
|
||||
|
||||
def _get_extensions() -> list[str]:
|
||||
# Recursively get all the .py files in the extensions directory not starting with an `_`.
|
||||
exts = glob("bot/extensions/**/[!_]*.py", recursive=True)
|
||||
# Turn the path into a plugin path ("path/to/extension.py" -> "path.to.extension")
|
||||
return [ext.replace("/", ".").replace("\\", ".").replace(".py", "") for ext in exts]
|
||||
|
||||
|
||||
async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikari.AutocompleteInteraction) -> list[str]:
|
||||
# Check that the option is a string.
|
||||
if not isinstance(option.value, str):
|
||||
raise TypeError(f"`option.value` must be of type `str`, it is currently a `{type(option.value)}`")
|
||||
|
||||
exts = _get_extensions()
|
||||
return [ext for ext in exts if option.value in ext]
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"plugin",
|
||||
"The plugin to reload. Leave empty to reload all plugins.",
|
||||
autocomplete=_plugin_autocomplete,
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def reload(ctx: lightbulb.SlashContext):
|
||||
"""Reload a plugin or all plugins."""
|
||||
# If the plugin option is None, reload all plugins.
|
||||
if ctx.options.plugin is None:
|
||||
ctx.bot.reload_extensions(*_get_extensions())
|
||||
await ctx.respond("Reloaded all plugins.")
|
||||
logger.info("Reloaded all plugins.")
|
||||
# Otherwise, reload the specified plugin.
|
||||
else:
|
||||
ctx.bot.reload_extensions(ctx.options.plugin)
|
||||
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
|
||||
logger.info(f"Reloaded `{ctx.options.plugin}`.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,181 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Hot reload plugin."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"TextLabels",
|
||||
)
|
||||
plugin.add_checks(lightbulb.guild_only) # Context menus are only enabled in guilds
|
||||
|
||||
|
||||
DISCORD_GRAY = 0x2F3136
|
||||
|
||||
|
||||
def clamp(num: float) -> float:
|
||||
"""Clamp a number between 0 and 1."""
|
||||
return min(max(0.0, num), 1.0)
|
||||
|
||||
|
||||
class LabelModal(miru.Modal):
|
||||
"""Modal for submitting text labels."""
|
||||
|
||||
def __init__(self, label: str, content: str, *args: t.Any, **kwargs: t.Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.label = label
|
||||
self.original_content = content
|
||||
|
||||
# Add the text of the message to the modal
|
||||
self.content = miru.TextInput(
|
||||
label="Text", style=hikari.TextInputStyle.PARAGRAPH, value=content, required=True, row=1
|
||||
)
|
||||
self.add_item(self.content)
|
||||
|
||||
value = miru.TextInput(label="Value", placeholder="Enter a value between 0 and 1", required=True, row=2)
|
||||
|
||||
async def callback(self, context: miru.ModalContext) -> None:
|
||||
val = float(self.value.value) if self.value.value else 0.0
|
||||
val = clamp(val)
|
||||
|
||||
edited = self.content.value != self.original_content
|
||||
await context.respond(
|
||||
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
)
|
||||
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
|
||||
|
||||
# Send a notification to the log channel
|
||||
assert context.guild_id is not None # `guild_only` check
|
||||
conn: Connection = context.bot.d.db # type: ignore
|
||||
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
|
||||
|
||||
if guild_settings is None or guild_settings.log_channel_id is None:
|
||||
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
|
||||
return
|
||||
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Message Label",
|
||||
description=f"{context.author.mention} labeled a message as `{self.label}`.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
color=0x00FF00,
|
||||
)
|
||||
.set_author(name=context.author.username, icon=context.author.avatar_url)
|
||||
.add_field("Total Labeled Message", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
)
|
||||
channel = await context.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel)
|
||||
await channel.send(EMPTY, embed=embed)
|
||||
|
||||
|
||||
class LabelSelect(miru.View):
|
||||
"""Select menu for selecting a label.
|
||||
|
||||
The current labels are:
|
||||
- contains toxic language
|
||||
- encourages illegal activity
|
||||
- good quality
|
||||
- bad quality
|
||||
- is spam
|
||||
"""
|
||||
|
||||
def __init__(self, content: str, *args: t.Any, **kwargs: t.Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.content = content
|
||||
|
||||
@miru.select(
|
||||
options=[
|
||||
hikari.SelectMenuOption(
|
||||
label="Toxic Language",
|
||||
value="toxic_language",
|
||||
description="The message contains toxic language.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Illegal Activity",
|
||||
value="illegal_activity",
|
||||
description="The message encourages illegal activity.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Good Quality",
|
||||
value="good_quality",
|
||||
description="The message is good quality.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Bad Quality",
|
||||
value="bad_quality",
|
||||
description="The message is bad quality.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="Spam",
|
||||
value="spam",
|
||||
description="The message is spam.",
|
||||
is_default=False,
|
||||
emoji=None,
|
||||
),
|
||||
],
|
||||
min_values=1,
|
||||
max_values=1,
|
||||
)
|
||||
async def label_select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
|
||||
"""Handle the select menu."""
|
||||
label = select.values[0]
|
||||
modal = LabelModal(label, self.content, title=f"Text Label: {label}", timeout=60)
|
||||
await modal.send(ctx.interaction)
|
||||
await modal.wait()
|
||||
|
||||
self.stop()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("Label Message", "Label a message")
|
||||
@lightbulb.implements(lightbulb.MessageCommand)
|
||||
async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
"""Label a message."""
|
||||
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
|
||||
# so the select menu will open the modal
|
||||
|
||||
msg: hikari.Message = ctx.options.target
|
||||
# Exit if the message is empty
|
||||
if not msg.content:
|
||||
await ctx.respond("Cannot label an empty message.", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
return
|
||||
|
||||
# Send the select menu
|
||||
# The modal will be opened from the select menu interaction
|
||||
embed = hikari.Embed(title="Label Message", description="Select a label for the message.", color=DISCORD_GRAY)
|
||||
label_select_view = LabelSelect(
|
||||
msg.content,
|
||||
timeout=60,
|
||||
)
|
||||
resp = await ctx.respond(EMPTY, embed=embed, components=label_select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
await label_select_view.start(await resp.message())
|
||||
await label_select_view.wait()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,301 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task plugin for testing different data collection methods."""
|
||||
# TODO: Delete this once user input method has been decided for final bot.
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from bot.utils import format_time
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("TaskPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60
|
||||
MAX_TASK_ACCEPT_TIME = 60
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_thread", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_thread(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
typ: str = ctx.options.type
|
||||
|
||||
# Create a thread for the task
|
||||
thread = await ctx.bot.rest.create_thread(ctx.channel_id, hikari.ChannelType.GUILD_PUBLIC_THREAD, f"Task: {typ}")
|
||||
|
||||
await ctx.respond(f"Please complete the task in the thread: {thread.mention}")
|
||||
|
||||
# Send the task in the thread
|
||||
await thread.send(
|
||||
f"""\
|
||||
Please complete the task.
|
||||
Sample Task
|
||||
|
||||
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
|
||||
"""
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.GuildMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id and e.channel_id == thread.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
finally:
|
||||
await thread.delete()
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_dm", "Request a task from the backend.", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand, lightbulb.PrefixCommand)
|
||||
async def task_dm(ctx: lightbulb.Context):
|
||||
"""Request a task from the backend."""
|
||||
await ctx.respond("Please complete the task in your DMs")
|
||||
|
||||
# Send the task in the dm
|
||||
await ctx.author.send(
|
||||
f"""\
|
||||
Please complete the task.
|
||||
Sample Task
|
||||
|
||||
Self destruct {format_time(datetime.now() + timedelta(seconds=MAX_TASK_TIME), 'R')}
|
||||
"""
|
||||
)
|
||||
|
||||
# Wait for the user to respond
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent,
|
||||
timeout=MAX_TASK_TIME,
|
||||
predicate=lambda e: e.author.id == ctx.author.id,
|
||||
)
|
||||
await ctx.respond(f"Received message: {event.message.content}")
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("You took too long to respond.")
|
||||
|
||||
|
||||
class TaskModal(miru.Modal):
|
||||
"""Modal for submitting a task."""
|
||||
|
||||
response = miru.TextInput(
|
||||
label="Response",
|
||||
placeholder="Enter your response!",
|
||||
required=True,
|
||||
style=hikari.TextInputStyle.PARAGRAPH,
|
||||
row=2,
|
||||
)
|
||||
|
||||
async def callback(self, context: miru.ModalContext) -> None:
|
||||
await context.respond(f"Received response: {self.response.value}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
class ModalView(miru.View):
|
||||
"""View for opening a modal."""
|
||||
|
||||
def __init__(self, modal_title: str, task: str, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.modal_title = modal_title
|
||||
self.task = task
|
||||
|
||||
@miru.button(label="Start Task!", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def modal_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
modal = TaskModal(title=self.modal_title)
|
||||
modal.add_item(miru.TextInput(label="Task", value=self.task, style=hikari.TextInputStyle.PARAGRAPH, row=1))
|
||||
await ctx.respond_with_modal(modal)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.split(".")[-1], value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=TaskRequestType.summarize_story,
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("task_modal", "Request a task from the backend.", ephemeral=True, auto_defer=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def task_modal(ctx: lightbulb.SlashContext):
|
||||
"""Request a task from the backend."""
|
||||
# typ: str = ctx.options.type
|
||||
view = ModalView(
|
||||
modal_title="Assistant Response",
|
||||
task="Please explain the moon landing to a six year old.",
|
||||
timeout=MAX_TASK_TIME,
|
||||
)
|
||||
resp = await ctx.respond(
|
||||
"Task - Respond to the prompt as if you were the Assistant:",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
components=view,
|
||||
)
|
||||
await view.start(await resp.message())
|
||||
|
||||
|
||||
class RatingView(miru.View):
|
||||
"""View for rating a task."""
|
||||
|
||||
def __init__(self, *args: t.Any, **kwargs: t.Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.presses: list[str] = []
|
||||
|
||||
def _close_if_all_pressed(self) -> None:
|
||||
if len(self.presses) == 5:
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="1", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_1(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("1")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="2", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_2(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("2")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="3", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_3(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("3")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="4", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_4(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("4")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="5", style=hikari.ButtonStyle.PRIMARY)
|
||||
async def button_5(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
if button.label not in self.presses:
|
||||
self.presses.append("5")
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
self._close_if_all_pressed()
|
||||
|
||||
@miru.button(label="Reset", style=hikari.ButtonStyle.DANGER)
|
||||
async def reset_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.presses = []
|
||||
await ctx.respond(f"Received response: {button.label}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
class SelectRating(miru.View):
|
||||
"""View for rating a task with a select menu."""
|
||||
|
||||
@miru.select(
|
||||
options=[
|
||||
hikari.SelectMenuOption(
|
||||
label="1",
|
||||
value="1",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="2",
|
||||
value="2",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
hikari.SelectMenuOption(
|
||||
label="3",
|
||||
value="3",
|
||||
description=None,
|
||||
emoji=None,
|
||||
is_default=False,
|
||||
),
|
||||
],
|
||||
placeholder="Select the good responses",
|
||||
min_values=0,
|
||||
max_values=3,
|
||||
row=3,
|
||||
)
|
||||
async def select(self, select: miru.Select, ctx: miru.ViewContext) -> None:
|
||||
await ctx.respond(f"You selected {select.values}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.command("rating_task", "Rate stuff.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def rating_task(ctx: lightbulb.SlashContext):
|
||||
"""Rate stuff."""
|
||||
# Message Based rating
|
||||
await ctx.respond(
|
||||
"List the responses in order of best to worst response (1,2,3,4,5)", flags=hikari.MessageFlag.EPHEMERAL
|
||||
)
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.MessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.respond("Timed out waiting for response")
|
||||
return
|
||||
|
||||
if event.content is None:
|
||||
await ctx.respond("No content in message")
|
||||
return
|
||||
ratings = event.content.replace(" ", "").split(",")
|
||||
|
||||
# Check if the ratings are valid
|
||||
if len(ratings) != 5:
|
||||
await ctx.respond("Invalid number of ratings")
|
||||
if not all([rating in ("1", "2", "3", "4", "5") for rating in ratings]):
|
||||
await ctx.respond("Invalid rating")
|
||||
|
||||
await ctx.respond(f"Your responses: {ratings}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
# Button Based rating
|
||||
view = RatingView(timeout=MAX_TASK_TIME)
|
||||
|
||||
resp = await ctx.respond("Click the buttons in order of best to worst response", components=view)
|
||||
await view.start(await resp.message())
|
||||
await view.wait()
|
||||
await ctx.respond(f"Your responses: {view.presses}", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await resp.delete()
|
||||
|
||||
# Select Based rating
|
||||
select_view = SelectRating(timeout=MAX_TASK_TIME)
|
||||
resp_2 = await ctx.respond("Select the good responses", components=select_view, flags=hikari.MessageFlag.EPHEMERAL)
|
||||
await select_view.start(await resp_2.message())
|
||||
await select_view.wait()
|
||||
await resp_2.delete()
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,451 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
import lightbulb.decorators
|
||||
import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.api_client import OasstApiClient, TaskType
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
plugin = lightbulb.Plugin("WorkPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
"type",
|
||||
"The type of task to request.",
|
||||
choices=[hikari.CommandChoice(name=task.value, value=task) for task in TaskRequestType],
|
||||
required=False,
|
||||
default=str(TaskRequestType.random),
|
||||
type=str,
|
||||
)
|
||||
@lightbulb.command("work", "Complete a task.")
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def work(ctx: lightbulb.SlashContext):
|
||||
"""Create and handle a task."""
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
await _handle_task(ctx, task_type)
|
||||
|
||||
|
||||
async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) -> None:
|
||||
"""Handle creating and collecting user input for a task.
|
||||
|
||||
Continually present tasks to the user until they select one, cancel, or time out.
|
||||
If they select one, present the task steps until a `task_done` task is received.
|
||||
Finally, ask the user if they want to perform another task (of the same type).
|
||||
"""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
|
||||
# Continue to complete tasks until the user doesn't want to do another
|
||||
done = False
|
||||
while not done:
|
||||
|
||||
# Loop until the user accepts a task
|
||||
task, msg_id = await _select_task(ctx, task_type)
|
||||
|
||||
if task is None:
|
||||
return
|
||||
|
||||
# Task action loop
|
||||
completed = False
|
||||
while not completed:
|
||||
await ctx.author.send("Please type your response here:")
|
||||
try:
|
||||
event = await ctx.bot.wait_for(
|
||||
hikari.DMMessageCreateEvent, timeout=MAX_TASK_TIME, predicate=lambda e: e.author.id == ctx.author.id
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
if event.content is None or not _validate_user_input(event.content, task):
|
||||
await ctx.author.send("Invalid response")
|
||||
continue
|
||||
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToMessage(
|
||||
message_id=str(msg_id),
|
||||
user_message_id=str(event.message_id),
|
||||
user=protocol_schema.User(
|
||||
auth_method="discord", id=str(ctx.author.id), display_name=ctx.author.username
|
||||
),
|
||||
text=event.content,
|
||||
)
|
||||
logger.debug(f"Sending reply to backend: {reply!r}")
|
||||
|
||||
# Get next task
|
||||
new_task = await oasst_api.post_interaction(reply)
|
||||
logger.info(f"New task {new_task}")
|
||||
|
||||
if new_task.type == TaskType.done:
|
||||
await ctx.author.send("Task completed")
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
assert ctx.guild_id is not None
|
||||
conn: Connection = ctx.bot.d.db
|
||||
guild_settings = await GuildSettings.from_db(conn, ctx.guild_id)
|
||||
|
||||
if guild_settings is not None and guild_settings.log_channel_id is not None:
|
||||
|
||||
channel = await ctx.bot.rest.fetch_channel(guild_settings.log_channel_id)
|
||||
assert isinstance(channel, hikari.TextableChannel) # option converter
|
||||
|
||||
done_embed = (
|
||||
hikari.Embed(
|
||||
title="Task Completion",
|
||||
description=f"`{task.type}` completed by {ctx.author.mention}",
|
||||
color=hikari.Color(0x00FF00),
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.add_field("Total Tasks", "0", inline=True)
|
||||
.add_field("Server Ranking", "0/0", inline=True)
|
||||
.add_field("Global Ranking", "0/0", inline=True)
|
||||
.set_footer(f"Task ID: {task.id}")
|
||||
)
|
||||
await channel.send(EMPTY, embed=done_embed)
|
||||
|
||||
# ask the user if they want to do another task
|
||||
choice_view = ChoiceView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send("Would you like another task?", components=choice_view)
|
||||
await choice_view.start(msg)
|
||||
await choice_view.wait()
|
||||
|
||||
match choice_view.choice:
|
||||
case False | None:
|
||||
done = True
|
||||
await ctx.author.send("Exiting, goodbye!")
|
||||
case True:
|
||||
pass
|
||||
|
||||
|
||||
async def _select_task(
|
||||
ctx: lightbulb.SlashContext, task_type: TaskRequestType, user: protocol_schema.User | None = None
|
||||
) -> tuple[protocol_schema.Task | None, str]:
|
||||
"""Present tasks to the user until they accept one, cancel, or time out."""
|
||||
oasst_api: OasstApiClient = ctx.bot.d.oasst_api
|
||||
logger.debug(f"Starting task selection for {task_type}")
|
||||
|
||||
# Loop until the user accepts a task, cancels, or times out
|
||||
while True:
|
||||
logger.debug(f"Requesting task of type {task_type}")
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
await oasst_api.ack_task(task.id, msg_id)
|
||||
return task, msg_id
|
||||
|
||||
case "next":
|
||||
logger.info(f"Task {task.id} rejected, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "rejected")
|
||||
await ctx.author.send("Sending next task...")
|
||||
continue
|
||||
|
||||
case "cancel":
|
||||
logger.info(f"Task {task.id} canceled, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "canceled")
|
||||
await ctx.author.send("Task canceled. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
case None:
|
||||
logger.info(f"Task {task.id} timed out, sending NACK")
|
||||
await oasst_api.nack_task(task.id, "timed out")
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
return None, msg_id
|
||||
|
||||
|
||||
async def _send_task(
|
||||
ctx: lightbulb.SlashContext, task: protocol_schema.Task
|
||||
) -> tuple[t.Literal["accept", "next", "cancel"] | None, str]:
|
||||
"""Send a task to the user.
|
||||
|
||||
Returns the user's choice and the message ID of the task message.
|
||||
"""
|
||||
# The clean way to do this would be to attach a `to_embed` method to the task classes
|
||||
# but the tasks aren't discord specific so that doesn't really make sense.
|
||||
|
||||
embed: hikari.UndefinedOr[hikari.Embed] = hikari.UNDEFINED
|
||||
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_prompter_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask)
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.prompter_reply:
|
||||
assert isinstance(task, protocol_schema.PrompterReplyTask)
|
||||
logger.debug("sending user reply task")
|
||||
embed = _prompter_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"unknown task type {task.type}")
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
msg = await ctx.author.send(
|
||||
EMPTY,
|
||||
embed=embed,
|
||||
components=view,
|
||||
)
|
||||
|
||||
assert msg is not None
|
||||
|
||||
await view.start(msg)
|
||||
await view.wait()
|
||||
|
||||
return view.choice, str(msg.id)
|
||||
|
||||
|
||||
def _validate_user_input(content: str | None, task: protocol_schema.Task) -> bool:
|
||||
"""Returns whether the user's input is valid for the task type."""
|
||||
if content is None:
|
||||
return False
|
||||
|
||||
# User message input
|
||||
if (
|
||||
task.type == TaskRequestType.initial_prompt
|
||||
or task.type == TaskRequestType.prompter_reply
|
||||
or task.type == TaskRequestType.assistant_reply
|
||||
):
|
||||
assert isinstance(
|
||||
task,
|
||||
protocol_schema.InitialPromptTask | protocol_schema.PrompterReplyTask | protocol_schema.AssistantReplyTask,
|
||||
)
|
||||
return len(content) > 0
|
||||
|
||||
# Ranking tasks
|
||||
elif task.type == TaskRequestType.rank_prompter_replies or task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankPrompterRepliesTask | protocol_schema.RankAssistantRepliesTask)
|
||||
num_replies = len(task.replies)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_replies + 1)} and len(rankings) == num_replies
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
num_prompts = len(task.prompts)
|
||||
|
||||
rankings = content.split(",")
|
||||
return set(rankings) == {str(i) for i in range(1, num_prompts + 1)} and len(rankings) == num_prompts
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
raise NotImplementedError
|
||||
elif task.type == TaskRequestType.rate_summary:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.critical(f"Unknown task type {task.type}")
|
||||
raise ValueError(f"Unknown task type {task.type}")
|
||||
|
||||
|
||||
class TaskAcceptView(miru.View):
|
||||
"""View with three buttons: accept, next, and cancel.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
"""
|
||||
|
||||
choice: t.Literal["accept", "next", "cancel"] | None = None
|
||||
|
||||
@miru.button(label="Accept", custom_id="accept", row=0, style=hikari.ButtonStyle.SUCCESS)
|
||||
async def accept_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Accept button pressed")
|
||||
self.choice = "accept"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Next Task", custom_id="next_task", row=0, style=hikari.ButtonStyle.SECONDARY)
|
||||
async def next_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Next button pressed")
|
||||
self.choice = "next"
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="Cancel", custom_id="cancel", row=0, style=hikari.ButtonStyle.DANGER)
|
||||
async def cancel_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
logger.info("Cancel button pressed")
|
||||
self.choice = "cancel"
|
||||
self.stop()
|
||||
|
||||
|
||||
class ChoiceView(miru.View):
|
||||
"""View with two buttons: yes and no.
|
||||
|
||||
The view stops once one of the buttons is pressed and the choice is stored in the `choice` attribute.
|
||||
"""
|
||||
|
||||
choice: bool | None = None
|
||||
|
||||
@miru.button(label="Yes", custom_id="yes", style=hikari.ButtonStyle.SUCCESS)
|
||||
async def yes_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = True
|
||||
self.stop()
|
||||
|
||||
@miru.button(label="No", custom_id="no", style=hikari.ButtonStyle.DANGER)
|
||||
async def no_button(self, button: miru.Button, ctx: miru.ViewContext) -> None:
|
||||
self.choice = False
|
||||
self.stop()
|
||||
|
||||
|
||||
################################################################
|
||||
# Template Embeds #
|
||||
################################################################
|
||||
|
||||
# TODO: Maybe implement a better way of creating embeds, like `from_json` or something
|
||||
|
||||
|
||||
def _initial_prompt_embed(task: protocol_schema.InitialPromptTask) -> hikari.Embed:
|
||||
return (
|
||||
hikari.Embed(title="Initial Prompt", description=f"Hint: {task.hint}", timestamp=datetime.now().astimezone())
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
|
||||
def _rank_initial_prompt_embed(task: protocol_schema.RankInitialPromptsTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Initial Prompt",
|
||||
description="Rank the following tasks from best to worst (1,2,3,4,5)",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512")
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, prompt in enumerate(task.prompts):
|
||||
embed.add_field(name=f"Prompt {i + 1}", value=prompt, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_prompter_reply_embed(task: protocol_schema.RankPrompterRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank User Reply",
|
||||
description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _rank_assistant_reply_embed(task: protocol_schema.RankAssistantRepliesTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="Rank Assistant Reply",
|
||||
description="Rank the following tasks from best to worst. e.g. 1,2,5,3,4",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
.set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: update image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for i, reply in enumerate(task.replies):
|
||||
embed.add_field(name=f"Reply {i + 1}", value=reply, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _prompter_reply_embed(task: protocol_schema.PrompterReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description=f"""\
|
||||
Send the next message in the conversation as if you were the user.
|
||||
{'Hint: ' if task.hint else ''}
|
||||
""",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def _assistant_reply_embed(task: protocol_schema.AssistantReplyTask) -> hikari.Embed:
|
||||
embed = (
|
||||
hikari.Embed(
|
||||
title="User Reply",
|
||||
description="Send the next message in the conversation as if you were the user.",
|
||||
timestamp=datetime.now().astimezone(),
|
||||
)
|
||||
# .set_image("https://images.unsplash.com/photo-1455390582262-044cdead277a?w=512") # TODO: change image
|
||||
.set_footer(text=f"OASST Assistant | {task.id}")
|
||||
)
|
||||
|
||||
for message in task.conversation.messages:
|
||||
embed.add_field(name="Assistant" if message.is_assistant else "User", value=message.text, inline=False)
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
"""Add the plugin to the bot."""
|
||||
bot.add_plugin(plugin)
|
||||
|
||||
|
||||
def unload(bot: lightbulb.BotApp):
|
||||
"""Remove the plugin to the bot."""
|
||||
bot.remove_plugin(plugin)
|
||||
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Configuration for the bot."""
|
||||
from pydantic import BaseSettings, Field
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Settings for the bot."""
|
||||
|
||||
bot_token: str = Field(env="BOT_TOKEN", default="")
|
||||
declare_global_commands: int = Field(env="DECLARE_GLOBAL_COMMANDS", default=0)
|
||||
owner_ids: list[int] = Field(env="OWNER_IDS", default_factory=list)
|
||||
prefix: str = Field(env="PREFIX", default="./")
|
||||
oasst_api_url: str = Field(env="OASST_API_URL", default="http://localhost:8080")
|
||||
oasst_api_key: str = Field(env="OASST_API_KEY", default="")
|
||||
|
||||
class Config(BaseSettings.Config):
|
||||
env_file = ".env"
|
||||
case_sensitive = False
|
||||
@@ -0,0 +1,48 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Utility functions."""
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
import hikari
|
||||
|
||||
|
||||
def format_time(dt: datetime, fmt: t.Literal["t", "T", "D", "f", "F", "R"]) -> str:
|
||||
"""Format a datetime object into the discord time format.
|
||||
|
||||
```
|
||||
| t | HH:MM | 16:20
|
||||
| T | HH:MM:SS | 16:20:11
|
||||
| D | D Mo Yr | 20 April 2022
|
||||
| f | D Mo Yr HH:MM | 20 April 2022 16:20
|
||||
| F | W, D Mo Yr HH:MM | Wednesday, 20 April 2022 16:20
|
||||
| R | relative | in an hour
|
||||
```
|
||||
"""
|
||||
match fmt:
|
||||
case "t" | "T" | "D" | "f" | "F" | "R":
|
||||
return f"<t:{dt.timestamp():.0f}:{fmt}>"
|
||||
case _:
|
||||
raise ValueError(f"`fmt` must be 't', 'T', 'D', 'f', 'F' or 'R', not {fmt}")
|
||||
|
||||
|
||||
EMPTY = "\u200d"
|
||||
"""Zero-width joiner.
|
||||
|
||||
This appears as an empty message in Discord.
|
||||
"""
|
||||
|
||||
|
||||
def mention(
|
||||
id: hikari.Snowflakeish,
|
||||
type: t.Literal["channel", "role", "user"],
|
||||
) -> str:
|
||||
"""Mention an object."""
|
||||
match type:
|
||||
case "channel":
|
||||
return f"<#{id}>"
|
||||
|
||||
case "user":
|
||||
return f"<@{id}>"
|
||||
|
||||
case "role":
|
||||
return f"<@&{id}>"
|
||||
@@ -1,61 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from channel_handlers import ChannelHandlerBase
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ReplyHandlerInfo:
|
||||
msg_id: int
|
||||
handler_task: asyncio.Task
|
||||
handler: ChannelHandlerBase
|
||||
|
||||
|
||||
class BotBase(ABC):
|
||||
bot_channel_name: str
|
||||
debug: bool
|
||||
backend: ApiClient
|
||||
client: discord.Client
|
||||
loop: asyncio.BaseEventLoop
|
||||
owner_id: int
|
||||
bot_channel: discord.TextChannel
|
||||
templates: MessageTemplates
|
||||
reply_handlers: dict[int, ReplyHandlerInfo]
|
||||
|
||||
def __init__(self):
|
||||
self.reply_handlers = {} # handlers by msg_id
|
||||
|
||||
def ensure_bot_channel(self) -> None:
|
||||
if self.bot_channel is None:
|
||||
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
|
||||
|
||||
async def post(
|
||||
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
|
||||
) -> discord.Message:
|
||||
if channel is None:
|
||||
self.ensure_bot_channel()
|
||||
channel = self.bot_channel
|
||||
return await channel.send(content=content, view=view)
|
||||
|
||||
async def post_template(
|
||||
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
|
||||
) -> discord.Message:
|
||||
logger.debug(f"rendering {name}")
|
||||
text = self.templates.render(name, **kwargs)
|
||||
return await self.post(text, view=view, channel=channel)
|
||||
|
||||
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
|
||||
if msg_id in self.reply_handlers:
|
||||
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
|
||||
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
|
||||
task.add_done_callback(lambda t: handler.on_completed())
|
||||
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
|
||||
@@ -1,15 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from pydantic import AnyHttpUrl, BaseSettings
|
||||
|
||||
|
||||
class BotSettings(BaseSettings):
|
||||
BACKEND_URL: AnyHttpUrl = "http://localhost:8080"
|
||||
API_KEY: str = "any_key"
|
||||
BOT_TOKEN: str
|
||||
BOT_CHANNEL_NAME: str = "bot"
|
||||
OWNER_ID: int = None
|
||||
TEMPLATE_DIR: str = "./templates"
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
settings = BotSettings(_env_file=".env")
|
||||
@@ -1,88 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
|
||||
import discord
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ChannelExpiredException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChannelHandlerBase(ABC):
|
||||
queue: asyncio.Queue
|
||||
completed: bool = False
|
||||
expiry_date: datetime
|
||||
expired: bool = False
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
self.expiry_date = expiry_date
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
"""Call this method to read the next message from the user in the handler method."""
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
|
||||
msg = await self.queue.get()
|
||||
if msg is None:
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
else:
|
||||
raise RuntimeError("Unexpected None message read")
|
||||
return msg
|
||||
|
||||
def on_reply(self, message: discord.Message) -> None:
|
||||
self.queue.put_nowait(message)
|
||||
|
||||
def on_expire(self) -> None:
|
||||
logger.info("ChannelHandler: on_expire")
|
||||
self.expired = True
|
||||
self.queue.put_nowait(None)
|
||||
|
||||
def on_completed(self) -> None:
|
||||
logger.info("ChannelHandler: on_completed")
|
||||
self.completed = True
|
||||
|
||||
def tick(self, now: datetime):
|
||||
if now > self.expiry_date and not self.expired:
|
||||
self.on_expire()
|
||||
|
||||
@abstractmethod
|
||||
async def handler_loop(self):
|
||||
...
|
||||
|
||||
async def finalize(self):
|
||||
pass
|
||||
|
||||
|
||||
class AutoDestructThreadHandler(ChannelHandlerBase):
|
||||
first_message: discord.Message = None
|
||||
thread: discord.Thread = None
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
super().__init__(expiry_date=expiry_date)
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
try:
|
||||
return await super().read()
|
||||
except ChannelExpiredException:
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self):
|
||||
logger.debug("AutoDestructThreadHandler.cleanup")
|
||||
if self.thread:
|
||||
logger.debug(f"deleting thread: {self.thread.name}")
|
||||
await self.thread.delete()
|
||||
self.thread = None
|
||||
if self.first_message:
|
||||
logger.debug(f"deleting first_message: {self.first_message.content}")
|
||||
await self.first_message.delete()
|
||||
self.first_message = None
|
||||
|
||||
async def finalize(self):
|
||||
await self.cleanup()
|
||||
return await super().finalize()
|
||||
@@ -1,16 +1,21 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Message templates for the discord bot."""
|
||||
import typing
|
||||
|
||||
import jinja2
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageTemplates:
|
||||
def __init__(self, template_dir="./templates"):
|
||||
self.env = jinja2.Environment(
|
||||
"""Create message templates for the discord bot."""
|
||||
|
||||
def __init__(self, template_dir: str = "./templates"):
|
||||
self.env = jinja2.Environment( # noqa: S701
|
||||
loader=jinja2.FileSystemLoader(template_dir),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
|
||||
)
|
||||
|
||||
def render(self, template_name, **kwargs):
|
||||
def render(self, template_name: str, **kwargs: typing.Any):
|
||||
template = self.env.get_template(template_name)
|
||||
txt = template.render(kwargs)
|
||||
logger.debug(txt)
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
discord.py==2.1.0
|
||||
Jinja2==3.1.2
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
pytz==2022.7
|
||||
requests==2.28.1
|
||||
schedule==1.1.0
|
||||
aiohttp # http client
|
||||
aiohttp[speedups] # speedups for aiohttp
|
||||
aiosqlite # database
|
||||
hikari # discord framework
|
||||
hikari-lightbulb # command handler
|
||||
hikari-miru # modals and buttons
|
||||
hikari[speedups]
|
||||
loguru
|
||||
pydantic
|
||||
|
||||
uvloop; os_name != 'nt' # Faster drop-in replacement for asyncio event loop
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from bot_base import BotBase
|
||||
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
||||
|
||||
|
||||
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
||||
name = discord.ui.TextInput(label="Name")
|
||||
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
|
||||
|
||||
|
||||
class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
thread_name: str = "Replies"
|
||||
expires_after: timedelta = timedelta(minutes=5)
|
||||
backend: ApiClient
|
||||
|
||||
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
|
||||
try:
|
||||
self.bot = bot
|
||||
self.task = task
|
||||
self.backend = bot.backend
|
||||
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
||||
msg = await self.send_first_message()
|
||||
self.first_message = msg
|
||||
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
||||
await self.on_thread_created(self.thread)
|
||||
except Exception:
|
||||
logger.exception("start task failed")
|
||||
await self.cleanup() # try to cleanup messag or thread
|
||||
raise
|
||||
|
||||
bot.register_reply_handler(msg_id=msg.id, handler=self)
|
||||
return msg
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_first_message(self) -> discord.message:
|
||||
...
|
||||
|
||||
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
||||
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
||||
|
||||
async def post_teaser_msg(self, template_name: str):
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
return await self.bot.post_template(
|
||||
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
|
||||
)
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
api_response = await self.backend.post_interaction(interaction)
|
||||
if api_response.type != "task_done":
|
||||
# multi-step tasks are not supported yet
|
||||
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
|
||||
raise RuntimeError("Unexpected response from backend received")
|
||||
return api_response
|
||||
|
||||
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.TextReplyToMessage(
|
||||
message_id=str(self.first_message.id),
|
||||
user_message_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
text=user_msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
self.post_text_reply_to_post(user_msg)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_text_reply_to_post()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.MessageRanking(
|
||||
message_id=str(self.first_message.id),
|
||||
user_message_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
ranking=ranking,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
ranking_str = user_msg.content
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
self.post_ranking(user_msg, ranking=ranking)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_ranking()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
|
||||
class SummarizeStoryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.SummarizeStoryTask
|
||||
thread_name: str = "Summaries"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_summarize_story.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class InitialPromptHandler(ChannelTaskBase):
|
||||
task: protocol_schema.InitialPromptTask
|
||||
thread_name: str = "Prompts"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_initial_prompt.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class PrompterReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.PrompterReplyTask
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_prompter_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_prompter_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class AssistantReplyHandler(ChannelTaskBase):
|
||||
task: protocol_schema.AssistantReplyTask
|
||||
thread_name: str = "Assistant replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_assistant_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankInitialPromptsTask
|
||||
thread_name: str = "User Responses"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RankConversationsHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RankConversationRepliesTask
|
||||
thread_name: str = "Rankings"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RatingButton(discord.ui.Button):
|
||||
def __init__(self, label, value, response_handler):
|
||||
super().__init__(label=label, style=discord.ButtonStyle.green)
|
||||
self.value = value
|
||||
self.response_handler = response_handler
|
||||
|
||||
async def callback(self, interaction):
|
||||
await self.response_handler(self.value, interaction)
|
||||
|
||||
|
||||
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
view = discord.ui.View()
|
||||
for i in range(lo, hi + 1):
|
||||
view.add_item(RatingButton(str(i), i, response_handler))
|
||||
return view
|
||||
|
||||
|
||||
class RateSummaryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RateSummaryTask
|
||||
thread_name: str = "Ratings"
|
||||
|
||||
async def _rating_response_handler(self, score, interaction: discord.Interaction):
|
||||
logger.info("rating_response_handler", score)
|
||||
if self.thread:
|
||||
try:
|
||||
self.backend.post_interaction(
|
||||
protocol_schema.MessageRating(
|
||||
message_id=str(self.first_message.id),
|
||||
user_message_id=str(interaction.id),
|
||||
user=self.to_api_user(interaction.user),
|
||||
rating=score,
|
||||
)
|
||||
)
|
||||
await interaction.response.send_message(
|
||||
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
|
||||
)
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in _rating_response_handler()")
|
||||
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.post_teaser_msg("teaser_rate_summary.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
||||
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info(f"on_rate_summary_reply: {msg.content}")
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply("❌ Text intput not supported.")
|
||||
@@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
def get_git_head_hash():
|
||||
# get current git hash
|
||||
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
if x.returncode == 0:
|
||||
return x.stdout.replace("\n", "")
|
||||
return None
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(pytz.UTC)
|
||||
|
||||
|
||||
class DiscordTimestampStyle(str, enum.Enum):
|
||||
"""
|
||||
Timestamp Styles
|
||||
|
||||
t 16:20 Short Time
|
||||
T 16:20:30 Long Time
|
||||
d 20/04/2021 Short Date
|
||||
D 20 April 2021 Long Date
|
||||
f * 20 April 2021 16:20 Short Date/Time
|
||||
F Tuesday, 20 April 2021 16:20 Long Date/Time
|
||||
R 2 months ago Relative Time
|
||||
|
||||
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
|
||||
"""
|
||||
|
||||
default = ""
|
||||
short_time = "t"
|
||||
long_time = "T"
|
||||
short_date = "d"
|
||||
long_date = "D"
|
||||
short_date_time = "f"
|
||||
long_date_time = "F"
|
||||
relative_time = "R"
|
||||
|
||||
|
||||
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
|
||||
parts = ["<t:", str(int(d.timestamp()))]
|
||||
if style:
|
||||
parts.append(":")
|
||||
parts.append(style)
|
||||
parts.append(">")
|
||||
return "".join(parts)
|
||||
Reference in New Issue
Block a user