first api-interaction, fix auth_method unique-index

This commit is contained in:
Andreas Köpf
2022-12-22 18:41:50 +01:00
parent cad6a450c0
commit 8a48722e72
10 changed files with 172 additions and 43 deletions
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""add_auth_method_to_ix_person_username
Revision ID: 0daec5f8135f
Revises: 6368515778c5
Create Date: 2022-12-22 18:35:59.609013
"""
import sqlalchemy as sa # noqa: F401
from alembic import op
# revision identifiers, used by Alembic.
revision = "0daec5f8135f"
down_revision = "6368515778c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False)
# ### end Alembic commands ###
+1 -1
View File
@@ -10,7 +10,7 @@ from sqlmodel import Field, Index, SQLModel
class Person(SQLModel, table=True):
__tablename__ = "person"
__table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),)
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
+6 -1
View File
@@ -32,7 +32,12 @@ class PromptRepository:
)
if person is None:
# user is unknown, create new record
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
person = Person(
username=user.id,
display_name=user.display_name,
api_client_id=self.api_client.id,
auth_method=user.auth_method,
)
self.db.add(person)
self.db.commit()
self.db.refresh(person)
+1 -1
View File
@@ -69,6 +69,6 @@ class ApiClient:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone:
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
+4 -10
View File
@@ -16,7 +16,7 @@ from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.2"
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
@@ -61,8 +61,8 @@ class OpenAssistantBot(BotBase):
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
if self.debug:
await self.post_boot_message()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@@ -125,16 +125,10 @@ class OpenAssistantBot(BotBase):
await msg.delete()
logger.info("Completed deleting old bot messages.")
async def print_separtor(self, title: str) -> discord.Message:
msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n")
return msg
async def next_task(self):
task_type = protocol_schema.TaskRequestType.random
task_type = protocol_schema.TaskRequestType.summarize_story
task = self.backend.fetch_task(task_type, user=None)
await self.print_separtor("New Task")
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
+2
View File
@@ -7,6 +7,7 @@ from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@@ -22,6 +23,7 @@ class ReplyHandlerInfo:
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
+4 -6
View File
@@ -13,15 +13,13 @@ class ChannelExpiredException(Exception):
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool
completed: bool = False
expiry_date: datetime
expired: bool
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.expired = False
self.queue = asyncio.Queue()
self.completed = False
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
@@ -55,8 +53,8 @@ class ChannelHandlerBase(ABC):
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message
thread: discord.Thread
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
+84 -24
View File
@@ -5,11 +5,12 @@ from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import utcnow
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
@@ -23,15 +24,23 @@ class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
self.bot = bot
self.task = task
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
@@ -42,20 +51,57 @@ class ChannelTaskBase(AutoDestructThreadHandler):
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToPost(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Member) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except Exception as e:
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_summarize_story.msg", task=self.task)
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relatve = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
msg = await self.bot.post_template(
"task_summarize_story_teaser.msg", task=self.task, expiry_time=expiry_time, expiry_relatve=expiry_relatve
)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
print("received: ", msg, type(msg))
logger.info("on_summarize_story_reply")
await msg.add_reaction("")
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
@@ -63,13 +109,14 @@ class InitialPromptHandler(ChannelTaskBase):
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_initial_prompt.msg", task=self.task)
msg = await self.bot.post_template("task_initial_prompt.msg", task=self.task)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def handler_loop(self):
while True:
msg = await self.read()
logger.info("on_initial_prompt_reply")
await msg.add_reaction("")
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
@@ -77,13 +124,14 @@ class UserReplyHandler(ChannelTaskBase):
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_user_reply.msg", task=self.task)
msg = await self.bot.post_template("task_user_reply.msg", task=self.task)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def handler_loop(self):
while True:
msg = await self.read()
logger.info("on_user_reply_reply")
await msg.add_reaction("")
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
@@ -91,13 +139,19 @@ class AssistantReplyHandler(ChannelTaskBase):
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_assistant_reply.msg", task=self.task)
msg = await self.bot.post_template("task_assistant_reply.msg", task=self.task)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def handler_loop(self):
while True:
msg = await self.read()
logger.info("on_assistant_reply_reply")
await msg.add_reaction("")
try:
self.post_text_reply_to_post(msg)
await msg.add_reaction("")
except Exception as e:
await msg.add_reaction("")
await msg.reply(f"❌ Error communicating with backend: {e}")
class RankInitialPromptsHandler(ChannelTaskBase):
@@ -105,7 +159,9 @@ class RankInitialPromptsHandler(ChannelTaskBase):
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task)
msg = await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def handler_loop(self):
while True:
@@ -119,7 +175,9 @@ class RankConversationsHandler(ChannelTaskBase):
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task)
msg = await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task)
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def handler_loop(self):
while True:
@@ -156,7 +214,9 @@ class RateSummaryHandler(ChannelTaskBase):
await interaction.response.send_message(f"got your feedback: {score}")
async def send_first_message(self) -> discord.message:
return await self.bot.post("first message")
msg = await self.bot.post("first message")
self.backend.ack_task(self.task.id, str(msg.id))
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
@@ -0,0 +1,5 @@
:point_right: **Challenge: Summarize Story :books: ** :point_left:
:point_down: Work on this in the theard.
:fire: Message will self-destruct at {{ expiry_time }} UTC ({{ expiry_relatve }}).
+35
View File
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
@@ -15,3 +16,37 @@ def get_git_head_hash():
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)