diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 90279415..cccb2167 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: "build|stubs" +exclude: "build|stubs|^bot/templates/" default_language_version: python: python3 diff --git a/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py b/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py new file mode 100644 index 00000000..c65b8319 --- /dev/null +++ b/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""add_auth_method_to_ix_person_username + +Revision ID: 0daec5f8135f +Revises: 6368515778c5 +Create Date: 2022-12-22 18:35:59.609013 + +""" +import sqlalchemy as sa # noqa: F401 +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0daec5f8135f" +down_revision = "6368515778c5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_person_username", table_name="person") + op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_person_username", table_name="person") + op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False) + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index dac2a9bd..7ec5aa96 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -169,8 +169,8 @@ def acknowledge_task( pr = PromptRepository(db, api_client, user=None) # here we store the post id in the database for the task - pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id) logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.") + pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id) except Exception: logger.exception("Failed to acknowledge task.") diff --git a/backend/oasst_backend/models/person.py b/backend/oasst_backend/models/person.py index 57f134a4..f01f85f0 100644 --- a/backend/oasst_backend/models/person.py +++ b/backend/oasst_backend/models/person.py @@ -10,7 +10,7 @@ from sqlmodel import Field, Index, SQLModel class Person(SQLModel, table=True): __tablename__ = "person" - __table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),) + __table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),) id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 9f7bb1dd..b0063cdf 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -32,7 +32,12 @@ class PromptRepository: ) if person is None: # user is unknown, create new record - person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + person = Person( + username=user.id, + display_name=user.display_name, + api_client_id=self.api_client.id, + auth_method=user.auth_method, + ) self.db.add(person) self.db.commit() self.db.refresh(person) diff --git a/bot/__main__.py b/bot/__main__.py index 362b16f0..0047bce7 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -12,5 +12,7 @@ if __name__ == "__main__": backend_url=settings.BACKEND_URL, api_key=settings.API_KEY, owner_id=settings.OWNER_ID, + template_dir=settings.TEMPLATE_DIR, + debug=settings.DEBUG, ) bot.run() diff --git a/bot/api_client.py b/bot/api_client.py index 19a62188..1de6bb17 100644 --- a/bot/api_client.py +++ b/bot/api_client.py @@ -69,6 +69,6 @@ class ApiClient: req = protocol_schema.TaskNAck(reason=reason) return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone: + def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: data = self.post("/api/v1/tasks/interaction", interaction.dict()) return self._parse_task(data) diff --git a/bot/bot.py b/bot/bot.py index e6e90770..a19fdfe1 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,32 +1,26 @@ # -*- coding: utf-8 -*- +from __future__ import annotations + import asyncio +from datetime import timedelta +from pathlib import Path from typing import Optional, Union import discord +import task_handlers from api_client import ApiClient, TaskType +from bot_base import BotBase from discord import app_commands from loguru import logger +from message_templates import MessageTemplates from oasst_shared.schemas import protocol as protocol_schema +from utils import get_git_head_hash, utcnow + +__version__ = "0.0.3" +BOT_NAME = "Open-Assistant Junior" -class RatingButton(discord.ui.Button): - def __init__(self, label, value, response_handler): - super().__init__(label=label, style=discord.ButtonStyle.green) - self.value = value - self.response_handler = response_handler - - async def callback(self, interaction): - await self.response_handler(self.value, interaction) - - -def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: - view = discord.ui.View() - for i in range(lo, hi + 1): - view.add_item(RatingButton(str(i), i, response_handler)) - return view - - -class OpenAssistantBot: +class OpenAssistantBot(BotBase): def __init__( self, bot_token: str, @@ -34,7 +28,16 @@ class OpenAssistantBot: backend_url: str, api_key: str, owner_id: Optional[Union[int, str]] = None, + template_dir: str = "./templates", + debug: bool = False, ): + super().__init__() + + self.template_dir = Path(template_dir) + self.bot_channel_name = bot_channel_name + self.templates = MessageTemplates(template_dir) + self.debug = debug + intents = discord.Intents.default() intents.message_content = True @@ -45,20 +48,25 @@ class OpenAssistantBot: self.bot_token = bot_token client = discord.Client(intents=intents) self.client = client + self.loop = client.loop self.bot_channel: discord.TextChannel = None self.backend = ApiClient(backend_url, api_key) - self.reply_handlers = {} # handlers by msg_id - self.tree = app_commands.CommandTree(self.client, fallback_to_global=True) - self.auto_archive_minutes = 60 # ToDo: add to bot config + self.tree = app_commands.CommandTree(self.client, fallback_to_global=True) @client.event async def on_ready(): self.bot_channel = self.get_text_channel_by_name(bot_channel_name) - client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") logger.info(f"{client.user} is now running!") + await self.delete_all_old_bot_messages() + # if self.debug: + # await self.post_boot_message() + await self.post_welcome_message() + + client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") + @client.event async def on_message(message: discord.Message): # ignore own messages @@ -68,208 +76,111 @@ class OpenAssistantBot: @self.tree.command() async def tutorial(interaction: discord.Interaction): """Start the Open-Assistant tutorial via DMs.""" + + dm = await self.client.create_dm(discord.Object(interaction.user.id)) + await dm.send("Tutorial coming soon... :-)") await interaction.response.send_message(f"tutorial command by {interaction.user.name}") @self.tree.command() async def help(interaction: discord.Interaction): """Sends the user a list of all available commands""" - await interaction.response.send_message(f"help command by {interaction.user.name}") + await self.post_help(interaction.user) + await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.") @self.tree.command() async def work(interaction: discord.Interaction): """Request a new personalized task""" - await interaction.response.send_message(f"work command by {interaction.user.name}") - async def print_separtor(self, title: str) -> discord.Message: - msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n") - return msg + # task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) + # task = self.backend.fetch_random_task(user=None) + q = task_handlers.Questionnaire() + await interaction.response.send_modal(q) - async def generate_summarize_story(self, task: protocol_schema.SummarizeStoryTask): - text = f"Summarize to the following story:\n{task.story}" - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="Summaries", auto_archive_duration=self.auto_archive_minutes + async def post_help(self, user: discord.abc.User) -> discord.Message: + is_bot_owner = user.id == self.owner_id + return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner) + + async def post_boot_message(self) -> discord.Message: + return await self.post_template( + "boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug ) - async def on_reply(message: discord.Message): - logger.info("on_summarize_story_reply", message) - await message.add_reaction("✅") + async def post_welcome_message(self) -> discord.Message: + return await self.post_template("welcome.msg") - self.reply_handlers[msg.id] = on_reply + async def delete_all_old_bot_messages(self) -> None: + logger.info("Deleting old threads...") + for thread in self.bot_channel.threads: + if thread.owner_id == self.client.user.id: + await thread.delete() + logger.info("Completed deleting old theards.") - return msg - - async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask): - s = [ - "Rate the following summary:", - task.summary, - "Full text:", - task.full_text, - f"Rating scale: {task.scale.min} - {task.scale.max}", - ] - text = "\n".join(s) - - async def rating_response_handler(score, interaction: discord.Interaction): - logger.info("rating_response_handler", score) - await interaction.response.send_message(f"got your feedback: {score}") - - view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler) - msg: discord.Message = await self.bot_channel.send(text, view=view) - - async def on_reply(message: discord.Message): - logger.info("on_summary_reply", message) - await message.add_reaction("") - - self.reply_handlers[msg.id] = on_reply - - return msg - - async def generate_initial_prompt(self, task: protocol_schema.InitialPromptTask): - text = "Please provide an initial prompt to the assistant." - if task.hint: - text += f"\nHint: {task.hint}" - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="Prompts", auto_archive_duration=self.auto_archive_minutes - ) - - async def on_reply(message: discord.Message): - logger.info("on_initial_prompt_reply", message) - await message.add_reaction("✅") - - self.reply_handlers[msg.id] = on_reply - - return msg - - def _render_message(self, message: protocol_schema.ConversationMessage) -> str: - """Render a message to the user.""" - if message.is_assistant: - return f":robot: Assistant:\n{message.text}" - else: - return f":person_red_hair: User:\n**{message.text}**" - - async def generate_user_reply(self, task: protocol_schema.UserReplyTask): - s = ["Please provide a reply to the assistant.", "Here is the conversation so far:\n"] - for message in task.conversation.messages: - s.append(self._render_message(message)) - s.append("") - if task.hint: - s.append(f"Hint: {task.hint}") - text = "\n".join(s) - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes - ) - - async def on_reply(message: discord.Message): - logger.info("on_user_reply_reply", message) - await message.add_reaction("✅") - - self.reply_handlers[msg.id] = on_reply - - return msg - - async def generate_assistant_reply(self, task: protocol_schema.AssistantReplyTask): - s = ["Act as the assistant and reply to the user.", "Here is the conversation so far\n:"] - for message in task.conversation.messages: - s.append(self._render_message(message)) - s.append("") - s.append(":robot: Assistant: { human, pls help me! ... }") - text = "\n".join(s) - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="Agent responses", auto_archive_duration=self.auto_archive_minutes - ) - - async def on_reply(message: discord.Message): - logger.info("on_assistant_reply_reply", message) - await message.add_reaction("✅") - - self.reply_handlers[msg.id] = on_reply - - return msg - - async def generate_rank_initial_prompts(self, task: protocol_schema.RankInitialPromptsTask): - s = ["Rank the following prompts:"] - for idx, prompt in enumerate(task.prompts, start=1): - s.append(f"{idx}: {prompt}") - s.append("") - s.append(':scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").') - text = "\n".join(s) - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes - ) - - async def on_reply(message: discord.Message): - logger.info("on_rank_initial_prompts_reply", message) - await message.add_reaction("✅") - - self.reply_handlers[msg.id] = on_reply - - return msg - - async def generate_rank_conversation(self, task: protocol_schema.RankConversationRepliesTask): - s = ["Here is the conversation so far:"] - for message in task.conversation.messages: - s.append(self._render_message(message)) - s.append("") - s.append("Rank the following replies:") - for idx, reply in enumerate(task.replies, start=1): - s.append(f"{idx}: {reply}") - s.append("") - s.append(':scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").') - text = "\n".join(s) - msg: discord.Message = await self.bot_channel.send(text) - await self.bot_channel.create_thread( - message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes - ) - - async def on_reply(message: discord.Message): - logger.info("on_rank_conversation_reply", message) - await message.add_reaction("✅") - message - - self.reply_handlers[msg.id] = on_reply - - return msg + logger.info("Deleting old messages...") + look_until = utcnow() - timedelta(days=365) + async for msg in self.bot_channel.history(limit=None): + msg: discord.Message + if msg.created_at < look_until: + break + if msg.author.id == self.client.user.id: + await msg.delete() + logger.info("Completed deleting old messages.") async def next_task(self): - # task = self.backend.fetch_task(protocol_schema.TaskRequestType.user_reply, user=None) - task = self.backend.fetch_random_task(user=None) + task_type = protocol_schema.TaskRequestType.random + task = self.backend.fetch_task(task_type, user=None) - await self.print_separtor("New Task") - - msg: discord.Message = None + handler: task_handlers.ChannelTaskBase = None match task.type: case TaskType.summarize_story: - msg = await self.generate_summarize_story(task) + handler = task_handlers.SummarizeStoryHandler() case TaskType.rate_summary: - msg = await self.generate_rate_summary(task) + handler = task_handlers.RateSummaryHandler() case TaskType.initial_prompt: - msg = await self.generate_initial_prompt(task) + handler = task_handlers.InitialPromptHandler() case TaskType.user_reply: - msg = await self.generate_user_reply(task) + handler = task_handlers.UserReplyHandler() case TaskType.assistant_reply: - msg = await self.generate_assistant_reply(task) + handler = task_handlers.AssistantReplyHandler() case TaskType.rank_initial_prompts: - msg = await self.generate_rank_initial_prompts(task) + handler = task_handlers.RankInitialPromptsHandler() case TaskType.rank_user_replies | TaskType.rank_assistant_replies: - msg = await self.generate_rank_conversation(task) + handler = task_handlers.RankConversationsHandler() + case _: + logger.warning(f"Unsupported task type received: {task.type}") + self.backend.nack_task(task.id, "not supported") - if msg is not None: - self.backend.ack_task(task.id, msg.id) - else: - self.backend.nack_task(task.id, "not supported") + if handler: + try: + logger.info(f"strarting task {task.id}") + msg = await handler.start(self, task) + self.backend.ack_task(task.id, msg.id) + except Exception: + logger.exception("Starting task failed.") + self.backend.nack_task(task.id, "faled") async def background_timer(self): + next_remove_completed = utcnow() + timedelta(seconds=10) + next_fetch_task = utcnow() + timedelta(seconds=1) while True: + now = utcnow() + if self.bot_channel: - try: - await self.next_task() - except Exception: - logger.exception("fetching next task failed") - await asyncio.sleep(30) + if now > next_fetch_task: + next_fetch_task = utcnow() + timedelta(seconds=60) + + try: + await self.next_task() + except Exception: + logger.exception("fetching next task failed") + + for x in self.reply_handlers.values(): + x.handler.tick(now) + + if now > next_remove_completed: + next_remove_completed = utcnow() + timedelta(seconds=10) + await self.remove_completed_handlers() + + await asyncio.sleep(1) async def _sync(self, command: str, message: discord.Message): @@ -293,38 +204,74 @@ class OpenAssistantBot: command_text: str = message.content command_text = command_text[1:] match command_text: - case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild" | "sync.clear_guild": + case "help" | "?": + await self.post_help(user=message.author) + case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild": if is_owner: await self._sync(command_text, message) case _: await message.reply(f"unknown command: {command_text}") + def recipient_filter(self, message: discord.Message) -> bool: + channel = message.channel + + if ( + message.channel.type == discord.ChannelType.private + or message.channel.type == discord.ChannelType.private_thread + ): + return True + + if ( + message.channel.type == discord.ChannelType.text + or message.channel.type == discord.ChannelType.public_thread + ): + while channel: + if self.bot_channel and channel.id == self.bot_channel.id: + return True + channel = channel.parent + + return False + async def handle_message(self, message: discord.Message): + if not self.recipient_filter(message): + return + user_id = message.author.id user_display_name = message.author.name + logger.debug( + f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})" + ) + command_prefix = "!" - if ( - message.channel.type == discord.ChannelType.private - and message.type == discord.MessageType.default - and message.content.startswith(command_prefix) - ): + if message.type == discord.MessageType.default and message.content.startswith(command_prefix): is_owner = self.owner_id and user_id == self.owner_id await self.handle_command(message, is_owner) if isinstance(message.channel, discord.Thread): handler = self.reply_handlers.get(message.channel.id) - if handler: - await handler(message) + if handler and not handler.handler.completed: + handler.handler.on_reply(message) if message.reference: handler = self.reply_handlers.get(message.reference.message_id) - if handler: - await handler(message) + if handler and not handler.handler.completed: + handler.handler.on_reply(message) - logger.debug( - f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})" - ) + async def remove_completed_handlers(self): + completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed] + if len(completed) == 0: + return + + for c in completed: + handler = self.reply_handlers[c] + del self.reply_handlers[c] + try: + await handler.handler.finalize() + except Exception: + logger.exception("handler finalize failed") + + logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})") def get_text_channel_by_name(self, channel_name) -> discord.TextChannel: for channel in self.client.get_all_channels(): diff --git a/bot/bot_base.py b/bot/bot_base.py new file mode 100644 index 00000000..76eca22d --- /dev/null +++ b/bot/bot_base.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +import asyncio +from abc import ABC +from dataclasses import dataclass +from typing import Any + +import discord +from api_client import ApiClient +from channel_handlers import ChannelHandlerBase +from loguru import logger +from message_templates import MessageTemplates + + +@dataclass +class ReplyHandlerInfo: + msg_id: int + handler_task: asyncio.Task + handler: ChannelHandlerBase + + +class BotBase(ABC): + bot_channel_name: str + debug: bool + backend: ApiClient + client: discord.Client + loop: asyncio.BaseEventLoop + owner_id: int + bot_channel: discord.TextChannel + templates: MessageTemplates + reply_handlers: dict[int, ReplyHandlerInfo] + + def __init__(self): + self.reply_handlers = {} # handlers by msg_id + + def ensure_bot_channel(self) -> None: + if self.bot_channel is None: + raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found") + + async def post( + self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None + ) -> discord.Message: + if channel is None: + self.ensure_bot_channel() + channel = self.bot_channel + return await channel.send(content=content, view=view) + + async def post_template( + self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any + ) -> discord.Message: + logger.debug(f"rendering {name}") + text = self.templates.render(name, **kwargs) + return await self.post(text, view=view, channel=channel) + + def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase): + if msg_id in self.reply_handlers: + raise RuntimeError(f"Handler already registered for msg_id: {msg_id}") + task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})") + task.add_done_callback(lambda t: handler.on_completed()) + self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler) diff --git a/bot/bot_settings.py b/bot/bot_settings.py index b7a46aa6..c976d7cd 100644 --- a/bot/bot_settings.py +++ b/bot/bot_settings.py @@ -8,6 +8,8 @@ class BotSettings(BaseSettings): BOT_TOKEN: str BOT_CHANNEL_NAME: str = "bot" OWNER_ID: int = None + TEMPLATE_DIR: str = "./templates" + DEBUG: bool = True settings = BotSettings(_env_file=".env") diff --git a/bot/channel_handlers.py b/bot/channel_handlers.py new file mode 100644 index 00000000..75f03c0e --- /dev/null +++ b/bot/channel_handlers.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +import asyncio +from abc import ABC, abstractmethod +from datetime import datetime + +import discord +from loguru import logger + + +class ChannelExpiredException(Exception): + pass + + +class ChannelHandlerBase(ABC): + queue: asyncio.Queue + completed: bool = False + expiry_date: datetime + expired: bool = False + + def __init__(self, *, expiry_date: datetime = None): + self.expiry_date = expiry_date + self.queue = asyncio.Queue() + + async def read(self) -> discord.Message: + """Call this method to read the next message from the user in the handler method.""" + if self.expired: + raise ChannelExpiredException() + + msg = await self.queue.get() + if msg is None: + if self.expired: + raise ChannelExpiredException() + else: + raise RuntimeError("Unexpected None message read") + return msg + + def on_reply(self, message: discord.Message) -> None: + self.queue.put_nowait(message) + + def on_expire(self) -> None: + logger.info("ChannelHandler: on_expire") + self.expired = True + self.queue.put_nowait(None) + + def on_completed(self) -> None: + logger.info("ChannelHandler: on_completed") + self.completed = True + + def tick(self, now: datetime): + if now > self.expiry_date and not self.expired: + self.on_expire() + + @abstractmethod + async def handler_loop(self): + ... + + async def finalize(self): + pass + + +class AutoDestructThreadHandler(ChannelHandlerBase): + first_message: discord.Message = None + thread: discord.Thread = None + + def __init__(self, *, expiry_date: datetime = None): + super().__init__(expiry_date=expiry_date) + + async def read(self) -> discord.Message: + try: + return await super().read() + except ChannelExpiredException: + await self.cleanup() + raise + + async def cleanup(self): + logger.debug("AutoDestructThreadHandler.cleanup") + if self.thread: + logger.debug(f"deleting thread: {self.thread.name}") + await self.thread.delete() + self.thread = None + if self.first_message: + logger.debug(f"deleting first_message: {self.first_message.content}") + await self.first_message.delete() + self.first_message = None + + async def finalize(self): + await self.cleanup() + return await super().finalize() diff --git a/bot/message_templates.py b/bot/message_templates.py new file mode 100644 index 00000000..df3ef1ac --- /dev/null +++ b/bot/message_templates.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +import jinja2 +from loguru import logger + + +class MessageTemplates: + def __init__(self, template_dir="./templates"): + self.env = jinja2.Environment( + loader=jinja2.FileSystemLoader(template_dir), + autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False), + ) + + def render(self, template_name, **kwargs): + template = self.env.get_template(template_name) + txt = template.render(kwargs) + logger.debug(txt) + + return txt diff --git a/bot/requirements.txt b/bot/requirements.txt index da4762a6..927ebcf2 100644 --- a/bot/requirements.txt +++ b/bot/requirements.txt @@ -1,4 +1,7 @@ discord.py==2.1.0 +Jinja2==3.1.2 pydantic==1.9.1 python-dotenv==0.21.0 +pytz==2022.7 requests==2.28.1 +schedule==1.1.0 diff --git a/bot/task_handlers.py b/bot/task_handlers.py new file mode 100644 index 00000000..1434d17c --- /dev/null +++ b/bot/task_handlers.py @@ -0,0 +1,267 @@ +# -*- coding: utf-8 -*- +from __future__ import annotations + +from abc import abstractmethod +from datetime import timedelta + +import discord +from api_client import ApiClient +from bot_base import BotBase +from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException +from loguru import logger +from oasst_shared.schemas import protocol as protocol_schema +from utils import DiscordTimestampStyle, discord_timestamp, utcnow + + +class Questionnaire(discord.ui.Modal, title="Questionnaire Response"): + name = discord.ui.TextInput(label="Name") + answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph) + + async def on_submit(self, interaction: discord.Interaction): + await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True) + + +class ChannelTaskBase(AutoDestructThreadHandler): + thread_name: str = "Replies" + expires_after: timedelta = timedelta(minutes=5) + backend: ApiClient + + async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message: + try: + self.bot = bot + self.task = task + self.backend = bot.backend + self.expiry_date = utcnow() + self.expires_after if self.expires_after else None + msg = await self.send_first_message() + self.first_message = msg + self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name) + await self.on_thread_created(self.thread) + except Exception: + logger.exception("start task failed") + await self.cleanup() # try to cleanup messag or thread + raise + + bot.register_reply_handler(msg_id=msg.id, handler=self) + return msg + + async def on_thread_created(self, thread: discord.Thread) -> None: + pass + + @abstractmethod + async def send_first_message(self) -> discord.message: + ... + + def to_api_user(self, user: discord.User) -> protocol_schema.User: + return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name) + + async def post_teaser_msg(self, template_name: str): + expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time) + expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time) + return await self.bot.post_template( + template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative + ) + + async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: + api_response = await self.backend.post_interaction(interaction) + if api_response.type != "task_done": + # multi-step tasks are not supported yet + logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})") + raise RuntimeError("Unexpected response from backend received") + return api_response + + def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: + return self.backend.post_interaction( + protocol_schema.TextReplyToPost( + post_id=str(self.first_message.id), + user_post_id=str(user_msg.id), + user=self.to_api_user(user_msg.author), + text=user_msg.content, + ) + ) + + async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: + try: + self.post_text_reply_to_post(user_msg) + await user_msg.add_reaction("✅") + except ChannelExpiredException: + raise + except Exception as e: + logger.exception("Error in handle_text_reply_to_post()") + await user_msg.add_reaction("❌") + await user_msg.reply(f"❌ Error communicating with backend: {e}") + + def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task: + return self.backend.post_interaction( + protocol_schema.PostRanking( + post_id=str(self.first_message.id), + user_post_id=str(user_msg.id), + user=self.to_api_user(user_msg.author), + ranking=ranking, + ) + ) + + async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task: + try: + ranking_str = user_msg.content + ranking = [int(x) - 1 for x in ranking_str.split(",")] + self.post_ranking(user_msg, ranking=ranking) + await user_msg.add_reaction("✅") + except ChannelExpiredException: + raise + except Exception as e: + logger.exception("Error in handle_ranking()") + await user_msg.add_reaction("❌") + await user_msg.reply(f"❌ Error communicating with backend: {e}") + + +class SummarizeStoryHandler(ChannelTaskBase): + task: protocol_schema.SummarizeStoryTask + thread_name: str = "Summaries" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_summarize_story.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_text_reply_to_post(msg) + + +class InitialPromptHandler(ChannelTaskBase): + task: protocol_schema.InitialPromptTask + thread_name: str = "Prompts" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_initial_prompt.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_text_reply_to_post(msg) + + +class UserReplyHandler(ChannelTaskBase): + task: protocol_schema.UserReplyTask + thread_name: str = "User replies" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_user_reply.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_text_reply_to_post(msg) + + +class AssistantReplyHandler(ChannelTaskBase): + task: protocol_schema.AssistantReplyTask + thread_name: str = "Assistant replies" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_assistant_reply.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_text_reply_to_post(msg) + + +class RankInitialPromptsHandler(ChannelTaskBase): + task: protocol_schema.RankInitialPromptsTask + thread_name: str = "User Responses" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_rank_initial_prompts.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_ranking(msg) + + +class RankConversationsHandler(ChannelTaskBase): + task: protocol_schema.RankConversationRepliesTask + thread_name: str = "Rankings" + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_rank_conversation_replies.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + await self.handle_ranking(msg) + + +class RatingButton(discord.ui.Button): + def __init__(self, label, value, response_handler): + super().__init__(label=label, style=discord.ButtonStyle.green) + self.value = value + self.response_handler = response_handler + + async def callback(self, interaction): + await self.response_handler(self.value, interaction) + + +def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: + view = discord.ui.View() + for i in range(lo, hi + 1): + view.add_item(RatingButton(str(i), i, response_handler)) + return view + + +class RateSummaryHandler(ChannelTaskBase): + task: protocol_schema.RateSummaryTask + thread_name: str = "Ratings" + + async def _rating_response_handler(self, score, interaction: discord.Interaction): + logger.info("rating_response_handler", score) + if self.thread: + try: + self.backend.post_interaction( + protocol_schema.PostRating( + post_id=str(self.first_message.id), + user_post_id=str(interaction.id), + user=self.to_api_user(interaction.user), + rating=score, + ) + ) + await interaction.response.send_message( + f"Thanks {interaction.user.display_name}, got your feedback: {score}!" + ) + except ChannelExpiredException: + raise + except Exception as e: + logger.exception("Error in _rating_response_handler()") + interaction.response.send_message(f"❌ Error communicating with backend: {e}") + + async def send_first_message(self) -> discord.message: + return await self.post_teaser_msg("teaser_rate_summary.msg") + + async def on_thread_created(self, thread: discord.Thread) -> None: + view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler) + return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task) + + async def handler_loop(self): + while True: + msg = await self.read() + logger.info(f"on_rate_summary_reply: {msg.content}") + await msg.add_reaction("❌") + await msg.reply("❌ Text intput not supported.") diff --git a/bot/templates/boot.msg b/bot/templates/boot.msg new file mode 100644 index 00000000..a3629715 --- /dev/null +++ b/bot/templates/boot.msg @@ -0,0 +1,13 @@ +``` +________ __ +\_____ \ _____ ______ _______/ |_ + / | \\__ \ / ___// ___/\ __\ +/ | \/ __ \_\___ \ \___ \ | | +\_______ (____ /____ >____ > |__| + \/ \/ \/ \/ + +{{bot_name}} {{version}} +git hash: {{git_hash}} +debug_mode: {{debug}} +``` +https://github.com/LAION-AI/Open-Assistant diff --git a/bot/templates/help.msg b/bot/templates/help.msg new file mode 100644 index 00000000..ca033c47 --- /dev/null +++ b/bot/templates/help.msg @@ -0,0 +1,15 @@ +**Open-Assistant Bot Help** + +Available slash-commands: + +`/work` Requests a new personalized human feedback task +`/help` Show this message + +{% if is_bot_owner %} +Commands for bot owners: + +`!sync` +`!sync.guild` +`!sync.copy_global` +`!sync.clear_guild` +{% endif %} \ No newline at end of file diff --git a/bot/templates/task_assistant_reply.msg b/bot/templates/task_assistant_reply.msg new file mode 100644 index 00000000..3dfe84a3 --- /dev/null +++ b/bot/templates/task_assistant_reply.msg @@ -0,0 +1,12 @@ +Act as the assistant and reply to the user. +Here is the conversation so far: +{% for message in task.conversation.messages %} +{% if message.is_assistant %} +:robot: Assistant: +{{ message.text }} +{% else %} +:person_red_hair: User: +**{{ message.text }}**" +{% endif %} +{% endfor %} +:robot: Assistant: { human, pls help me! ... } \ No newline at end of file diff --git a/bot/templates/task_initial_prompt.msg b/bot/templates/task_initial_prompt.msg new file mode 100644 index 00000000..47cf0f45 --- /dev/null +++ b/bot/templates/task_initial_prompt.msg @@ -0,0 +1,4 @@ +Please provide an initial prompt to the assistant. +{% if task.hint is not none %} +Hint: {{task.hint}} +{% endif %} \ No newline at end of file diff --git a/bot/templates/task_rank_conversation_replies.msg b/bot/templates/task_rank_conversation_replies.msg new file mode 100644 index 00000000..c0c8bc80 --- /dev/null +++ b/bot/templates/task_rank_conversation_replies.msg @@ -0,0 +1,13 @@ +Here is the conversation so far: +{% for message in task.conversation.messages %}{% if message.is_assistant %} +:robot: Assistant: +{{ message.text }} +{% else %} +:person_red_hair: User: +**{{ message.text }}**" +{% endif %}{% endfor %} +Rank the following replies: +{% for reply in task.replies %} +{{loop.index}}: {{reply}}{% endfor %} + +:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2"). \ No newline at end of file diff --git a/bot/templates/task_rank_initial_prompts.msg b/bot/templates/task_rank_initial_prompts.msg new file mode 100644 index 00000000..5a75cbd1 --- /dev/null +++ b/bot/templates/task_rank_initial_prompts.msg @@ -0,0 +1,5 @@ +Rank the following prompts: +{% for prompt in task.prompts %} +{{loop.index}}: {{prompt}}{% endfor %} + +:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2"). \ No newline at end of file diff --git a/bot/templates/task_rate_summary.msg b/bot/templates/task_rate_summary.msg new file mode 100644 index 00000000..31007c40 --- /dev/null +++ b/bot/templates/task_rate_summary.msg @@ -0,0 +1,7 @@ +Rate the following summary: +{{task.summary}} + +Full text: +{{task.full_text}} + +Rating scale: {{task.scale.min}} - {{task.scale.max}} diff --git a/bot/templates/task_summarize_story.msg b/bot/templates/task_summarize_story.msg new file mode 100644 index 00000000..24753841 --- /dev/null +++ b/bot/templates/task_summarize_story.msg @@ -0,0 +1,2 @@ +Summarize to the following story: +{{task.story}} diff --git a/bot/templates/task_user_reply.msg b/bot/templates/task_user_reply.msg new file mode 100644 index 00000000..c247daa5 --- /dev/null +++ b/bot/templates/task_user_reply.msg @@ -0,0 +1,12 @@ +Please provide a reply to the assistant. +Here is the conversation so far: +{% for message in task.conversation.messages %}{% if message.is_assistant %} +:robot: Assistant: +{{ message.text }} +{% else %} +:person_red_hair: User: +**{{ message.text }}**" +{% endif %}{% endfor %} +{% if task.hint %} +Hint: {{ task.hint }} +{% endif %} \ No newline at end of file diff --git a/bot/templates/teaser_assistant_reply.msg b/bot/templates/teaser_assistant_reply.msg new file mode 100644 index 00000000..6975d417 --- /dev/null +++ b/bot/templates/teaser_assistant_reply.msg @@ -0,0 +1,3 @@ +:robot: **Challenge: Assistant Reply** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_initial_prompt.msg b/bot/templates/teaser_initial_prompt.msg new file mode 100644 index 00000000..e9ae5c7a --- /dev/null +++ b/bot/templates/teaser_initial_prompt.msg @@ -0,0 +1,3 @@ +:microphone2: **Challenge: Initial Prompt** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_rank_conversation_replies.msg b/bot/templates/teaser_rank_conversation_replies.msg new file mode 100644 index 00000000..744f7a76 --- /dev/null +++ b/bot/templates/teaser_rank_conversation_replies.msg @@ -0,0 +1,3 @@ +:bar_chart: **Challenge: Rank Replies** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_rank_initial_prompts.msg b/bot/templates/teaser_rank_initial_prompts.msg new file mode 100644 index 00000000..07399f56 --- /dev/null +++ b/bot/templates/teaser_rank_initial_prompts.msg @@ -0,0 +1,3 @@ +:bar_chart: **Challenge: Rank Initial Prompts** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_rate_summary.msg b/bot/templates/teaser_rate_summary.msg new file mode 100644 index 00000000..41357b06 --- /dev/null +++ b/bot/templates/teaser_rate_summary.msg @@ -0,0 +1,3 @@ +:ballot_box: **Challenge: Rate Summary** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_summarize_story.msg b/bot/templates/teaser_summarize_story.msg new file mode 100644 index 00000000..6e5ee5e5 --- /dev/null +++ b/bot/templates/teaser_summarize_story.msg @@ -0,0 +1,3 @@ +:books: **Challenge: Summarize Story** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/teaser_user_reply.msg b/bot/templates/teaser_user_reply.msg new file mode 100644 index 00000000..47ec8a2d --- /dev/null +++ b/bot/templates/teaser_user_reply.msg @@ -0,0 +1,3 @@ +:person_red_hair: **Challenge: User Reply** + +:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}). \ No newline at end of file diff --git a/bot/templates/welcome.msg b/bot/templates/welcome.msg new file mode 100644 index 00000000..553f7925 --- /dev/null +++ b/bot/templates/welcome.msg @@ -0,0 +1,6 @@ +Hi there, + +I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗! +Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world. + +Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands. diff --git a/bot/utils.py b/bot/utils.py new file mode 100644 index 00000000..968e4498 --- /dev/null +++ b/bot/utils.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +import enum +import subprocess +from datetime import datetime + +import pytz + + +def get_git_head_hash(): + # get current git hash + x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True) + if x.returncode == 0: + return x.stdout.replace("\n", "") + return None + + +def utcnow() -> datetime: + return datetime.now(pytz.UTC) + + +class DiscordTimestampStyle(str, enum.Enum): + """ + Timestamp Styles + + t 16:20 Short Time + T 16:20:30 Long Time + d 20/04/2021 Short Date + D 20 April 2021 Long Date + f * 20 April 2021 16:20 Short Date/Time + F Tuesday, 20 April 2021 16:20 Long Date/Time + R 2 months ago Relative Time + + See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles + """ + + default = "" + short_time = "t" + long_time = "T" + short_date = "d" + long_date = "D" + short_date_time = "f" + long_date_time = "F" + relative_time = "R" + + +def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default): + parts = ["") + return "".join(parts)