Merge pull request #50 from LAION-AI/bot_dev

add first bits of jinja template support
This commit is contained in:
Andreas Köpf
2022-12-23 01:03:00 +01:00
committed by GitHub
32 changed files with 796 additions and 211 deletions
+1 -1
View File
@@ -1,4 +1,4 @@
exclude: "build|stubs"
exclude: "build|stubs|^bot/templates/"
default_language_version:
python: python3
@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
"""add_auth_method_to_ix_person_username
Revision ID: 0daec5f8135f
Revises: 6368515778c5
Create Date: 2022-12-22 18:35:59.609013
"""
import sqlalchemy as sa # noqa: F401
from alembic import op
# revision identifiers, used by Alembic.
revision = "0daec5f8135f"
down_revision = "6368515778c5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username", "auth_method"], unique=True)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index("ix_person_username", table_name="person")
op.create_index("ix_person_username", "person", ["api_client_id", "username"], unique=False)
# ### end Alembic commands ###
+1 -1
View File
@@ -169,8 +169,8 @@ def acknowledge_task(
pr = PromptRepository(db, api_client, user=None)
# here we store the post id in the database for the task
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
logger.info(f"Frontend acknowledges task {task_id=}, {ack_request=}.")
pr.bind_frontend_post_id(task_id=task_id, post_id=ack_request.post_id)
except Exception:
logger.exception("Failed to acknowledge task.")
+1 -1
View File
@@ -10,7 +10,7 @@ from sqlmodel import Field, Index, SQLModel
class Person(SQLModel, table=True):
__tablename__ = "person"
__table_args__ = (Index("ix_person_username", "api_client_id", "username", unique=True),)
__table_args__ = (Index("ix_person_username", "api_client_id", "username", "auth_method", unique=True),)
id: Optional[UUID] = Field(
sa_column=sa.Column(
+6 -1
View File
@@ -32,7 +32,12 @@ class PromptRepository:
)
if person is None:
# user is unknown, create new record
person = Person(username=user.id, display_name=user.display_name, api_client_id=self.api_client.id)
person = Person(
username=user.id,
display_name=user.display_name,
api_client_id=self.api_client.id,
auth_method=user.auth_method,
)
self.db.add(person)
self.db.commit()
self.db.refresh(person)
+2
View File
@@ -12,5 +12,7 @@ if __name__ == "__main__":
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
+1 -1
View File
@@ -69,6 +69,6 @@ class ApiClient:
req = protocol_schema.TaskNAck(reason=reason)
return self.post(f"/api/v1/tasks/{task_id}/nack", req.dict())
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.TaskDone:
def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
data = self.post("/api/v1/tasks/interaction", interaction.dict())
return self._parse_task(data)
+153 -206
View File
@@ -1,32 +1,26 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from datetime import timedelta
from pathlib import Path
from typing import Optional, Union
import discord
import task_handlers
from api_client import ApiClient, TaskType
from bot_base import BotBase
from discord import app_commands
from loguru import logger
from message_templates import MessageTemplates
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.3"
BOT_NAME = "Open-Assistant Junior"
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class OpenAssistantBot:
class OpenAssistantBot(BotBase):
def __init__(
self,
bot_token: str,
@@ -34,7 +28,16 @@ class OpenAssistantBot:
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
super().__init__()
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
@@ -45,20 +48,25 @@ class OpenAssistantBot:
self.bot_token = bot_token
client = discord.Client(intents=intents)
self.client = client
self.loop = client.loop
self.bot_channel: discord.TextChannel = None
self.backend = ApiClient(backend_url, api_key)
self.reply_handlers = {} # handlers by msg_id
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
self.auto_archive_minutes = 60 # ToDo: add to bot config
self.tree = app_commands.CommandTree(self.client, fallback_to_global=True)
@client.event
async def on_ready():
self.bot_channel = self.get_text_channel_by_name(bot_channel_name)
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
# if self.debug:
# await self.post_boot_message()
await self.post_welcome_message()
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
@client.event
async def on_message(message: discord.Message):
# ignore own messages
@@ -68,208 +76,111 @@ class OpenAssistantBot:
@self.tree.command()
async def tutorial(interaction: discord.Interaction):
"""Start the Open-Assistant tutorial via DMs."""
dm = await self.client.create_dm(discord.Object(interaction.user.id))
await dm.send("Tutorial coming soon... :-)")
await interaction.response.send_message(f"tutorial command by {interaction.user.name}")
@self.tree.command()
async def help(interaction: discord.Interaction):
"""Sends the user a list of all available commands"""
await interaction.response.send_message(f"help command by {interaction.user.name}")
await self.post_help(interaction.user)
await interaction.response.send_message(f"@{interaction.user.display_name}, I've sent you a PM.")
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
await interaction.response.send_message(f"work command by {interaction.user.name}")
async def print_separtor(self, title: str) -> discord.Message:
msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n")
return msg
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = task_handlers.Questionnaire()
await interaction.response.send_modal(q)
async def generate_summarize_story(self, task: protocol_schema.SummarizeStoryTask):
text = f"Summarize to the following story:\n{task.story}"
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="Summaries", auto_archive_duration=self.auto_archive_minutes
async def post_help(self, user: discord.abc.User) -> discord.Message:
is_bot_owner = user.id == self.owner_id
return await self.post_template("help.msg", channel=user, is_bot_owner=is_bot_owner)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def on_reply(message: discord.Message):
logger.info("on_summarize_story_reply", message)
await message.add_reaction("")
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
self.reply_handlers[msg.id] = on_reply
async def delete_all_old_bot_messages(self) -> None:
logger.info("Deleting old threads...")
for thread in self.bot_channel.threads:
if thread.owner_id == self.client.user.id:
await thread.delete()
logger.info("Completed deleting old theards.")
return msg
async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask):
s = [
"Rate the following summary:",
task.summary,
"Full text:",
task.full_text,
f"Rating scale: {task.scale.min} - {task.scale.max}",
]
text = "\n".join(s)
async def rating_response_handler(score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
await interaction.response.send_message(f"got your feedback: {score}")
view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler)
msg: discord.Message = await self.bot_channel.send(text, view=view)
async def on_reply(message: discord.Message):
logger.info("on_summary_reply", message)
await message.add_reaction("")
self.reply_handlers[msg.id] = on_reply
return msg
async def generate_initial_prompt(self, task: protocol_schema.InitialPromptTask):
text = "Please provide an initial prompt to the assistant."
if task.hint:
text += f"\nHint: {task.hint}"
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="Prompts", auto_archive_duration=self.auto_archive_minutes
)
async def on_reply(message: discord.Message):
logger.info("on_initial_prompt_reply", message)
await message.add_reaction("")
self.reply_handlers[msg.id] = on_reply
return msg
def _render_message(self, message: protocol_schema.ConversationMessage) -> str:
"""Render a message to the user."""
if message.is_assistant:
return f":robot: Assistant:\n{message.text}"
else:
return f":person_red_hair: User:\n**{message.text}**"
async def generate_user_reply(self, task: protocol_schema.UserReplyTask):
s = ["Please provide a reply to the assistant.", "Here is the conversation so far:\n"]
for message in task.conversation.messages:
s.append(self._render_message(message))
s.append("")
if task.hint:
s.append(f"Hint: {task.hint}")
text = "\n".join(s)
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes
)
async def on_reply(message: discord.Message):
logger.info("on_user_reply_reply", message)
await message.add_reaction("")
self.reply_handlers[msg.id] = on_reply
return msg
async def generate_assistant_reply(self, task: protocol_schema.AssistantReplyTask):
s = ["Act as the assistant and reply to the user.", "Here is the conversation so far\n:"]
for message in task.conversation.messages:
s.append(self._render_message(message))
s.append("")
s.append(":robot: Assistant: { human, pls help me! ... }")
text = "\n".join(s)
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="Agent responses", auto_archive_duration=self.auto_archive_minutes
)
async def on_reply(message: discord.Message):
logger.info("on_assistant_reply_reply", message)
await message.add_reaction("")
self.reply_handlers[msg.id] = on_reply
return msg
async def generate_rank_initial_prompts(self, task: protocol_schema.RankInitialPromptsTask):
s = ["Rank the following prompts:"]
for idx, prompt in enumerate(task.prompts, start=1):
s.append(f"{idx}: {prompt}")
s.append("")
s.append(':scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").')
text = "\n".join(s)
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes
)
async def on_reply(message: discord.Message):
logger.info("on_rank_initial_prompts_reply", message)
await message.add_reaction("")
self.reply_handlers[msg.id] = on_reply
return msg
async def generate_rank_conversation(self, task: protocol_schema.RankConversationRepliesTask):
s = ["Here is the conversation so far:"]
for message in task.conversation.messages:
s.append(self._render_message(message))
s.append("")
s.append("Rank the following replies:")
for idx, reply in enumerate(task.replies, start=1):
s.append(f"{idx}: {reply}")
s.append("")
s.append(':scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").')
text = "\n".join(s)
msg: discord.Message = await self.bot_channel.send(text)
await self.bot_channel.create_thread(
message=discord.Object(msg.id), name="User responses", auto_archive_duration=self.auto_archive_minutes
)
async def on_reply(message: discord.Message):
logger.info("on_rank_conversation_reply", message)
await message.add_reaction("")
message
self.reply_handlers[msg.id] = on_reply
return msg
logger.info("Deleting old messages...")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old messages.")
async def next_task(self):
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.user_reply, user=None)
task = self.backend.fetch_random_task(user=None)
task_type = protocol_schema.TaskRequestType.random
task = self.backend.fetch_task(task_type, user=None)
await self.print_separtor("New Task")
msg: discord.Message = None
handler: task_handlers.ChannelTaskBase = None
match task.type:
case TaskType.summarize_story:
msg = await self.generate_summarize_story(task)
handler = task_handlers.SummarizeStoryHandler()
case TaskType.rate_summary:
msg = await self.generate_rate_summary(task)
handler = task_handlers.RateSummaryHandler()
case TaskType.initial_prompt:
msg = await self.generate_initial_prompt(task)
handler = task_handlers.InitialPromptHandler()
case TaskType.user_reply:
msg = await self.generate_user_reply(task)
handler = task_handlers.UserReplyHandler()
case TaskType.assistant_reply:
msg = await self.generate_assistant_reply(task)
handler = task_handlers.AssistantReplyHandler()
case TaskType.rank_initial_prompts:
msg = await self.generate_rank_initial_prompts(task)
handler = task_handlers.RankInitialPromptsHandler()
case TaskType.rank_user_replies | TaskType.rank_assistant_replies:
msg = await self.generate_rank_conversation(task)
handler = task_handlers.RankConversationsHandler()
case _:
logger.warning(f"Unsupported task type received: {task.type}")
self.backend.nack_task(task.id, "not supported")
if msg is not None:
self.backend.ack_task(task.id, msg.id)
else:
self.backend.nack_task(task.id, "not supported")
if handler:
try:
logger.info(f"strarting task {task.id}")
msg = await handler.start(self, task)
self.backend.ack_task(task.id, msg.id)
except Exception:
logger.exception("Starting task failed.")
self.backend.nack_task(task.id, "faled")
async def background_timer(self):
next_remove_completed = utcnow() + timedelta(seconds=10)
next_fetch_task = utcnow() + timedelta(seconds=1)
while True:
now = utcnow()
if self.bot_channel:
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
await asyncio.sleep(30)
if now > next_fetch_task:
next_fetch_task = utcnow() + timedelta(seconds=60)
try:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
for x in self.reply_handlers.values():
x.handler.tick(now)
if now > next_remove_completed:
next_remove_completed = utcnow() + timedelta(seconds=10)
await self.remove_completed_handlers()
await asyncio.sleep(1)
async def _sync(self, command: str, message: discord.Message):
@@ -293,38 +204,74 @@ class OpenAssistantBot:
command_text: str = message.content
command_text = command_text[1:]
match command_text:
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild" | "sync.clear_guild":
case "help" | "?":
await self.post_help(user=message.author)
case "sync" | "sync.guild" | "sync.copy_global" | "sync.clear_guild":
if is_owner:
await self._sync(command_text, message)
case _:
await message.reply(f"unknown command: {command_text}")
def recipient_filter(self, message: discord.Message) -> bool:
channel = message.channel
if (
message.channel.type == discord.ChannelType.private
or message.channel.type == discord.ChannelType.private_thread
):
return True
if (
message.channel.type == discord.ChannelType.text
or message.channel.type == discord.ChannelType.public_thread
):
while channel:
if self.bot_channel and channel.id == self.bot_channel.id:
return True
channel = channel.parent
return False
async def handle_message(self, message: discord.Message):
if not self.recipient_filter(message):
return
user_id = message.author.id
user_display_name = message.author.name
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
command_prefix = "!"
if (
message.channel.type == discord.ChannelType.private
and message.type == discord.MessageType.default
and message.content.startswith(command_prefix)
):
if message.type == discord.MessageType.default and message.content.startswith(command_prefix):
is_owner = self.owner_id and user_id == self.owner_id
await self.handle_command(message, is_owner)
if isinstance(message.channel, discord.Thread):
handler = self.reply_handlers.get(message.channel.id)
if handler:
await handler(message)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
if message.reference:
handler = self.reply_handlers.get(message.reference.message_id)
if handler:
await handler(message)
if handler and not handler.handler.completed:
handler.handler.on_reply(message)
logger.debug(
f"{message.type} {message.channel.type} from ({user_display_name}) {user_id}: {message.content} ({type(message.content)})"
)
async def remove_completed_handlers(self):
completed = [k for k, v in self.reply_handlers.items() if v.handler is None or v.handler.completed]
if len(completed) == 0:
return
for c in completed:
handler = self.reply_handlers[c]
del self.reply_handlers[c]
try:
await handler.handler.finalize()
except Exception:
logger.exception("handler finalize failed")
logger.info(f"removed {len(completed)} completed handlers (remaining: {len(self.reply_handlers)})")
def get_text_channel_by_name(self, channel_name) -> discord.TextChannel:
for channel in self.client.get_all_channels():
+61
View File
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
from abc import ABC
from dataclasses import dataclass
from typing import Any
import discord
from api_client import ApiClient
from channel_handlers import ChannelHandlerBase
from loguru import logger
from message_templates import MessageTemplates
@dataclass
class ReplyHandlerInfo:
msg_id: int
handler_task: asyncio.Task
handler: ChannelHandlerBase
class BotBase(ABC):
bot_channel_name: str
debug: bool
backend: ApiClient
client: discord.Client
loop: asyncio.BaseEventLoop
owner_id: int
bot_channel: discord.TextChannel
templates: MessageTemplates
reply_handlers: dict[int, ReplyHandlerInfo]
def __init__(self):
self.reply_handlers = {} # handlers by msg_id
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(
self, content: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None
) -> discord.Message:
if channel is None:
self.ensure_bot_channel()
channel = self.bot_channel
return await channel.send(content=content, view=view)
async def post_template(
self, name: str, *, view: discord.ui.View = None, channel: discord.abc.Messageable = None, **kwargs: Any
) -> discord.Message:
logger.debug(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view=view, channel=channel)
def register_reply_handler(self, msg_id: int, handler: ChannelHandlerBase):
if msg_id in self.reply_handlers:
raise RuntimeError(f"Handler already registered for msg_id: {msg_id}")
task = asyncio.create_task(coro=handler.handler_loop(), name=f"reply_handler(msg_id={msg_id})")
task.add_done_callback(lambda t: handler.on_completed())
self.reply_handlers[msg_id] = ReplyHandlerInfo(msg_id=msg_id, handler_task=task, handler=handler)
+2
View File
@@ -8,6 +8,8 @@ class BotSettings(BaseSettings):
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
+88
View File
@@ -0,0 +1,88 @@
# -*- coding: utf-8 -*-
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
import discord
from loguru import logger
class ChannelExpiredException(Exception):
pass
class ChannelHandlerBase(ABC):
queue: asyncio.Queue
completed: bool = False
expiry_date: datetime
expired: bool = False
def __init__(self, *, expiry_date: datetime = None):
self.expiry_date = expiry_date
self.queue = asyncio.Queue()
async def read(self) -> discord.Message:
"""Call this method to read the next message from the user in the handler method."""
if self.expired:
raise ChannelExpiredException()
msg = await self.queue.get()
if msg is None:
if self.expired:
raise ChannelExpiredException()
else:
raise RuntimeError("Unexpected None message read")
return msg
def on_reply(self, message: discord.Message) -> None:
self.queue.put_nowait(message)
def on_expire(self) -> None:
logger.info("ChannelHandler: on_expire")
self.expired = True
self.queue.put_nowait(None)
def on_completed(self) -> None:
logger.info("ChannelHandler: on_completed")
self.completed = True
def tick(self, now: datetime):
if now > self.expiry_date and not self.expired:
self.on_expire()
@abstractmethod
async def handler_loop(self):
...
async def finalize(self):
pass
class AutoDestructThreadHandler(ChannelHandlerBase):
first_message: discord.Message = None
thread: discord.Thread = None
def __init__(self, *, expiry_date: datetime = None):
super().__init__(expiry_date=expiry_date)
async def read(self) -> discord.Message:
try:
return await super().read()
except ChannelExpiredException:
await self.cleanup()
raise
async def cleanup(self):
logger.debug("AutoDestructThreadHandler.cleanup")
if self.thread:
logger.debug(f"deleting thread: {self.thread.name}")
await self.thread.delete()
self.thread = None
if self.first_message:
logger.debug(f"deleting first_message: {self.first_message.content}")
await self.first_message.delete()
self.first_message = None
async def finalize(self):
await self.cleanup()
return await super().finalize()
+18
View File
@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
import jinja2
from loguru import logger
class MessageTemplates:
def __init__(self, template_dir="./templates"):
self.env = jinja2.Environment(
loader=jinja2.FileSystemLoader(template_dir),
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
)
def render(self, template_name, **kwargs):
template = self.env.get_template(template_name)
txt = template.render(kwargs)
logger.debug(txt)
return txt
+3
View File
@@ -1,4 +1,7 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
+267
View File
@@ -0,0 +1,267 @@
# -*- coding: utf-8 -*-
from __future__ import annotations
from abc import abstractmethod
from datetime import timedelta
import discord
from api_client import ApiClient
from bot_base import BotBase
from channel_handlers import AutoDestructThreadHandler, ChannelExpiredException
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import DiscordTimestampStyle, discord_timestamp, utcnow
class Questionnaire(discord.ui.Modal, title="Questionnaire Response"):
name = discord.ui.TextInput(label="Name")
answer = discord.ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class ChannelTaskBase(AutoDestructThreadHandler):
thread_name: str = "Replies"
expires_after: timedelta = timedelta(minutes=5)
backend: ApiClient
async def start(self, bot: BotBase, task: protocol_schema.Task) -> discord.Message:
try:
self.bot = bot
self.task = task
self.backend = bot.backend
self.expiry_date = utcnow() + self.expires_after if self.expires_after else None
msg = await self.send_first_message()
self.first_message = msg
self.thread = await bot.bot_channel.create_thread(message=discord.Object(msg.id), name=self.thread_name)
await self.on_thread_created(self.thread)
except Exception:
logger.exception("start task failed")
await self.cleanup() # try to cleanup messag or thread
raise
bot.register_reply_handler(msg_id=msg.id, handler=self)
return msg
async def on_thread_created(self, thread: discord.Thread) -> None:
pass
@abstractmethod
async def send_first_message(self) -> discord.message:
...
def to_api_user(self, user: discord.User) -> protocol_schema.User:
return protocol_schema.User(auth_method="discord", id=user.id, display_name=user.display_name)
async def post_teaser_msg(self, template_name: str):
expiry_time = discord_timestamp(self.expiry_date, DiscordTimestampStyle.long_time)
expiry_relative = discord_timestamp(self.expiry_date, DiscordTimestampStyle.relative_time)
return await self.bot.post_template(
template_name, task=self.task, expiry_time=expiry_time, expiry_relative=expiry_relative
)
async def post_interaction(self, interaction: protocol_schema.Interaction) -> protocol_schema.Task:
api_response = await self.backend.post_interaction(interaction)
if api_response.type != "task_done":
# multi-step tasks are not supported yet
logger.error(f"multi-step tasks are not supported yet (got response type: {api_response.type})")
raise RuntimeError("Unexpected response from backend received")
return api_response
def post_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.TextReplyToPost(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
text=user_msg.content,
)
)
async def handle_text_reply_to_post(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
self.post_text_reply_to_post(user_msg)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_text_reply_to_post()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
def post_ranking(self, user_msg: discord.Message, ranking: list[int]) -> protocol_schema.Task:
return self.backend.post_interaction(
protocol_schema.PostRanking(
post_id=str(self.first_message.id),
user_post_id=str(user_msg.id),
user=self.to_api_user(user_msg.author),
ranking=ranking,
)
)
async def handle_ranking(self, user_msg: discord.Message) -> protocol_schema.Task:
try:
ranking_str = user_msg.content
ranking = [int(x) - 1 for x in ranking_str.split(",")]
self.post_ranking(user_msg, ranking=ranking)
await user_msg.add_reaction("")
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in handle_ranking()")
await user_msg.add_reaction("")
await user_msg.reply(f"❌ Error communicating with backend: {e}")
class SummarizeStoryHandler(ChannelTaskBase):
task: protocol_schema.SummarizeStoryTask
thread_name: str = "Summaries"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_summarize_story.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_summarize_story.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class InitialPromptHandler(ChannelTaskBase):
task: protocol_schema.InitialPromptTask
thread_name: str = "Prompts"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_initial_prompt.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_initial_prompt.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class UserReplyHandler(ChannelTaskBase):
task: protocol_schema.UserReplyTask
thread_name: str = "User replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_user_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_user_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class AssistantReplyHandler(ChannelTaskBase):
task: protocol_schema.AssistantReplyTask
thread_name: str = "Assistant replies"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_assistant_reply.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_assistant_reply.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_text_reply_to_post(msg)
class RankInitialPromptsHandler(ChannelTaskBase):
task: protocol_schema.RankInitialPromptsTask
thread_name: str = "User Responses"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_initial_prompts.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_initial_prompts.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RankConversationsHandler(ChannelTaskBase):
task: protocol_schema.RankConversationRepliesTask
thread_name: str = "Rankings"
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rank_conversation_replies.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
await self.bot.post_template("task_rank_conversation_replies.msg", channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
await self.handle_ranking(msg)
class RatingButton(discord.ui.Button):
def __init__(self, label, value, response_handler):
super().__init__(label=label, style=discord.ButtonStyle.green)
self.value = value
self.response_handler = response_handler
async def callback(self, interaction):
await self.response_handler(self.value, interaction)
def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
view = discord.ui.View()
for i in range(lo, hi + 1):
view.add_item(RatingButton(str(i), i, response_handler))
return view
class RateSummaryHandler(ChannelTaskBase):
task: protocol_schema.RateSummaryTask
thread_name: str = "Ratings"
async def _rating_response_handler(self, score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
if self.thread:
try:
self.backend.post_interaction(
protocol_schema.PostRating(
post_id=str(self.first_message.id),
user_post_id=str(interaction.id),
user=self.to_api_user(interaction.user),
rating=score,
)
)
await interaction.response.send_message(
f"Thanks {interaction.user.display_name}, got your feedback: {score}!"
)
except ChannelExpiredException:
raise
except Exception as e:
logger.exception("Error in _rating_response_handler()")
interaction.response.send_message(f"❌ Error communicating with backend: {e}")
async def send_first_message(self) -> discord.message:
return await self.post_teaser_msg("teaser_rate_summary.msg")
async def on_thread_created(self, thread: discord.Thread) -> None:
view = generate_rating_view(self.task.scale.min, self.task.scale.max, self._rating_response_handler)
return await self.bot.post_template("task_rate_summary.msg", view=view, channel=thread, task=self.task)
async def handler_loop(self):
while True:
msg = await self.read()
logger.info(f"on_rate_summary_reply: {msg.content}")
await msg.add_reaction("")
await msg.reply("❌ Text intput not supported.")
+13
View File
@@ -0,0 +1,13 @@
```
________ __
\_____ \ _____ ______ _______/ |_
/ | \\__ \ / ___// ___/\ __\
/ | \/ __ \_\___ \ \___ \ | |
\_______ (____ /____ >____ > |__|
\/ \/ \/ \/
{{bot_name}} {{version}}
git hash: {{git_hash}}
debug_mode: {{debug}}
```
https://github.com/LAION-AI/Open-Assistant
+15
View File
@@ -0,0 +1,15 @@
**Open-Assistant Bot Help**
Available slash-commands:
`/work` Requests a new personalized human feedback task
`/help` Show this message
{% if is_bot_owner %}
Commands for bot owners:
`!sync`
`!sync.guild`
`!sync.copy_global`
`!sync.clear_guild`
{% endif %}
+12
View File
@@ -0,0 +1,12 @@
Act as the assistant and reply to the user.
Here is the conversation so far:
{% for message in task.conversation.messages %}
{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}
{% endfor %}
:robot: Assistant: { human, pls help me! ... }
+4
View File
@@ -0,0 +1,4 @@
Please provide an initial prompt to the assistant.
{% if task.hint is not none %}
Hint: {{task.hint}}
{% endif %}
@@ -0,0 +1,13 @@
Here is the conversation so far:
{% for message in task.conversation.messages %}{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}{% endfor %}
Rank the following replies:
{% for reply in task.replies %}
{{loop.index}}: {{reply}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
@@ -0,0 +1,5 @@
Rank the following prompts:
{% for prompt in task.prompts %}
{{loop.index}}: {{prompt}}{% endfor %}
:scroll: Reply with the numbers of best to worst prompts separated by commas (example: "4,1,3,2").
+7
View File
@@ -0,0 +1,7 @@
Rate the following summary:
{{task.summary}}
Full text:
{{task.full_text}}
Rating scale: {{task.scale.min}} - {{task.scale.max}}
+2
View File
@@ -0,0 +1,2 @@
Summarize to the following story:
{{task.story}}
+12
View File
@@ -0,0 +1,12 @@
Please provide a reply to the assistant.
Here is the conversation so far:
{% for message in task.conversation.messages %}{% if message.is_assistant %}
:robot: Assistant:
{{ message.text }}
{% else %}
:person_red_hair: User:
**{{ message.text }}**"
{% endif %}{% endfor %}
{% if task.hint %}
Hint: {{ task.hint }}
{% endif %}
+3
View File
@@ -0,0 +1,3 @@
:robot: **Challenge: Assistant Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+3
View File
@@ -0,0 +1,3 @@
:microphone2: **Challenge: Initial Prompt**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:bar_chart: **Challenge: Rank Replies**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
@@ -0,0 +1,3 @@
:bar_chart: **Challenge: Rank Initial Prompts**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+3
View File
@@ -0,0 +1,3 @@
:ballot_box: **Challenge: Rate Summary**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+3
View File
@@ -0,0 +1,3 @@
:books: **Challenge: Summarize Story**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+3
View File
@@ -0,0 +1,3 @@
:person_red_hair: **Challenge: User Reply**
:point_down: Work on it here (:fire: Thread will self-destruct at {{ expiry_time }}, {{ expiry_relative }}).
+6
View File
@@ -0,0 +1,6 @@
Hi there,
I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗!
Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world.
Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands.
+52
View File
@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
import enum
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
class DiscordTimestampStyle(str, enum.Enum):
"""
Timestamp Styles
t 16:20 Short Time
T 16:20:30 Long Time
d 20/04/2021 Short Date
D 20 April 2021 Long Date
f * 20 April 2021 16:20 Short Date/Time
F Tuesday, 20 April 2021 16:20 Long Date/Time
R 2 months ago Relative Time
See https://discord.com/developers/docs/reference#message-formatting-timestamp-styles
"""
default = ""
short_time = "t"
long_time = "T"
short_date = "d"
long_date = "D"
short_date_time = "f"
long_date_time = "F"
relative_time = "R"
def discord_timestamp(d: datetime, style: DiscordTimestampStyle = DiscordTimestampStyle.default):
parts = ["<t:", str(int(d.timestamp()))]
if style:
parts.append(":")
parts.append(style)
parts.append(">")
return "".join(parts)