From 8a48722e7204e05926e7551dfb0f99bd472cacd2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Thu, 22 Dec 2022 18:41:50 +0100 Subject: [PATCH] first api-interaction, fix auth_method unique-index --- ...f_add_auth_method_to_ix_person_username.py | 30 +++++ backend/oasst_backend/models/person.py | 2 +- backend/oasst_backend/prompt_repository.py | 7 +- bot/api_client.py | 2 +- bot/bot.py | 14 +-- bot/bot_base.py | 2 + bot/channel_handlers.py | 10 +- bot/task_handlers.py | 108 ++++++++++++++---- bot/templates/task_summarize_story_teaser.msg | 5 + bot/utils.py | 35 ++++++ 10 files changed, 172 insertions(+), 43 deletions(-) create mode 100644 backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py create mode 100644 bot/templates/task_summarize_story_teaser.msg diff --git a/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py b/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py new file mode 100644 index 00000000..c65b8319 --- /dev/null +++ b/backend/alembic/versions/2022_12_22_1835-0daec5f8135f_add_auth_method_to_ix_person_username.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +"""add_auth_method_to_ix_person_username + +Revision ID: 0daec5f8135f +Revises: 6368515778c5 +Create Date: 2022-12-22 18:35:59.609013 + +""" +import sqlalchemy as sa # noqa: F401 +from alembic import op + +# revision identifiers, used by Alembic. +revision = "0daec5f8135f" +down_revision = "6368515778c5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_person_username", table_name="person") + op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index("ix_person_username", table_name="person") + op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False) + # ### end Alembic commands ### diff --git a/backend/oasst_backend/models/person.py b/backend/oasst_backend/models/person.py index 57f134a4..f01f85f0 100644 --- a/backend/oasst_backend/models/person.py +++ b/backend/oasst_backend/models/person.py @@ -10,7 +10,7 @@ from sqlmodel import Field, Index, SQLModel class Person(SQLModel, table=True): __tablename__ = "person" - __table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),) + __table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),) id: Optional[UUID] = Field( sa_column=sa.Column( diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 9f7bb1dd..b0063cdf 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -32,7 +32,12 @@ class PromptRepository: ) if person is None: # user is unknown, create new record - person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id) + person = Person( + username=user.id, + display_name=user.display_name, + api_client_id=self.api_client.id, + auth_method=user.auth_method, + ) self.db.add(person) self.db.commit() self.db.refresh(person) diff --git a/bot/api_client.py b/bot/api_client.py index 19a62188..1de6bb17 100644 --- a/bot/api_client.py +++ b/bot/api_client.py @@ -69,6 +69,6 @@ class ApiClient: req = protocol_schema.TaskNAck(reason=reason) return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict()) - def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone: + def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: data = self.post("/api/v1/tasks/interaction", interaction.dict()) return self._parse_task(data) diff --git a/bot/bot.py b/bot/bot.py index a8df4cf5..9f2f9247 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -16,7 +16,7 @@ from message_templates import MessageTemplates from oasst_shared.schemas import protocol as protocol_schema from utils import get_git_head_hash, utcnow -__version__ = "0.0.2" +__version__ = "0.0.3" BOT_NAME = "Open-Assistant Junior" @@ -61,8 +61,8 @@ class OpenAssistantBot(BotBase): logger.info(f"{client.user} is now running!") await self.delete_all_old_bot_messages() - if self.debug: - await self.post_boot_message() + # if self.debug: + # await self.post_boot_message() await self.post_welcome_message() client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") @@ -125,16 +125,10 @@ class OpenAssistantBot(BotBase): await msg.delete() logger.info("Completed deleting old bot messages.") - async def print_separtor(self, title: str) -> discord.Message: - msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n") - return msg - async def next_task(self): - task_type = protocol_schema.TaskRequestType.random + task_type = protocol_schema.TaskRequestType.summarize_story task = self.backend.fetch_task(task_type, user=None) - await self.print_separtor("New Task") - handler: task_handlers.ChannelTaskBase = None match task.type: case TaskType.summarize_story: diff --git a/bot/bot_base.py b/bot/bot_base.py index 7ac2e2ac..76eca22d 100644 --- a/bot/bot_base.py +++ b/bot/bot_base.py @@ -7,6 +7,7 @@ from dataclasses import dataclass from typing import Any import discord +from api_client import ApiClient from channel_handlers import ChannelHandlerBase from loguru import logger from message_templates import MessageTemplates @@ -22,6 +23,7 @@ class ReplyHandlerInfo: class BotBase(ABC): bot_channel_name: str debug: bool + backend: ApiClient client: discord.Client loop: asyncio.BaseEventLoop owner_id: int diff --git a/bot/channel_handlers.py b/bot/channel_handlers.py index b92d2273..deed5049 100644 --- a/bot/channel_handlers.py +++ b/bot/channel_handlers.py @@ -13,15 +13,13 @@ class ChannelExpiredException(Exception): class ChannelHandlerBase(ABC): queue: asyncio.Queue - completed: bool + completed: bool = False expiry_date: datetime - expired: bool + expired: bool = False def __init__(self, *, expiry_date: datetime = None): self.expiry_date = expiry_date - self.expired = False self.queue = asyncio.Queue() - self.completed = False async def read(self) -> discord.Message: """Call this method to read the next message from the user in the handler method.""" @@ -55,8 +53,8 @@ class ChannelHandlerBase(ABC): class AutoDestructThreadHandler(ChannelHandlerBase): - first_message: discord.Message - thread: discord.Thread + first_message: discord.Message = None + thread: discord.Thread = None def __init__(self, *, expiry_date: datetime = None): super().__init__(expiry_date=expiry_date) diff --git a/bot/task_handlers.py b/bot/task_handlers.py index dd81bbde..261860ab 100644 --- a/bot/task_handlers.py +++ b/bot/task_handlers.py @@ -5,11 +5,12 @@ from abc import abstractmethod from datetime import timedelta import discord +from api_client import ApiClient from bot_base import BotBase from channel_handlers import AutoDestructThreadHandler from loguru import logger from oasst_shared.schemas import protocol as protocol_schema -from utils import utcnow +from utils import DiscordTimestampStyle, discord_timestamp, utcnow class Questionnaire(discord.ui.Modal, title="Questionnaire Response"): @@ -23,15 +24,23 @@ class Questionnaire(discord.ui.Modal, title="Questionnaire Response"): class ChannelTaskBase(AutoDestructThreadHandler): thread_name: str = "Replies" expires_after: timedelta = timedelta(minutes=5) + backend: ApiClient async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message: - self.bot = bot - self.task = task - msg = await self.send_first_message() - self.first_message = msg - self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name) - await self.on_thread_created(self.thread) - self.expiry_date = utcnow() + self.expires_after if self.expires_after else None + try: + self.bot = bot + self.task = task + self.backend = bot.backend + self.expiry_date = utcnow() + self.expires_after if self.expires_after else None + msg = await self.send_first_message() + self.first_message = msg + self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name) + await self.on_thread_created(self.thread) + except Exception: + logger.exception("start task failed") + await self.cleanup() # try to cleanup messag or thread + raise + bot.register_reply_handler(msg_id=msg.id, handler=self) return msg @@ -42,20 +51,57 @@ class ChannelTaskBase(AutoDestructThreadHandler): async def send_first_message(self) -> discord.message: ... + def to_api_user(self, user: discord.User) -> protocol_schema.User: + return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name) + + async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task: + api_response = await self.backend.post_interaction(interaction) + if api_response.type != "task_done": + # multi-step tasks are not supported yet + logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})") + raise RuntimeError("Unexpected response from backend received") + return api_response + + def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task: + return self.backend.post_interaction( + protocol_schema.TextReplyToPost( + post_id=str(self.first_message.id), + user_post_id=str(user_msg.id), + user=self.to_api_user(user_msg.author), + text=user_msg.content, + ) + ) + + async def handle_text_reply_to_post(self, user_msg: discord.Member) -> protocol_schema.Task: + try: + self.post_text_reply_to_post(user_msg) + await user_msg.add_reaction("✅") + except Exception as e: + await user_msg.add_reaction("❌") + await user_msg.reply(f"❌ Error communicating with backend: {e}") + class SummarizeStoryHandler(ChannelTaskBase): task: protocol_schema.SummarizeStoryTask thread_name: str = "Summaries" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_summarize_story.msg", task=self.task) + + expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time) + expiry_relatve = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time) + msg = await self.bot.post_template( + "task_summarize_story_teaser.msg", task=self.task, expiry_time=expiry_time, expiry_relatve=expiry_relatve + ) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg + + async def on_thread_created(self, thread: discord.Thread) -> None: + await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task) async def handler_loop(self): while True: msg = await self.read() - print("received: ", msg, type(msg)) - logger.info("on_summarize_story_reply") - await msg.add_reaction("✅") + await self.handle_text_reply_to_post(msg) class InitialPromptHandler(ChannelTaskBase): @@ -63,13 +109,14 @@ class InitialPromptHandler(ChannelTaskBase): thread_name: str = "Prompts" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_initial_prompt.msg", task=self.task) + msg = await self.bot.post_template("task_initial_prompt.msg", task=self.task) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def handler_loop(self): while True: msg = await self.read() - logger.info("on_initial_prompt_reply") - await msg.add_reaction("✅") + await self.handle_text_reply_to_post(msg) class UserReplyHandler(ChannelTaskBase): @@ -77,13 +124,14 @@ class UserReplyHandler(ChannelTaskBase): thread_name: str = "User replies" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_user_reply.msg", task=self.task) + msg = await self.bot.post_template("task_user_reply.msg", task=self.task) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def handler_loop(self): while True: msg = await self.read() - logger.info("on_user_reply_reply") - await msg.add_reaction("✅") + await self.handle_text_reply_to_post(msg) class AssistantReplyHandler(ChannelTaskBase): @@ -91,13 +139,19 @@ class AssistantReplyHandler(ChannelTaskBase): thread_name: str = "Assistant replies" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_assistant_reply.msg", task=self.task) + msg = await self.bot.post_template("task_assistant_reply.msg", task=self.task) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def handler_loop(self): while True: msg = await self.read() - logger.info("on_assistant_reply_reply") - await msg.add_reaction("✅") + try: + self.post_text_reply_to_post(msg) + await msg.add_reaction("✅") + except Exception as e: + await msg.add_reaction("❌") + await msg.reply(f"❌ Error communicating with backend: {e}") class RankInitialPromptsHandler(ChannelTaskBase): @@ -105,7 +159,9 @@ class RankInitialPromptsHandler(ChannelTaskBase): thread_name: str = "User Responses" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task) + msg = await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def handler_loop(self): while True: @@ -119,7 +175,9 @@ class RankConversationsHandler(ChannelTaskBase): thread_name: str = "Rankings" async def send_first_message(self) -> discord.message: - return await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task) + msg = await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task) + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def handler_loop(self): while True: @@ -156,7 +214,9 @@ class RateSummaryHandler(ChannelTaskBase): await interaction.response.send_message(f"got your feedback: {score}") async def send_first_message(self) -> discord.message: - return await self.bot.post("first message") + msg = await self.bot.post("first message") + self.backend.ack_task(self.task.id, str(msg.id)) + return msg async def on_thread_created(self, thread: discord.Thread) -> None: view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler) diff --git a/bot/templates/task_summarize_story_teaser.msg b/bot/templates/task_summarize_story_teaser.msg new file mode 100644 index 00000000..3493982b --- /dev/null +++ b/bot/templates/task_summarize_story_teaser.msg @@ -0,0 +1,5 @@ +:point_right: **Challenge: Summarize Story :books: ** :point_left: + +:point_down: Work on this in the theard. + +:fire: Message will self-destruct at {{ expiry_time }} UTC ({{ expiry_relatve }}). \ No newline at end of file diff --git a/bot/utils.py b/bot/utils.py index 1a06b833..968e4498 100644 --- a/bot/utils.py +++ b/bot/utils.py @@ -1,4 +1,5 @@ # -*- coding: utf-8 -*- +import enum import subprocess from datetime import datetime @@ -15,3 +16,37 @@ def get_git_head_hash(): def utcnow() -> datetime: return datetime.now(pytz.UTC) + + +class DiscordTimestampStyle(str, enum.Enum): + """ + Timestamp Styles + + t 16:20 Short Time + T 16:20:30 Long Time + d 20/04/2021 Short Date + D 20 April 2021 Long Date + f * 20 April 2021 16:20 Short Date/Time + F Tuesday, 20 April 2021 16:20 Long Date/Time + R 2 months ago Relative Time + + See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles + """ + + default = "" + short_time = "t" + long_time = "T" + short_date = "d" + long_date = "D" + short_date_time = "f" + long_date_time = "F" + relative_time = "R" + + +def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default): + parts = ["") + return "".join(parts)