diff --git a/bot/__main__.py b/bot/__main__.py new file mode 100644 index 00000000..1c456849 --- /dev/null +++ b/bot/__main__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from bot_settings import settings + +from bot import OpenAssistantBot + +# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot + +if __name__ == "__main__": + bot = OpenAssistantBot( + settings.BOT_TOKEN, + bot_channel_name=settings.BOT_CHANNEL_NAME, + backend_url=settings.BACKEND_URL, + api_key=settings.API_KEY, + ) + bot.run() diff --git a/bot/api_client.py b/bot/api_client.py new file mode 100644 index 00000000..6fe39c8b --- /dev/null +++ b/bot/api_client.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import enum +from typing import Optional, Type + +import requests +from schemas import protocol as protocol_schema + + +class TaskType(str, enum.Enum): + summarize_story = "summarize_story" + rate_summary = "rate_summary" + initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" + rank_initial_prompts = "rank_initial_prompts" + rank_user_replies = "rank_user_replies" + rank_assistant_replies = "rank_assistant_replies" + done = "task_done" + + +class ApiClient: + def __init__(self, backend_url: str, api_key: str): + self.backend_url = backend_url + self.api_key = api_key + + task_models_map: dict[str, Type[protocol_schema.Task]] = { + TaskType.summarize_story: protocol_schema.SummarizeStoryTask, + TaskType.rate_summary: protocol_schema.RateSummaryTask, + TaskType.initial_prompt: protocol_schema.InitialPromptTask, + TaskType.user_reply: protocol_schema.UserReplyTask, + TaskType.assistant_reply: protocol_schema.AssistantReplyTask, + TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, + TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, + TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, + TaskType.done: protocol_schema.TaskDone, + } + self.task_models_map = task_models_map + + def post(self, path: str, json: dict) -> dict: + response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) + response.raise_for_status() + return response.json() + + def _parse_task(self, data: dict) -> protocol_schema.Task: + if not isinstance(data, dict): + raise ValueError("dict expected") + + task_type = data.get("type") + if task_type not in self.task_models_map: + raise RuntimeError(f"Unsupported task type: {task_type}") + + return self.task_models_map[task_type].parse_obj(data) + + def fetch_task( + self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + ) -> protocol_schema.Task: + req = protocol_schema.TaskRequest(type=task_type, user=user) + data = self.post("/api/v1/tasks/", req.dict()) + return self._parse_task(data) + + def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + return self.fetch_task(protocol_schema.TaskRequestType.random, user) + + def ack_task(self, task_id: str, post_id: str) -> None: + req = protocol_schema.TaskAck(post_id=post_id) + return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) + + def nack_task(self, task_id: str, reason: str) -> None: + req = protocol_schema.TaskNAck(reason=reason) + return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) + + def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone: + data = self.post("/api/v1/tasks/interaction", interaction.dict()) + return self._parse_task(data) diff --git a/bot/bot.py b/bot/bot.py index c2da5100..1910af6a 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,215 +1,236 @@ # -*- coding: utf-8 -*- - -import json -import os +import asyncio +from typing import Any import discord -import requests -from discord import app_commands -from dotenv import load_dotenv -from loguru import logger - -bot_url = "https://discord.com/api/oauth2/authorize?client_id=1051614245940375683&permissions=8&scope=bot" - -# Load up all the important environment variables. -load_dotenv() - -# For authentication. -TOKEN = os.getenv("DISCORD_TOKEN") - -# For Backends. -API_SERVER_URL = os.getenv("API_SERVER_URL") -API_SERVER_KEY = os.getenv("API_SERVER_KEY") - -labelers_url = f"{API_SERVER_URL}/api/v1/labelers/" -prompts_url = f"{API_SERVER_URL}/api/v1/prompts/" -headers = {"X-API-Key": API_SERVER_KEY} - -# For testing only. -TEST_GUILD = os.getenv("TEST_GUILD") -TEST_GUILD_LAION = os.getenv("TEST_GUILD_LAION") -# TEST_GUILD = False -guild_ids = [TEST_GUILD, TEST_GUILD_LAION] +from api_client import ApiClient, TaskType +from schemas import protocol as protocol_schema -# Initiate the client and command tree to create slash commands. -class OpenAssistantClient(discord.Client): - def __init__(self, *, intents: discord.Intents): - super().__init__(intents=intents) - self.tree = app_commands.CommandTree(self) +class RatingButton(discord.ui.Button): + def __init__(self, label, value, response_handler): + super().__init__(label=label, style=discord.ButtonStyle.green) + self.value = value + self.response_handler = response_handler + + async def callback(self, interaction): + await self.response_handler(self.value, interaction) + + +def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: + view = discord.ui.View() + for i in range(lo, hi + 1): + view.add_item(RatingButton(str(i), i, response_handler)) + return view + + +class ModifiedClient(discord.Client): + def __init__(self, *, intents: discord.Intents, **options: Any): + super().__init__(intents=intents, **options) async def setup_hook(self): - if TEST_GUILD: - # When testing the bot it's handy to run in a single server (called a - # Guide in the API). This is relatively fast. - for guild_id in guild_ids: - guild = discord.Object(id=guild_id) - self.tree.copy_global_to(guild=guild) - await self.tree.sync(guild=guild) - - # guild = discord.Object(id=TEST_GUILD) - # self.tree.copy_global_to(guild=guild) - # await self.tree.sync(guild=guild) - else: - # This can take up to an hour for the commands to be registered. - await self.tree.sync() - logger.debug("Ready!") + print("setup") -# List the set of intents needed for commands to operate properly. -intents = discord.Intents.default() -intents.message_content = True -client = OpenAssistantClient(intents=intents) +class OpenAssistantBot: + def __init__(self, bot_token: str, bot_channel_name: str, backend_url: str, api_key: str): + intents = discord.Intents.default() + intents.message_content = True + self.bot_token = bot_token + client = ModifiedClient(intents=intents) + self.client = client + self.bot_channel: discord.TextChannel = None + self.backend = ApiClient(backend_url, api_key) + self.reply_handlers = {} # handlers by msg_id + @client.event + async def on_ready(): + self.bot_channel = self.get_text_channel_by_name(bot_channel_name) -class LikeButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👍") - self.channel = channel - self.username = username - self.prompt = prompt + client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") + print(f"{client.user} is now running!") - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - await interaction.response.send_message("Thanks for your feedback. You liked this 👍 ") - - -class NeutralButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="😐") - self.channel = channel - self.username = username - self.prompt = prompt - - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - await interaction.response.send_message("Thanks for your feedback. You thought this was neutral 😐 ") - - -class DislikeButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👎") - self.channel = channel - self.username = username - self.prompt = prompt - - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - # send the feedback to the backend # - await interaction.response.send_message("Thanks for your feedback. You disliked this 👎 ") - - -@client.tree.command() -async def register(interaction: discord.Interaction): - """Registers the user for submissions.""" - labeler = { - "discord_username": f"{interaction.user.id}", - "display_name": interaction.user.name, - "is_enabled": True, - } - response = requests.post(labelers_url, headers=headers, json=labeler) - if response.status_code == 200: - await interaction.response.send_message(f"Added you {interaction.user.name}") - else: - logger.debug(response) - await interaction.response.send_message("Failed to add you") - - -@client.tree.command() -async def list_participants(interaction: discord.Interaction): - """Reports the set of registered participants.""" - response = requests.get(labelers_url, headers=headers) - if response.status_code == 200: - names = ",".join([labeler["display_name"] for labeler in response.json()]) - await interaction.response.send_message(f"Found these users: {names}") - else: - await interaction.response.send_message("Failed to fetch participants") - - -async def send_prompt_with_response_and_button(channel, username, prompt, response): - await channel.send(f"What do you think about the following interaction: \nprompt: {prompt} \nresponse: {response}") - # await channel.send(f'Please click on the button that best describes your reaction to the response:') - - # add buttons - view = discord.ui.View() - like = LikeButton(label="Like", channel=channel, username=username, prompt=prompt) - neutral = NeutralButton(label="Neutral", channel=channel, username=username, prompt=prompt) - dislike = DislikeButton(label="Dislike", channel=channel, username=username, prompt=prompt) - - view.add_item(item=like) - view.add_item(item=neutral) - view.add_item(item=dislike) - await channel.send(view=view) - - -@client.tree.command() -async def review_prompts(interaction: discord.Interaction, number_of_prompts: int): - # get the prompt from the db - url = f"{prompts_url}?begin_id=0&limit={number_of_prompts}" - response = requests.get(url, headers=headers) - if response.status_code == 200: - prompts = response.json() - logger.debug("the responses are:", prompts) - for prompt in prompts: - await send_prompt_with_response_and_button( - interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"] - ) - else: - await interaction.response.send_message("Failed to get prompts for review") - - -@client.tree.command() -async def add_prompt(interaction: discord.Interaction, prompt: str, response: str, language: str = "en"): - """Uploads a single prompt to the server.""" - prompt = { - "discord_username": f"{interaction.user.id}", - "labeler_id": 5, - "prompt": prompt, - "response": response, - "lang": language, - } - response = requests.post(prompts_url, headers=headers, json=prompt) - if response.status_code == 200: - await send_prompt_with_response_and_button( - interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"] - ) - # send the prompt back with buttons for the user to click on - # await interaction.response.send_message("Added your prompt") - else: - await interaction.response.send_message("Failed to add the prompt") - - -@client.tree.command() -async def add_prompts_set(interaction: discord.Interaction, prompts: discord.Attachment): - """Uploads a batch of prompts to the server.""" - # Loading a bunch of prompts from a file can take a while. So first defer - # the response to ensure we're able to later tell the user what happened. - await interaction.response.defer(ephemeral=True) - - # Read the prompts and load them one by one. - # TODO: Upload a batch when the API supports it. - # TODO: Handle incorrect file types and parsing errors. - prompts_raw = await prompts.read() - prompts_loaded = json.loads(prompts_raw) - count = 0 - for entry in prompts_loaded: - for response in entry["responses"]: - prompt = { - "discord_username": f"{interaction.user.id}", - "labeler_id": 5, - "prompt": entry["prompt"], - "response": response, - "lang": "en", - } - response = requests.post(prompts_url, headers=headers, json=prompt) - if response.status_code != 200: - await interaction.followup.send("Failed to upload") + @client.event + async def on_message(message: discord.Message): + # ignore own messages + if message.author == client.user: return - count += 1 - await interaction.followup.send(f"Loaded up {count} prompts") + await self.handle_message(message) -client.run(TOKEN) + async def generate_summarize_story(self, task: protocol_schema.SummarizeStoryTask): + text = f"Summarize to the following story:\n{task.story}" + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_summarize_story_reply", message) + await message.reply("thx, on_summarize_story_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask): + s = [ + "Rate the following summary:", + task.summary, + "Full text:", + task.full_text, + f"Rating scale: {task.scale.min} - {task.scale.max}", + ] + text = "\n".join(s) + + async def rating_response_handler(score, interaction: discord.Interaction): + print("rating_response_handler", score) + await interaction.response.send_message(f"got your feedback: {score}") + + view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler) + msg: discord.Message = await self.bot_channel.send(text, view=view) + + async def on_reply(message: discord.Message): + print("on_summary_reply", message) + await message.reply("thx, on_summary_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_initial_prompt(self, task: protocol_schema.InitialPromptTask): + text = "Please provide an initial prompt to the assistant." + if task.hint: + text += f"\nHint: {task.hint}" + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_initial_prompt_reply", message) + await message.reply("thx, on_initial_prompt_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + def _render_message(self, message: protocol_schema.ConversationMessage) -> str: + """Render a message to the user.""" + if message.is_assistant: + return f"Assistant: {message.text}" + return f"User: {message.text}" + + async def generate_user_reply(self, task: protocol_schema.UserReplyTask): + s = ["Please provide a reply to the assistant.", "Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + if task.hint: + s.append(f"Hint: {task.hint}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_user_reply_reply", message) + await message.reply("thx, on_user_reply_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_assistant_reply(self, task: protocol_schema.AssistantReplyTask): + s = ["Act as the assistant and reply to the user.", "Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_assistant_reply_reply", message) + await message.reply("thx, on_assistant_reply_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rank_initial_prompts(self, task: protocol_schema.RankInitialPromptsTask): + s = ["Rank the following prompts:"] + for idx, prompt in enumerate(task.prompts, start=1): + s.append(f"{idx}: {prompt}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_rank_initial_prompts_reply", message) + await message.reply("thx, on_rank_initial_prompts_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rank_conversation(self, task: protocol_schema.RankConversationRepliesTask): + s = ["Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + s.append("Rank the following replies:") + for idx, reply in enumerate(task.replies, start=1): + s.append(f"{idx}: {reply}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_rrank_conversation_reply", message) + message + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def next_task(self): + task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) + # task = self.backend.fetch_random_task(user=None) + + msg: discord.Message = None + match task.type: + case TaskType.summarize_story: + msg = await self.generate_summarize_story(task) + case TaskType.rate_summary: + msg = await self.generate_rate_summary(task) + case TaskType.initial_prompt: + msg = await self.generate_initial_prompt(task) + case TaskType.user_reply: + msg = await self.generate_user_reply(task) + case TaskType.assistant_reply: + msg = await self.generate_assistant_reply(task) + case TaskType.rank_initial_prompts: + msg = await self.generate_rank_initial_prompts(task) + case TaskType.rank_user_replies | TaskType.rank_assistant_replies: + msg = await self.generate_rank_conversation(task) + + if msg is not None: + await self.backend.ack_task(task.id, msg.id) + else: + await self.backend.nack_task(task.id, "not supported") + + async def background_timer(self): + while True: + if self.bot_channel: + try: + await self.next_task() + except Exception as e: + print(e) + await asyncio.sleep(30) + + def run(self): + """Run bot loop blocking.""" + self.client.run(self.bot_token) + + async def handle_message(self, message: discord.Message): + user_id = message.author.id + user_display_name = message.author.name + + if message.reference: + handler = self.reply_handlers.get(message.reference.message_id) + if handler: + await handler(message) + + print(user_id, user_display_name, message.content, type(message.content)) + + def get_text_channel_by_name(self, channel_name) -> discord.TextChannel: + for channel in self.client.get_all_channels(): + if channel.type == discord.ChannelType.text and channel.name == channel_name: + return channel diff --git a/bot/bot_settings.py b/bot/bot_settings.py new file mode 100644 index 00000000..3323b2fe --- /dev/null +++ b/bot/bot_settings.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +from pydantic import AnyHttpUrl, BaseSettings + + +class BotSettings(BaseSettings): + BACKEND_URL: AnyHttpUrl = "http://localhost:8080" + API_KEY: str = "any_key" + BOT_TOKEN: str + BOT_CHANNEL_NAME: str = "bot" + TEST_GUILD: str = None + + +settings = BotSettings(_env_file=".env") diff --git a/bot/requirements.txt b/bot/requirements.txt index 617e7071..da4762a6 100644 --- a/bot/requirements.txt +++ b/bot/requirements.txt @@ -1,2 +1,4 @@ discord.py==2.1.0 +pydantic==1.9.1 python-dotenv==0.21.0 +requests==2.28.1 diff --git a/bot/schemas/protocol.py b/bot/schemas/protocol.py new file mode 100644 index 00000000..d5f508b6 --- /dev/null +++ b/bot/schemas/protocol.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- +import enum +from typing import Literal, Optional, Union +from uuid import UUID, uuid4 + +import pydantic +from pydantic import BaseModel + + +class TaskRequestType(str, enum.Enum): + random = "random" + summarize_story = "summarize_story" + rate_summary = "rate_summary" + initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" + rank_initial_prompts = "rank_initial_prompts" + rank_user_replies = "rank_user_replies" + rank_assistant_replies = "rank_assistant_replies" + + +class User(BaseModel): + id: str + display_name: str + auth_method: Literal["discord", "local"] + + +class ConversationMessage(BaseModel): + """Represents a message in a conversation between the user and the assistant.""" + + text: str + is_assistant: bool + + +class Conversation(BaseModel): + """Represents a conversation between the user and the assistant.""" + + messages: list[ConversationMessage] = [] + + +class TaskRequest(BaseModel): + """The frontend asks the backend for a task.""" + + type: TaskRequestType = TaskRequestType.random + user: Optional[User] = None + + +class TaskAck(BaseModel): + """The frontend acknowledges that it has received a task and created a post.""" + + post_id: str + + +class TaskNAck(BaseModel): + """The frontend acknowledges that it has received a task but cannot create a post.""" + + reason: str + + +class Task(BaseModel): + """A task is a unit of work that the backend gives to the frontend.""" + + id: UUID = pydantic.Field(default_factory=uuid4) + type: str + + +class SummarizeStoryTask(Task): + """A task to summarize a story.""" + + type: Literal["summarize_story"] = "summarize_story" + story: str + + +class RatingScale(BaseModel): + min: int + max: int + + +class AbstractRatingTask(Task): + """A task to rate something.""" + + scale: RatingScale = RatingScale(min=1, max=5) + + +class RateSummaryTask(AbstractRatingTask): + """A task to rate a summary.""" + + type: Literal["rate_summary"] = "rate_summary" + full_text: str + summary: str + + +class WithHintMixin(BaseModel): + hint: str | None = None # provide a hint to the user to spark their imagination + + +class InitialPromptTask(Task, WithHintMixin): + """A task to prompt the user to submit an initial prompt to the assistant.""" + + type: Literal["initial_prompt"] = "initial_prompt" + + +class ReplyToConversationTask(Task): + """A task to prompt the user to submit a reply to a conversation.""" + + type: Literal["reply_to_conversation"] = "reply_to_conversation" + conversation: Conversation # the conversation so far + + +class UserReplyTask(ReplyToConversationTask, WithHintMixin): + """A task to prompt the user to submit a reply to the assistant.""" + + type: Literal["user_reply"] = "user_reply" + + +class AssistantReplyTask(ReplyToConversationTask): + """A task to prompt the user to act as the assistant.""" + + type: Literal["assistant_reply"] = "assistant_reply" + + +class RankInitialPromptsTask(Task): + """A task to rank a set of initial prompts.""" + + type: Literal["rank_initial_prompts"] = "rank_initial_prompts" + prompts: list[str] + + +class RankConversationRepliesTask(Task): + """A task to rank a set of replies to a conversation.""" + + type: Literal["rank_conversation_replies"] = "rank_conversation_replies" + conversation: Conversation # the conversation so far + replies: list[str] + + +class RankUserRepliesTask(RankConversationRepliesTask): + """A task to rank a set of user replies to a conversation.""" + + type: Literal["rank_user_replies"] = "rank_user_replies" + + +class RankAssistantRepliesTask(RankConversationRepliesTask): + """A task to rank a set of assistant replies to a conversation.""" + + type: Literal["rank_assistant_replies"] = "rank_assistant_replies" + + +class TaskDone(Task): + """Signals to the frontend that the task is done.""" + + type: Literal["task_done"] = "task_done" + + +AnyTask = Union[ + TaskDone, + SummarizeStoryTask, + RateSummaryTask, + InitialPromptTask, + ReplyToConversationTask, + UserReplyTask, + AssistantReplyTask, + RankInitialPromptsTask, + RankConversationRepliesTask, + RankUserRepliesTask, + RankAssistantRepliesTask, +] + + +class Interaction(BaseModel): + """An interaction is a user-generated action in the frontend.""" + + type: str + user: User + + +class TextReplyToPost(Interaction): + """A user has replied to a post with text.""" + + type: Literal["text_reply_to_post"] = "text_reply_to_post" + post_id: str + user_post_id: str + text: str + + +class PostRating(Interaction): + """A user has rated a post.""" + + type: Literal["post_rating"] = "post_rating" + post_id: str + rating: int + + +class PostRanking(Interaction): + """A user has given a ranking for a post.""" + + type: Literal["post_ranking"] = "post_ranking" + post_id: str + ranking: list[int] + + +AnyInteraction = Union[ + TextReplyToPost, + PostRating, + PostRanking, +]