diff --git a/.gitignore b/.gitignore index ce7a9b8a..2c698b44 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ .venv *.pyc *.swp +*.egg-info +__pycache__ diff --git a/.vscode/settings.json b/.vscode/settings.json index b7368caa..56a51f78 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,3 +1,4 @@ { - "python.formatting.provider": "black" + "python.formatting.provider": "black", + "python.analysis.extraPaths": ["${workspaceFolder}/oasst-shared"] } diff --git a/backend/Dockerfile b/backend/Dockerfile deleted file mode 100644 index 8074ef3a..00000000 --- a/backend/Dockerfile +++ /dev/null @@ -1,9 +0,0 @@ -FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10 - -COPY ./requirements.txt /app/requirements.txt - -RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt - -ENV PORT 8080 - -COPY ./app /app diff --git a/backend/app/alembic.ini b/backend/alembic.ini similarity index 100% rename from backend/app/alembic.ini rename to backend/alembic.ini diff --git a/backend/app/alembic/README b/backend/alembic/README similarity index 100% rename from backend/app/alembic/README rename to backend/alembic/README diff --git a/backend/app/alembic/env.py b/backend/alembic/env.py similarity index 97% rename from backend/app/alembic/env.py rename to backend/alembic/env.py index 6d2ec0c3..83de474c 100644 --- a/backend/app/alembic/env.py +++ b/backend/alembic/env.py @@ -3,7 +3,7 @@ from logging.config import fileConfig import sqlmodel from alembic import context -from oasst import models # noqa: F401 +from oasst_backend import models # noqa: F401 from sqlalchemy import engine_from_config, pool # this is the Alembic Config object, which provides diff --git a/backend/app/alembic/script.py.mako b/backend/alembic/script.py.mako similarity index 100% rename from backend/app/alembic/script.py.mako rename to backend/alembic/script.py.mako diff --git a/backend/app/alembic/versions/2022_12_15_0000-23e5fea252dd_first_revision.py b/backend/alembic/versions/2022_12_15_0000-23e5fea252dd_first_revision.py similarity index 100% rename from backend/app/alembic/versions/2022_12_15_0000-23e5fea252dd_first_revision.py rename to backend/alembic/versions/2022_12_15_0000-23e5fea252dd_first_revision.py diff --git a/backend/app/alembic/versions/2022_12_16_0000-cd7de470586e_v1_db_structure.py b/backend/alembic/versions/2022_12_16_0000-cd7de470586e_v1_db_structure.py similarity index 100% rename from backend/app/alembic/versions/2022_12_16_0000-cd7de470586e_v1_db_structure.py rename to backend/alembic/versions/2022_12_16_0000-cd7de470586e_v1_db_structure.py diff --git a/backend/app/alembic/versions/2022_12_17_2230-6368515778c5_add_auth_method_to_person.py b/backend/alembic/versions/2022_12_17_2230-6368515778c5_add_auth_method_to_person.py similarity index 100% rename from backend/app/alembic/versions/2022_12_17_2230-6368515778c5_add_auth_method_to_person.py rename to backend/alembic/versions/2022_12_17_2230-6368515778c5_add_auth_method_to_person.py diff --git a/backend/app/oasst/schemas/__init__.py b/backend/app/oasst/schemas/__init__.py deleted file mode 100644 index 5ee00d4a..00000000 --- a/backend/app/oasst/schemas/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -*- coding: utf-8 -*- -__all__ = [] diff --git a/backend/app/main.py b/backend/main.py similarity index 93% rename from backend/app/main.py rename to backend/main.py index f78f608c..abcd2391 100644 --- a/backend/app/main.py +++ b/backend/main.py @@ -5,8 +5,8 @@ import alembic.command import alembic.config import fastapi from loguru import logger -from oasst.api.v1.api import api_router -from oasst.config import settings +from oasst_backend.api.v1.api import api_router +from oasst_backend.config import settings from starlette.middleware.cors import CORSMiddleware app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") diff --git a/backend/app/oasst/__init__.py b/backend/oasst_backend/__init__.py similarity index 100% rename from backend/app/oasst/__init__.py rename to backend/oasst_backend/__init__.py diff --git a/backend/app/oasst/api/__init__.py b/backend/oasst_backend/api/__init__.py similarity index 100% rename from backend/app/oasst/api/__init__.py rename to backend/oasst_backend/api/__init__.py diff --git a/backend/app/oasst/api/deps.py b/backend/oasst_backend/api/deps.py similarity index 93% rename from backend/app/oasst/api/deps.py rename to backend/oasst_backend/api/deps.py index bfa54931..96af5c5e 100644 --- a/backend/app/oasst/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -6,9 +6,9 @@ from uuid import UUID from fastapi import HTTPException, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery from loguru import logger -from oasst.config import settings -from oasst.database import engine -from oasst.models import ApiClient +from oasst_backend.config import settings +from oasst_backend.database import engine +from oasst_backend.models import ApiClient from sqlmodel import Session from starlette.status import HTTP_403_FORBIDDEN diff --git a/backend/app/oasst/api/v1/__init__.py b/backend/oasst_backend/api/v1/__init__.py similarity index 100% rename from backend/app/oasst/api/v1/__init__.py rename to backend/oasst_backend/api/v1/__init__.py diff --git a/backend/app/oasst/api/v1/api.py b/backend/oasst_backend/api/v1/api.py similarity index 79% rename from backend/app/oasst/api/v1/api.py rename to backend/oasst_backend/api/v1/api.py index 3d568cb9..cd1119d6 100644 --- a/backend/app/oasst/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- from fastapi import APIRouter -from oasst.api.v1 import tasks +from oasst_backend.api.v1 import tasks api_router = APIRouter() api_router.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) diff --git a/backend/app/oasst/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py similarity index 97% rename from backend/app/oasst/api/v1/tasks.py rename to backend/oasst_backend/api/v1/tasks.py index 41f01f3c..40965022 100644 --- a/backend/app/oasst/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -6,10 +6,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException from fastapi.security.api_key import APIKey from loguru import logger -from oasst.api import deps -from oasst.models.db_payload import TaskPayload -from oasst.prompt_repository import PromptRepository -from oasst.schemas import protocol as protocol_schema +from oasst_backend.api import deps +from oasst_backend.models.db_payload import TaskPayload +from oasst_backend.prompt_repository import PromptRepository +from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session from starlette.status import HTTP_400_BAD_REQUEST diff --git a/backend/app/oasst/config.py b/backend/oasst_backend/config.py similarity index 100% rename from backend/app/oasst/config.py rename to backend/oasst_backend/config.py diff --git a/backend/app/oasst/crud/__init__.py b/backend/oasst_backend/crud/__init__.py similarity index 100% rename from backend/app/oasst/crud/__init__.py rename to backend/oasst_backend/crud/__init__.py diff --git a/backend/app/oasst/crud/base.py b/backend/oasst_backend/crud/base.py similarity index 100% rename from backend/app/oasst/crud/base.py rename to backend/oasst_backend/crud/base.py diff --git a/backend/app/oasst/database.py b/backend/oasst_backend/database.py similarity index 81% rename from backend/app/oasst/database.py rename to backend/oasst_backend/database.py index ca729f4e..66d7a857 100644 --- a/backend/app/oasst/database.py +++ b/backend/oasst_backend/database.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -from oasst.config import settings +from oasst_backend.config import settings from sqlmodel import create_engine if settings.DATABASE_URI is None: diff --git a/backend/app/oasst/models/__init__.py b/backend/oasst_backend/models/__init__.py similarity index 100% rename from backend/app/oasst/models/__init__.py rename to backend/oasst_backend/models/__init__.py diff --git a/backend/app/oasst/models/api_client.py b/backend/oasst_backend/models/api_client.py similarity index 100% rename from backend/app/oasst/models/api_client.py rename to backend/oasst_backend/models/api_client.py diff --git a/backend/app/oasst/models/db_payload.py b/backend/oasst_backend/models/db_payload.py similarity index 94% rename from backend/app/oasst/models/db_payload.py rename to backend/oasst_backend/models/db_payload.py index b01cecce..2a4438e2 100644 --- a/backend/app/oasst/models/db_payload.py +++ b/backend/oasst_backend/models/db_payload.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- from typing import Literal -from oasst.models.payload_column_type import payload_type -from oasst.schemas import protocol as protocol_schema +from oasst_backend.models.payload_column_type import payload_type +from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel diff --git a/backend/app/oasst/models/payload_column_type.py b/backend/oasst_backend/models/payload_column_type.py similarity index 100% rename from backend/app/oasst/models/payload_column_type.py rename to backend/oasst_backend/models/payload_column_type.py diff --git a/backend/app/oasst/models/person.py b/backend/oasst_backend/models/person.py similarity index 100% rename from backend/app/oasst/models/person.py rename to backend/oasst_backend/models/person.py diff --git a/backend/app/oasst/models/person_stats.py b/backend/oasst_backend/models/person_stats.py similarity index 100% rename from backend/app/oasst/models/person_stats.py rename to backend/oasst_backend/models/person_stats.py diff --git a/backend/app/oasst/models/post.py b/backend/oasst_backend/models/post.py similarity index 100% rename from backend/app/oasst/models/post.py rename to backend/oasst_backend/models/post.py diff --git a/backend/app/oasst/models/post_reaction.py b/backend/oasst_backend/models/post_reaction.py similarity index 100% rename from backend/app/oasst/models/post_reaction.py rename to backend/oasst_backend/models/post_reaction.py diff --git a/backend/app/oasst/models/work_package.py b/backend/oasst_backend/models/work_package.py similarity index 100% rename from backend/app/oasst/models/work_package.py rename to backend/oasst_backend/models/work_package.py diff --git a/backend/app/oasst/prompt_repository.py b/backend/oasst_backend/prompt_repository.py similarity index 97% rename from backend/app/oasst/prompt_repository.py rename to backend/oasst_backend/prompt_repository.py index 35f1c9b3..9f7bb1dd 100644 --- a/backend/app/oasst/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -3,11 +3,11 @@ from datetime import datetime from typing import Optional from uuid import UUID, uuid4 -import oasst.models.db_payload as db_payload +import oasst_backend.models.db_payload as db_payload from loguru import logger -from oasst.models import ApiClient, Person, Post, PostReaction, WorkPackage -from oasst.models.payload_column_type import PayloadContainer -from oasst.schemas import protocol as protocol_schema +from oasst_backend.models import ApiClient, Person, Post, PostReaction, WorkPackage +from oasst_backend.models.payload_column_type import PayloadContainer +from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session diff --git a/bot/__main__.py b/bot/__main__.py new file mode 100644 index 00000000..1c456849 --- /dev/null +++ b/bot/__main__.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from bot_settings import settings + +from bot import OpenAssistantBot + +# invite bot url: https://discord.com/api/oauth2/authorize?client_id=1054078345542910022&permissions=1634235579456&scope=bot + +if __name__ == "__main__": + bot = OpenAssistantBot( + settings.BOT_TOKEN, + bot_channel_name=settings.BOT_CHANNEL_NAME, + backend_url=settings.BACKEND_URL, + api_key=settings.API_KEY, + ) + bot.run() diff --git a/bot/api_client.py b/bot/api_client.py new file mode 100644 index 00000000..6fe39c8b --- /dev/null +++ b/bot/api_client.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import enum +from typing import Optional, Type + +import requests +from schemas import protocol as protocol_schema + + +class TaskType(str, enum.Enum): + summarize_story = "summarize_story" + rate_summary = "rate_summary" + initial_prompt = "initial_prompt" + user_reply = "user_reply" + assistant_reply = "assistant_reply" + rank_initial_prompts = "rank_initial_prompts" + rank_user_replies = "rank_user_replies" + rank_assistant_replies = "rank_assistant_replies" + done = "task_done" + + +class ApiClient: + def __init__(self, backend_url: str, api_key: str): + self.backend_url = backend_url + self.api_key = api_key + + task_models_map: dict[str, Type[protocol_schema.Task]] = { + TaskType.summarize_story: protocol_schema.SummarizeStoryTask, + TaskType.rate_summary: protocol_schema.RateSummaryTask, + TaskType.initial_prompt: protocol_schema.InitialPromptTask, + TaskType.user_reply: protocol_schema.UserReplyTask, + TaskType.assistant_reply: protocol_schema.AssistantReplyTask, + TaskType.rank_initial_prompts: protocol_schema.RankInitialPromptsTask, + TaskType.rank_user_replies: protocol_schema.RankUserRepliesTask, + TaskType.rank_assistant_replies: protocol_schema.RankAssistantRepliesTask, + TaskType.done: protocol_schema.TaskDone, + } + self.task_models_map = task_models_map + + def post(self, path: str, json: dict) -> dict: + response = requests.post(f"{self.backend_url}{path}", json=json, headers={"X-API-Key": self.api_key}) + response.raise_for_status() + return response.json() + + def _parse_task(self, data: dict) -> protocol_schema.Task: + if not isinstance(data, dict): + raise ValueError("dict expected") + + task_type = data.get("type") + if task_type not in self.task_models_map: + raise RuntimeError(f"Unsupported task type: {task_type}") + + return self.task_models_map[task_type].parse_obj(data) + + def fetch_task( + self, task_type: protocol_schema.TaskRequestType, user: Optional[protocol_schema.User] = None + ) -> protocol_schema.Task: + req = protocol_schema.TaskRequest(type=task_type, user=user) + data = self.post("/api/v1/tasks/", req.dict()) + return self._parse_task(data) + + def fetch_random_task(self, user: Optional[protocol_schema.User] = None) -> protocol_schema.Task: + return self.fetch_task(protocol_schema.TaskRequestType.random, user) + + def ack_task(self, task_id: str, post_id: str) -> None: + req = protocol_schema.TaskAck(post_id=post_id) + return self.post(f"/api/v1/tasks/{task_id}/ack", req.dict()) + + def nack_task(self, task_id: str, reason: str) -> None: + req = protocol_schema.TaskNAck(reason=reason) + return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) + + def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone: + data = self.post("/api/v1/tasks/interaction", interaction.dict()) + return self._parse_task(data) diff --git a/bot/bot.py b/bot/bot.py index c2da5100..376b0b3c 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,215 +1,236 @@ # -*- coding: utf-8 -*- - -import json -import os +import asyncio +from typing import Any import discord -import requests -from discord import app_commands -from dotenv import load_dotenv -from loguru import logger - -bot_url = "https://discord.com/api/oauth2/authorize?client_id=1051614245940375683&permissions=8&scope=bot" - -# Load up all the important environment variables. -load_dotenv() - -# For authentication. -TOKEN = os.getenv("DISCORD_TOKEN") - -# For Backends. -API_SERVER_URL = os.getenv("API_SERVER_URL") -API_SERVER_KEY = os.getenv("API_SERVER_KEY") - -labelers_url = f"{API_SERVER_URL}/api/v1/labelers/" -prompts_url = f"{API_SERVER_URL}/api/v1/prompts/" -headers = {"X-API-Key": API_SERVER_KEY} - -# For testing only. -TEST_GUILD = os.getenv("TEST_GUILD") -TEST_GUILD_LAION = os.getenv("TEST_GUILD_LAION") -# TEST_GUILD = False -guild_ids = [TEST_GUILD, TEST_GUILD_LAION] +from api_client import ApiClient, TaskType +from oasst_shared.schemas import protocol as protocol_schema -# Initiate the client and command tree to create slash commands. -class OpenAssistantClient(discord.Client): - def __init__(self, *, intents: discord.Intents): - super().__init__(intents=intents) - self.tree = app_commands.CommandTree(self) +class RatingButton(discord.ui.Button): + def __init__(self, label, value, response_handler): + super().__init__(label=label, style=discord.ButtonStyle.green) + self.value = value + self.response_handler = response_handler + + async def callback(self, interaction): + await self.response_handler(self.value, interaction) + + +def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: + view = discord.ui.View() + for i in range(lo, hi + 1): + view.add_item(RatingButton(str(i), i, response_handler)) + return view + + +class ModifiedClient(discord.Client): + def __init__(self, *, intents: discord.Intents, **options: Any): + super().__init__(intents=intents, **options) async def setup_hook(self): - if TEST_GUILD: - # When testing the bot it's handy to run in a single server (called a - # Guide in the API). This is relatively fast. - for guild_id in guild_ids: - guild = discord.Object(id=guild_id) - self.tree.copy_global_to(guild=guild) - await self.tree.sync(guild=guild) - - # guild = discord.Object(id=TEST_GUILD) - # self.tree.copy_global_to(guild=guild) - # await self.tree.sync(guild=guild) - else: - # This can take up to an hour for the commands to be registered. - await self.tree.sync() - logger.debug("Ready!") + print("setup") -# List the set of intents needed for commands to operate properly. -intents = discord.Intents.default() -intents.message_content = True -client = OpenAssistantClient(intents=intents) +class OpenAssistantBot: + def __init__(self, bot_token: str, bot_channel_name: str, backend_url: str, api_key: str): + intents = discord.Intents.default() + intents.message_content = True + self.bot_token = bot_token + client = ModifiedClient(intents=intents) + self.client = client + self.bot_channel: discord.TextChannel = None + self.backend = ApiClient(backend_url, api_key) + self.reply_handlers = {} # handlers by msg_id + @client.event + async def on_ready(): + self.bot_channel = self.get_text_channel_by_name(bot_channel_name) -class LikeButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👍") - self.channel = channel - self.username = username - self.prompt = prompt + client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") + print(f"{client.user} is now running!") - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - await interaction.response.send_message("Thanks for your feedback. You liked this 👍 ") - - -class NeutralButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="😐") - self.channel = channel - self.username = username - self.prompt = prompt - - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - await interaction.response.send_message("Thanks for your feedback. You thought this was neutral 😐 ") - - -class DislikeButton(discord.ui.Button): - def __init__(self, label, channel, username, prompt): - super().__init__(label=label, style=discord.ButtonStyle.green, emoji="👎") - self.channel = channel - self.username = username - self.prompt = prompt - - async def callback(self, interaction): - # interaction holds the interaction object - # await interaction.response.defer() - # send the feedback to the backend # - await interaction.response.send_message("Thanks for your feedback. You disliked this 👎 ") - - -@client.tree.command() -async def register(interaction: discord.Interaction): - """Registers the user for submissions.""" - labeler = { - "discord_username": f"{interaction.user.id}", - "display_name": interaction.user.name, - "is_enabled": True, - } - response = requests.post(labelers_url, headers=headers, json=labeler) - if response.status_code == 200: - await interaction.response.send_message(f"Added you {interaction.user.name}") - else: - logger.debug(response) - await interaction.response.send_message("Failed to add you") - - -@client.tree.command() -async def list_participants(interaction: discord.Interaction): - """Reports the set of registered participants.""" - response = requests.get(labelers_url, headers=headers) - if response.status_code == 200: - names = ",".join([labeler["display_name"] for labeler in response.json()]) - await interaction.response.send_message(f"Found these users: {names}") - else: - await interaction.response.send_message("Failed to fetch participants") - - -async def send_prompt_with_response_and_button(channel, username, prompt, response): - await channel.send(f"What do you think about the following interaction: \nprompt: {prompt} \nresponse: {response}") - # await channel.send(f'Please click on the button that best describes your reaction to the response:') - - # add buttons - view = discord.ui.View() - like = LikeButton(label="Like", channel=channel, username=username, prompt=prompt) - neutral = NeutralButton(label="Neutral", channel=channel, username=username, prompt=prompt) - dislike = DislikeButton(label="Dislike", channel=channel, username=username, prompt=prompt) - - view.add_item(item=like) - view.add_item(item=neutral) - view.add_item(item=dislike) - await channel.send(view=view) - - -@client.tree.command() -async def review_prompts(interaction: discord.Interaction, number_of_prompts: int): - # get the prompt from the db - url = f"{prompts_url}?begin_id=0&limit={number_of_prompts}" - response = requests.get(url, headers=headers) - if response.status_code == 200: - prompts = response.json() - logger.debug("the responses are:", prompts) - for prompt in prompts: - await send_prompt_with_response_and_button( - interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"] - ) - else: - await interaction.response.send_message("Failed to get prompts for review") - - -@client.tree.command() -async def add_prompt(interaction: discord.Interaction, prompt: str, response: str, language: str = "en"): - """Uploads a single prompt to the server.""" - prompt = { - "discord_username": f"{interaction.user.id}", - "labeler_id": 5, - "prompt": prompt, - "response": response, - "lang": language, - } - response = requests.post(prompts_url, headers=headers, json=prompt) - if response.status_code == 200: - await send_prompt_with_response_and_button( - interaction.channel, interaction.user.name, prompt["prompt"], prompt["response"] - ) - # send the prompt back with buttons for the user to click on - # await interaction.response.send_message("Added your prompt") - else: - await interaction.response.send_message("Failed to add the prompt") - - -@client.tree.command() -async def add_prompts_set(interaction: discord.Interaction, prompts: discord.Attachment): - """Uploads a batch of prompts to the server.""" - # Loading a bunch of prompts from a file can take a while. So first defer - # the response to ensure we're able to later tell the user what happened. - await interaction.response.defer(ephemeral=True) - - # Read the prompts and load them one by one. - # TODO: Upload a batch when the API supports it. - # TODO: Handle incorrect file types and parsing errors. - prompts_raw = await prompts.read() - prompts_loaded = json.loads(prompts_raw) - count = 0 - for entry in prompts_loaded: - for response in entry["responses"]: - prompt = { - "discord_username": f"{interaction.user.id}", - "labeler_id": 5, - "prompt": entry["prompt"], - "response": response, - "lang": "en", - } - response = requests.post(prompts_url, headers=headers, json=prompt) - if response.status_code != 200: - await interaction.followup.send("Failed to upload") + @client.event + async def on_message(message: discord.Message): + # ignore own messages + if message.author == client.user: return - count += 1 - await interaction.followup.send(f"Loaded up {count} prompts") + await self.handle_message(message) -client.run(TOKEN) + async def generate_summarize_story(self, task: protocol_schema.SummarizeStoryTask): + text = f"Summarize to the following story:\n{task.story}" + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_summarize_story_reply", message) + await message.reply("thx, on_summarize_story_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask): + s = [ + "Rate the following summary:", + task.summary, + "Full text:", + task.full_text, + f"Rating scale: {task.scale.min} - {task.scale.max}", + ] + text = "\n".join(s) + + async def rating_response_handler(score, interaction: discord.Interaction): + print("rating_response_handler", score) + await interaction.response.send_message(f"got your feedback: {score}") + + view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler) + msg: discord.Message = await self.bot_channel.send(text, view=view) + + async def on_reply(message: discord.Message): + print("on_summary_reply", message) + await message.reply("thx, on_summary_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_initial_prompt(self, task: protocol_schema.InitialPromptTask): + text = "Please provide an initial prompt to the assistant." + if task.hint: + text += f"\nHint: {task.hint}" + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_initial_prompt_reply", message) + await message.reply("thx, on_initial_prompt_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + def _render_message(self, message: protocol_schema.ConversationMessage) -> str: + """Render a message to the user.""" + if message.is_assistant: + return f"Assistant: {message.text}" + return f"User: {message.text}" + + async def generate_user_reply(self, task: protocol_schema.UserReplyTask): + s = ["Please provide a reply to the assistant.", "Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + if task.hint: + s.append(f"Hint: {task.hint}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_user_reply_reply", message) + await message.reply("thx, on_user_reply_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_assistant_reply(self, task: protocol_schema.AssistantReplyTask): + s = ["Act as the assistant and reply to the user.", "Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_assistant_reply_reply", message) + await message.reply("thx, on_assistant_reply_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rank_initial_prompts(self, task: protocol_schema.RankInitialPromptsTask): + s = ["Rank the following prompts:"] + for idx, prompt in enumerate(task.prompts, start=1): + s.append(f"{idx}: {prompt}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_rank_initial_prompts_reply", message) + await message.reply("thx, on_rank_initial_prompts_reply") + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def generate_rank_conversation(self, task: protocol_schema.RankConversationRepliesTask): + s = ["Here is the conversation so far:"] + for message in task.conversation.messages: + s.append(self._render_message(message)) + s.append("Rank the following replies:") + for idx, reply in enumerate(task.replies, start=1): + s.append(f"{idx}: {reply}") + text = "\n".join(s) + msg: discord.Message = await self.bot_channel.send(text) + + async def on_reply(message: discord.Message): + print("on_rrank_conversation_reply", message) + message + + self.reply_handlers[msg.id] = on_reply + + return msg + + async def next_task(self): + task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) + # task = self.backend.fetch_random_task(user=None) + + msg: discord.Message = None + match task.type: + case TaskType.summarize_story: + msg = await self.generate_summarize_story(task) + case TaskType.rate_summary: + msg = await self.generate_rate_summary(task) + case TaskType.initial_prompt: + msg = await self.generate_initial_prompt(task) + case TaskType.user_reply: + msg = await self.generate_user_reply(task) + case TaskType.assistant_reply: + msg = await self.generate_assistant_reply(task) + case TaskType.rank_initial_prompts: + msg = await self.generate_rank_initial_prompts(task) + case TaskType.rank_user_replies | TaskType.rank_assistant_replies: + msg = await self.generate_rank_conversation(task) + + if msg is not None: + await self.backend.ack_task(task.id, msg.id) + else: + await self.backend.nack_task(task.id, "not supported") + + async def background_timer(self): + while True: + if self.bot_channel: + try: + await self.next_task() + except Exception as e: + print(e) + await asyncio.sleep(30) + + def run(self): + """Run bot loop blocking.""" + self.client.run(self.bot_token) + + async def handle_message(self, message: discord.Message): + user_id = message.author.id + user_display_name = message.author.name + + if message.reference: + handler = self.reply_handlers.get(message.reference.message_id) + if handler: + await handler(message) + + print(user_id, user_display_name, message.content, type(message.content)) + + def get_text_channel_by_name(self, channel_name) -> discord.TextChannel: + for channel in self.client.get_all_channels(): + if channel.type == discord.ChannelType.text and channel.name == channel_name: + return channel diff --git a/bot/bot_settings.py b/bot/bot_settings.py new file mode 100644 index 00000000..3323b2fe --- /dev/null +++ b/bot/bot_settings.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +from pydantic import AnyHttpUrl, BaseSettings + + +class BotSettings(BaseSettings): + BACKEND_URL: AnyHttpUrl = "http://localhost:8080" + API_KEY: str = "any_key" + BOT_TOKEN: str + BOT_CHANNEL_NAME: str = "bot" + TEST_GUILD: str = None + + +settings = BotSettings(_env_file=".env") diff --git a/bot/requirements.txt b/bot/requirements.txt index 617e7071..da4762a6 100644 --- a/bot/requirements.txt +++ b/bot/requirements.txt @@ -1,2 +1,4 @@ discord.py==2.1.0 +pydantic==1.9.1 python-dotenv==0.21.0 +requests==2.28.1 diff --git a/bot/setup.py b/bot/setup.py deleted file mode 100644 index 06f7326e..00000000 --- a/bot/setup.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -from setuptools import find_packages, setup - -if __name__ == "__main__": - import os - - def _read_reqs(relpath): - fullpath = os.path.join(os.path.dirname(__file__), relpath) - with open(fullpath) as f: - return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] - - REQUIREMENTS = _read_reqs("requirements.txt") - - setup( - name="open-assistant", - packages=find_packages(), - version="0.0.1", - license="Apache 2.0", - description="A Discord Bot for collecting and ranking prompts to train an Open Assistant", - keywords=["machine learning", "natural language processing", "discord"], - install_requires=REQUIREMENTS, - classifiers=[ - "Development Status :: Alpha", - "Intended Audience :: Developers", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "License :: OSI Approved :: Apache License", - "Programming Language :: Python :: 3.6", - ], - ) diff --git a/docker/Dockerfile.backend b/docker/Dockerfile.backend new file mode 100644 index 00000000..d9458ae0 --- /dev/null +++ b/docker/Dockerfile.backend @@ -0,0 +1,15 @@ +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.10 + +COPY ./backend/requirements.txt /app/requirements.txt + +RUN pip install --no-cache-dir --upgrade -r /app/requirements.txt + +ENV PORT 8080 + +COPY ./oasst-shared /oasst-shared +RUN pip install -e /oasst-shared + +COPY ./backend/alembic /app/alembic +COPY ./backend/alembic.ini /app/alembic.ini +COPY ./backend/main.py /app/main.py +COPY ./backend/oasst_backend /app/oasst_backend diff --git a/bot/Dockerfile b/docker/Dockerfile.discord-bot similarity index 60% rename from bot/Dockerfile rename to docker/Dockerfile.discord-bot index ab215b5b..13ae308a 100644 --- a/bot/Dockerfile +++ b/docker/Dockerfile.discord-bot @@ -1,7 +1,7 @@ FROM python:3.10-slim-bullseye RUN mkdir /app -ADD requirements.txt /app/requirements.txt -WORKDIR /app +COPY ./discord-bot/requirements.txt /requirements.txt RUN pip install -r requirements.txt -ADD . /app +WORKDIR /app +COPY ./discord-bot /app CMD ["python", "bot.py"] diff --git a/oasst-shared/README.md b/oasst-shared/README.md new file mode 100644 index 00000000..28761ced --- /dev/null +++ b/oasst-shared/README.md @@ -0,0 +1,3 @@ +# Shared Python code for Open Assisstant + +Run `pip install -e .` to install the package in editable mode. diff --git a/backend/app/oasst/schemas/protocol.py b/oasst-shared/oasst_shared/schemas/protocol.py similarity index 100% rename from backend/app/oasst/schemas/protocol.py rename to oasst-shared/oasst_shared/schemas/protocol.py diff --git a/oasst-shared/setup.py b/oasst-shared/setup.py new file mode 100644 index 00000000..ebaf4217 --- /dev/null +++ b/oasst-shared/setup.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# setup.py for the shared python modules + +from distutils.core import setup + +from setuptools import find_namespace_packages + +setup( + name="oasst-shared", + version="1.0", + packages=find_namespace_packages(), + author="OASST Team", + install_requires=[ + "pydantic==1.9.1", + ], +) diff --git a/scripts/backend-development/README.md b/scripts/backend-development/README.md index 80706c79..3f7e6509 100644 --- a/scripts/backend-development/README.md +++ b/scripts/backend-development/README.md @@ -2,5 +2,5 @@ Run `docker compose up` to start a database. The default settings are already configured to connect to the database at `localhost:5432`. -Make sure you have all requirements installed. You can do this by running `pip install -r requirements.txt` inside the `backend` folder. +Make sure you have all requirements installed. You can do this by running `pip install -r requirements.txt` inside the `backend` folder and `pip install -e .` inside the `oasst-shared` folder. Then, run the backend using the `run-local.sh` script. This will start the backend server at `http://localhost:8080`. diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 064612c3..e9df6ca2 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -2,7 +2,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) # switch to backend directory -pushd "$parent_path/../../backend/app" +pushd "$parent_path/../../backend" export ALLOW_ANY_API_KEY=True diff --git a/scripts/frontend-development/docker-compose.yaml b/scripts/frontend-development/docker-compose.yaml index 1127d35d..7ff5dd28 100644 --- a/scripts/frontend-development/docker-compose.yaml +++ b/scripts/frontend-development/docker-compose.yaml @@ -15,7 +15,9 @@ services: file: ../backend-development/docker-compose.yaml service: adminer backend: - build: ../../backend/. + build: + dockerfile: docker/Dockerfile.backend + context: ../.. image: oasst-backend environment: - POSTGRES_HOST=db diff --git a/backend/postprocessing/rankings.py b/scripts/postprocessing/rankings.py similarity index 100% rename from backend/postprocessing/rankings.py rename to scripts/postprocessing/rankings.py