mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add teaser msgs & remaining task handling
This commit is contained in:
+4
-4
@@ -115,7 +115,7 @@ class OpenAssistantBot(BotBase):
|
||||
await thread.delete()
|
||||
logger.info("Completed deleting old theards.")
|
||||
|
||||
logger.info("Deleting old bot messages...")
|
||||
logger.info("Deleting old messages...")
|
||||
look_until = utcnow() - timedelta(days=365)
|
||||
async for msg in self.bot_channel.history(limit=None):
|
||||
msg: discord.Message
|
||||
@@ -123,10 +123,10 @@ class OpenAssistantBot(BotBase):
|
||||
break
|
||||
if msg.author.id == self.client.user.id:
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old bot messages.")
|
||||
logger.info("Completed deleting old messages.")
|
||||
|
||||
async def next_task(self):
|
||||
task_type = protocol_schema.TaskRequestType.summarize_story
|
||||
task_type = protocol_schema.TaskRequestType.random
|
||||
task = self.backend.fetch_task(task_type, user=None)
|
||||
|
||||
handler: task_handlers.ChannelTaskBase = None
|
||||
@@ -166,7 +166,7 @@ class OpenAssistantBot(BotBase):
|
||||
|
||||
if self.bot_channel:
|
||||
if now > next_fetch_task:
|
||||
next_fetch_task = utcnow() + timedelta(seconds=600)
|
||||
next_fetch_task = utcnow() + timedelta(seconds=60)
|
||||
|
||||
try:
|
||||
await self.next_task()
|
||||
|
||||
@@ -23,9 +23,15 @@ class ChannelHandlerBase(ABC):
|
||||
|
||||
async def read(self) -> discord.Message:
|
||||
"""Call this method to read the next message from the user in the handler method."""
|
||||
msg = await self.queue.get()
|
||||
if msg is None and self.expired:
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
|
||||
msg = await self.queue.get()
|
||||
if msg is None:
|
||||
if self.expired:
|
||||
raise ChannelExpiredException()
|
||||
else:
|
||||
raise RuntimeError("Unexpected None message read")
|
||||
return msg
|
||||
|
||||
def on_reply(self, message: discord.Message) -> None:
|
||||
@@ -64,6 +70,7 @@ class AutoDestructThreadHandler(ChannelHandlerBase):
|
||||
return await super().read()
|
||||
except ChannelExpiredException:
|
||||
await self.cleanup()
|
||||
raise
|
||||
|
||||
async def cleanup(self):
|
||||
logger.debug("AutoDestructThreadHandler.cleanup")
|
||||
|
||||
+81
-35
@@ -7,7 +7,7 @@ from datetime import timedelta
|
||||
import discord
|
||||
from api_client import ApiClient
|
||||
from bot_base import BotBase
|
||||
from channel_handlers import AutoDestructThreadHandler
|
||||
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
|
||||
@@ -54,6 +54,13 @@ class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
def to_api_user(self, user: discord.User) -> protocol_schema.User:
|
||||
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
|
||||
|
||||
async def post_teaser_msg(self, template_name: str):
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relatve = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
return await self.bot.post_template(
|
||||
template_name, task=self.task, expiry_time=expiry_time, expiry_relatve=expiry_relatve
|
||||
)
|
||||
|
||||
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
|
||||
api_response = await self.backend.post_interaction(interaction)
|
||||
if api_response.type != "task_done":
|
||||
@@ -72,11 +79,37 @@ class ChannelTaskBase(AutoDestructThreadHandler):
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Member) -> protocol_schema.Task:
|
||||
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
self.post_text_reply_to_post(user_msg)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_text_reply_to_post()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
|
||||
return self.backend.post_interaction(
|
||||
protocol_schema.PostRanking(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(user_msg.id),
|
||||
user=self.to_api_user(user_msg.author),
|
||||
ranking=ranking,
|
||||
)
|
||||
)
|
||||
|
||||
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
|
||||
try:
|
||||
ranking_str = user_msg.content
|
||||
ranking = [int(x) - 1 for x in ranking_str.split(",")]
|
||||
self.post_ranking(user_msg, ranking=ranking)
|
||||
await user_msg.add_reaction("✅")
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in handle_ranking()")
|
||||
await user_msg.add_reaction("❌")
|
||||
await user_msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
@@ -86,12 +119,7 @@ class SummarizeStoryHandler(ChannelTaskBase):
|
||||
thread_name: str = "Summaries"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
|
||||
expiry_relatve = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
|
||||
msg = await self.bot.post_template(
|
||||
"task_summarize_story_teaser.msg", task=self.task, expiry_time=expiry_time, expiry_relatve=expiry_relatve
|
||||
)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_summarize_story.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
|
||||
@@ -107,8 +135,10 @@ class InitialPromptHandler(ChannelTaskBase):
|
||||
thread_name: str = "Prompts"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post_template("task_initial_prompt.msg", task=self.task)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_initial_prompt.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
@@ -121,8 +151,10 @@ class UserReplyHandler(ChannelTaskBase):
|
||||
thread_name: str = "User replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post_template("task_user_reply.msg", task=self.task)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_user_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
@@ -135,18 +167,15 @@ class AssistantReplyHandler(ChannelTaskBase):
|
||||
thread_name: str = "Assistant replies"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post_template("task_assistant_reply.msg", task=self.task)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_assistant_reply.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
try:
|
||||
self.post_text_reply_to_post(msg)
|
||||
await msg.add_reaction("✅")
|
||||
except Exception as e:
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply(f"❌ Error communicating with backend: {e}")
|
||||
await self.handle_text_reply_to_post(msg)
|
||||
|
||||
|
||||
class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
@@ -154,14 +183,15 @@ class RankInitialPromptsHandler(ChannelTaskBase):
|
||||
thread_name: str = "User Responses"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post_template("task_rank_initial_prompts.msg", task=self.task)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_rank_initial_prompts_reply")
|
||||
await msg.add_reaction("✅")
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RankConversationsHandler(ChannelTaskBase):
|
||||
@@ -169,14 +199,15 @@ class RankConversationsHandler(ChannelTaskBase):
|
||||
thread_name: str = "Rankings"
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post_template("task_rank_conversation_replies.msg", task=self.task)
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
|
||||
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_rank_conversation_reply")
|
||||
await msg.add_reaction("✅")
|
||||
await self.handle_ranking(msg)
|
||||
|
||||
|
||||
class RatingButton(discord.ui.Button):
|
||||
@@ -198,17 +229,31 @@ def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
|
||||
class RateSummaryHandler(ChannelTaskBase):
|
||||
task: protocol_schema.RateSummaryTask
|
||||
thread_name: str = "Rate"
|
||||
thread_name: str = "Ratings"
|
||||
|
||||
async def _rating_response_handler(self, score, interaction: discord.Interaction):
|
||||
logger.info("rating_response_handler", score)
|
||||
if self.thread:
|
||||
await self.thread.send(f"{interaction.user.name} got your feedback: {score}")
|
||||
await interaction.response.send_message(f"got your feedback: {score}")
|
||||
try:
|
||||
self.backend.post_interaction(
|
||||
protocol_schema.PostRating(
|
||||
post_id=str(self.first_message.id),
|
||||
user_post_id=str(interaction.id),
|
||||
user=self.to_api_user(interaction.user),
|
||||
rating=score,
|
||||
)
|
||||
)
|
||||
await interaction.response.send_message(
|
||||
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
|
||||
)
|
||||
except ChannelExpiredException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error in _rating_response_handler()")
|
||||
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
|
||||
|
||||
async def send_first_message(self) -> discord.message:
|
||||
msg = await self.bot.post("first message")
|
||||
return msg
|
||||
return await self.post_teaser_msg("teaser_rate_summary.msg")
|
||||
|
||||
async def on_thread_created(self, thread: discord.Thread) -> None:
|
||||
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
|
||||
@@ -217,5 +262,6 @@ class RateSummaryHandler(ChannelTaskBase):
|
||||
async def handler_loop(self):
|
||||
while True:
|
||||
msg = await self.read()
|
||||
logger.info("on_rate_summary_reply")
|
||||
await msg.add_reaction("✅")
|
||||
logger.info(f"on_rate_summary_reply: {msg.content}")
|
||||
await msg.add_reaction("❌")
|
||||
await msg.reply("❌ Text intput not supported.")
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
:point_right: **Challenge: Summarize Story :books: ** :point_left:
|
||||
|
||||
:point_down: Work on this in the theard.
|
||||
|
||||
:fire: Message will self-destruct at {{ expiry_time }} UTC ({{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:robot: **Challenge: Assistant Reply**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:microphone2: **Challenge: Initial Prompt**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:bar_chart: **Challenge: Rank Replies**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:bar_chart: **Challenge: Rank Initial Prompts**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:ballot_box: **Challenge: Rate Summary**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:books: **Challenge: Summarize Story**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
@@ -0,0 +1,3 @@
|
||||
:person_red_hair: **Challenge: User Reply**
|
||||
|
||||
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relatve }}).
|
||||
Reference in New Issue
Block a user