mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
first api-interaction, fix auth_method unique-index
This commit is contained in:
+30
@@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""add_auth_method_to_ix_person_username
|
||||
|
||||
Revision ID: 0daec5f8135f
|
||||
Revises: 6368515778c5
|
||||
Create Date: 2022-12-22 18:35:59.609013
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa # noqa: F401
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0daec5f8135f"
|
||||
down_revision = "6368515778c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_person_username", table_name="person")
|
||||
op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
@@ -10,7 +10,7 @@ from sqlmodel import Field, Index, SQLModel
|
||||
|
||||
class Person(SQLModel, table=True):
|
||||
__tablename__ = "person"
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),)
|
||||
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
|
||||
|
||||
id: Optional[UUID] = Field(
|
||||
sa_column=sa.Column(
|
||||
|
||||
@@ -32,7 +32,12 @@ class PromptRepository:
|
||||
)
|
||||
if person is None:
|
||||
# user is unknown, create new record
|
||||
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
|
||||
person = Person(
|
||||
username=user.id,
|
||||
display_name=user.display_name,
|
||||
api_client_id=self.api_client.id,
|
||||
auth_method=user.auth_method,
|
||||
)
|
||||
self.db.add(person)
|
||||
self.db.commit()
|
||||
self.db.refresh(person)
|
||||
|
||||
+1
-1
@@ -69,6 +69,6 @@ class ApiClient:
|
||||
req = protocol_schema.TaskNAck(reason=reason)
|
||||
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
|
||||
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone:
|
||||
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
data = self.post("/api/v1/tasks/interaction", interaction.dict())
|
||||
return self._parse_task(data)
|
||||
|
||||
+4
-10
@@ -16,7 +16,7 @@ from message_templates import MessageTemplates
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import get_git_head_hash, utcnow
|
||||
|
||||
__version__ = "0.0.2"
|
||||
__version__ = "0.0.3"
|
||||
BOT_NAME = "Open-Assistant Junior"
|
||||
|
||||
|
||||
@@ -61,8 +61,8 @@ class OpenAssistantBot(BotBase):
|
||||
logger.info(f"{client.user} is now running!")
|
||||
|
||||
await self.delete_all_old_bot_messages()
|
||||
if self.debug:
|
||||
await self.post_boot_message()
|
||||
# if self.debug:
|
||||
# await self.post_boot_message()
|
||||
await self.post_welcome_message()
|
||||
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
@@ -125,16 +125,10 @@ class OpenAssistantBot(BotBase):
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old bot messages.")
|
||||
|
||||
async def print_separtor(self, title: str) -> discord.Message:
|
||||
msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n")
|
||||
return msg
|
||||
|
||||
async def next_task(self):
|
||||
task_type = protocol_schema.TaskRequestType.random
|
||||
task_type = protocol_schema.TaskRequestType.summarize_story
|
||||
task = self.backend.fetch_task(task_type, user=None)
|
||||
|
||||
await self.print_separtor("New Task")
|
||||
|
||||
handler: task_handlers.ChannelTaskBase = None
|
||||
match task.type:
|
||||
case TaskType.summarize_story:
|
||||
|
||||
@@ -7,6 +7,7 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from channel_handlers import ChannelHandlerBase
|
||||
from loguru import logger
|
||||
from message_templates import MessageTemplates
|
||||
@@ -22,6 +23,7 @@ class ReplyHandlerInfo:
|
||||
class BotBase(ABC):
|
||||
bot_channel_name: str
|
||||
debug: bool
|
||||
backend: ApiClient
|
||||
client: discord.Client
|
||||
loop: asyncio.BaseEventLoop
|
||||
owner_id: int
|
||||
|
||||
@@ -13,15 +13,13 @@ class ChannelExpiredException(Exception):
|
||||
|
||||
class ChannelHandlerBase(ABC):
|
||||
queue: asyncio.Queue
|
||||
completed: bool
|
||||
completed: bool = False
|
||||
expiry_date: datetime
|
||||
expired: bool
|
||||
expired: bool = False
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
self.expiry_date = expiry_date
|
||||
self.expired = False
|
||||
self.queue = asyncio.Queue()
|
||||
self.completed = False
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
"""Call this method to read the next message from the user in the handler method."""
|
||||
@@ -55,8 +53,8 @@ class ChannelHandlerBase(ABC):
|
||||
|
||||
|
||||
class AutoDestructThreadHandler(ChannelHandlerBase):
|
||||
first_message: discord.Message
|
||||
thread: discord.Thread
|
||||
first_message: discord.Message = None
|
||||
thread: discord.Thread = None
|
||||
|
||||
def __init__(self, *, expiry_date: datetime = None):
|
||||
super().__init__(expiry_date=expiry_date)
|
||||
|
||||
+84
-24
@@ -5,11 +5,12 @@ from abc import abstractmethod
|
||||
from datetime import timedelta
|
||||
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from bot_base import BotBase
|
||||
from channel_handlers import AutoDestructThreadHandler
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import utcnow
|
||||
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
||||
|
||||
|
||||
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
||||
@@ -23,15 +24,23 @@ class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
|
||||
class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
thread_name: str = "Replies"
|
||||
expires_after: timedelta = timedelta(minutes=5)
|
||||
backend: ApiClient
|
||||
|
||||
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
|
||||
self.bot = bot
|
||||
self.task = task
|
||||
msg = await self.send_first_message()
|
||||
self.first_message = msg
|
||||
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
||||
await self.on_thread_created(self.thread)
|
||||
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
||||
try:
|
||||
self.bot = bot
|
||||
self.task = task
|
||||
self.backend = bot.backend
|
||||
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
|
||||
msg = await self.send_first_message()
|
||||
self.first_message = msg
|
||||
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
|
||||
await self.on_thread_created(self.thread)
|
||||
except Exception:
|
||||
logger.exception("start task failed")
|
||||
await self.cleanup() # try to cleanup messag or thread
|
||||
raise
|
||||
|
||||
bot.register_reply_handler(msg_id=msg.id, handler=self)
|
||||
return msg
|
||||
|
||||
@@ -42,20 +51,57 @@ class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
async def send_first_message(self) -> discord.message:
|
||||
...
|
||||
|
||||
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
||||
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
api_response = await self.backend.post_interaction(interaction)
|
||||
if api_response.type != "task_done":
|
||||
# multi-step tasks are not supported yet
|
||||
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
|
||||
raise RuntimeError("Unexpected response from backend received")
|
||||
return api_response
|
||||
|
||||
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.TextReplyToPost(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
text=user_msg.content,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Member) -> protocol_schema.Task:
|
||||
try:
|
||||
self.post_text_reply_to_post(user_msg)
|
||||
await user_msg.add_reaction("✅")
|
||||
except Exception as e:
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
|
||||
class SummarizeStoryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.SummarizeStoryTask
|
||||
thread_name: str = "Summaries"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_summarize_story.msg", task=self.task)
|
||||
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relatve = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
msg = await self.bot.post_template(
|
||||
"task_summarize_story_teaser.msg", task=self.task, expiry_time=expiry_time, expiry_relatve=expiry_relatve
|
||||
)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
print("received: ", msg, type(msg))
|
||||
logger.info("on_summarize_story_reply")
|
||||
await msg.add_reaction("✅")
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class InitialPromptHandler(ChannelTaskBase):
|
||||
@@ -63,13 +109,14 @@ class InitialPromptHandler(ChannelTaskBase):
|
||||
thread_name: str = "Prompts"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_initial_prompt.msg", task=self.task)
|
||||
msg = await self.bot.post_template("task_initial_prompt.msg", task=self.task)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_initial_prompt_reply")
|
||||
await msg.add_reaction("✅")
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class UserReplyHandler(ChannelTaskBase):
|
||||
@@ -77,13 +124,14 @@ class UserReplyHandler(ChannelTaskBase):
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_user_reply.msg", task=self.task)
|
||||
msg = await self.bot.post_template("task_user_reply.msg", task=self.task)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_user_reply_reply")
|
||||
await msg.add_reaction("✅")
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class AssistantReplyHandler(ChannelTaskBase):
|
||||
@@ -91,13 +139,19 @@ class AssistantReplyHandler(ChannelTaskBase):
|
||||
thread_name: str = "Assistant replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_assistant_reply.msg", task=self.task)
|
||||
msg = await self.bot.post_template("task_assistant_reply.msg", task=self.task)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_assistant_reply_reply")
|
||||
await msg.add_reaction("✅")
|
||||
try:
|
||||
self.post_text_reply_to_post(msg)
|
||||
await msg.add_reaction("✅")
|
||||
except Exception as e:
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
|
||||
class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
@@ -105,7 +159,9 @@ class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
thread_name: str = "User Responses"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task)
|
||||
msg = await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
@@ -119,7 +175,9 @@ class RankConversationsHandler(ChannelTaskBase):
|
||||
thread_name: str = "Rankings"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task)
|
||||
msg = await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task)
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
@@ -156,7 +214,9 @@ class RateSummaryHandler(ChannelTaskBase):
|
||||
await interaction.response.send_message(f"got your feedback: {score}")
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
return await self.bot.post("first message")
|
||||
msg = await self.bot.post("first message")
|
||||
self.backend.ack_task(self.task.id, str(msg.id))
|
||||
return msg
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
:point_right: **Challenge: Summarize Story :books: ** :point_left:
|
||||
|
||||
:point_down: Work on this in the theard.
|
||||
|
||||
:fire: Message will self-destruct at {{ expiry_time }} UTC ({{ expiry_relatve }}).
|
||||
@@ -1,4 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
@@ -15,3 +16,37 @@ def get_git_head_hash():
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(pytz.UTC)
|
||||
|
||||
|
||||
class DiscordTimestampStyle(str, enum.Enum):
|
||||
"""
|
||||
Timestamp Styles
|
||||
|
||||
t 16:20 Short Time
|
||||
T 16:20:30 Long Time
|
||||
d 20/04/2021 Short Date
|
||||
D 20 April 2021 Long Date
|
||||
f * 20 April 2021 16:20 Short Date/Time
|
||||
F Tuesday, 20 April 2021 16:20 Long Date/Time
|
||||
R 2 months ago Relative Time
|
||||
|
||||
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
|
||||
"""
|
||||
|
||||
default = ""
|
||||
short_time = "t"
|
||||
long_time = "T"
|
||||
short_date = "d"
|
||||
long_date = "D"
|
||||
short_date_time = "f"
|
||||
long_date_time = "F"
|
||||
relative_time = "R"
|
||||
|
||||
|
||||
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
|
||||
parts = ["<t:", str(int(d.timestamp()))]
|
||||
if style:
|
||||
parts.append(":")
|
||||
parts.append(style)
|
||||
parts.append(">")
|
||||
return "".join(parts)
|
||||
|
||||
Reference in New Issue
Block a user