diff --git a/bot/__main__.py b/bot/__main__.py index 362b16f0..0047bce7 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -12,5 +12,7 @@ if __name__ == "__main__": backend_url=settings.BACKEND_URL, api_key=settings.API_KEY, owner_id=settings.OWNER_ID, + template_dir=settings.TEMPLATE_DIR, + debug=settings.DEBUG, ) bot.run() diff --git a/bot/bot.py b/bot/bot.py index e6e90770..6ab11ed0 100644 --- a/bot/bot.py +++ b/bot/bot.py @@ -1,12 +1,20 @@ # -*- coding: utf-8 -*- import asyncio -from typing import Optional, Union +from datetime import timedelta +from pathlib import Path +from typing import Any, Optional, Union import discord +import discord.ui as ui +import jinja2 from api_client import ApiClient, TaskType from discord import app_commands from loguru import logger from oasst_shared.schemas import protocol as protocol_schema +from utils import get_git_head_hash, utcnow + +__version__ = "0.0.1" +BOT_NAME = "Open-Assistant Junior" class RatingButton(discord.ui.Button): @@ -26,6 +34,26 @@ def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View: return view +class Questionnaire(ui.Modal, title="Questionnaire Response"): + name = ui.TextInput(label="Name") + answer = ui.TextInput(label="Answer", style=discord.TextStyle.paragraph) + + async def on_submit(self, interaction: discord.Interaction): + await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True) + + +class MessageTemplates: + def __init__(self, template_dir="./templates"): + self.env = jinja2.Environment( + loader=jinja2.FileSystemLoader(template_dir), + autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False), + ) + + def render(self, template_name, **kwargs): + template = self.env.get_template(template_name) + return template.render(kwargs) + + class OpenAssistantBot: def __init__( self, @@ -34,7 +62,14 @@ class OpenAssistantBot: backend_url: str, api_key: str, owner_id: Optional[Union[int, str]] = None, + template_dir: str = "./templates", + debug: bool = False, ): + self.template_dir = Path(template_dir) + self.bot_channel_name = bot_channel_name + self.templates = MessageTemplates(template_dir) + self.debug = debug + intents = discord.Intents.default() intents.message_content = True @@ -59,6 +94,11 @@ class OpenAssistantBot: client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()") logger.info(f"{client.user} is now running!") + await self.delete_all_old_bot_messages() + if self.debug: + await self.post_boot_message() + await self.post_welcome_message() + @client.event async def on_message(message: discord.Message): # ignore own messages @@ -78,7 +118,42 @@ class OpenAssistantBot: @self.tree.command() async def work(interaction: discord.Interaction): """Request a new personalized task""" - await interaction.response.send_message(f"work command by {interaction.user.name}") + # task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None) + # task = self.backend.fetch_random_task(user=None) + q = Questionnaire() + await interaction.response.send_modal(q) + + def ensure_bot_channel(self) -> None: + if self.bot_channel is None: + raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found") + + async def post(self, content: str, view: discord.ui.View = None) -> discord.Message: + self.ensure_bot_channel() + return await self.bot_channel.send(content=content) + + async def post_template(self, name: str, view: discord.ui.View = None, **kwargs: Any) -> discord.Message: + logger.info(f"rendering {name}") + text = self.templates.render(name, **kwargs) + return await self.post(text, view) + + async def post_boot_message(self) -> discord.Message: + return await self.post_template( + "boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug + ) + + async def post_welcome_message(self) -> discord.Message: + return await self.post_template("welcome.msg") + + async def delete_all_old_bot_messages(self) -> None: + logger.info("Begin deleting old bot messages.") + look_until = utcnow() - timedelta(days=365) + async for msg in self.bot_channel.history(limit=None): + msg: discord.Message + if msg.created_at < look_until: + break + if msg.author.id == self.client.user.id: + await msg.delete() + logger.info("Completed deleting old bot messages.") async def print_separtor(self, title: str) -> discord.Message: msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n") @@ -100,21 +175,12 @@ class OpenAssistantBot: return msg async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask): - s = [ - "Rate the following summary:", - task.summary, - "Full text:", - task.full_text, - f"Rating scale: {task.scale.min} - {task.scale.max}", - ] - text = "\n".join(s) - async def rating_response_handler(score, interaction: discord.Interaction): logger.info("rating_response_handler", score) await interaction.response.send_message(f"got your feedback: {score}") view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler) - msg: discord.Message = await self.bot_channel.send(text, view=view) + msg: discord.Message = await self.post_template("rate_summary", task=task, view=view) async def on_reply(message: discord.Message): logger.info("on_summary_reply", message) @@ -235,8 +301,8 @@ class OpenAssistantBot: return msg async def next_task(self): - # task = self.backend.fetch_task(protocol_schema.TaskRequestType.user_reply, user=None) - task = self.backend.fetch_random_task(user=None) + task = self.backend.fetch_task(protocol_schema.TaskRequestType.summarize_story, user=None) + # task = self.backend.fetch_random_task(user=None) await self.print_separtor("New Task") @@ -269,7 +335,7 @@ class OpenAssistantBot: await self.next_task() except Exception: logger.exception("fetching next task failed") - await asyncio.sleep(30) + await asyncio.sleep(60) async def _sync(self, command: str, message: discord.Message): diff --git a/bot/bot_settings.py b/bot/bot_settings.py index b7a46aa6..c976d7cd 100644 --- a/bot/bot_settings.py +++ b/bot/bot_settings.py @@ -8,6 +8,8 @@ class BotSettings(BaseSettings): BOT_TOKEN: str BOT_CHANNEL_NAME: str = "bot" OWNER_ID: int = None + TEMPLATE_DIR: str = "./templates" + DEBUG: bool = True settings = BotSettings(_env_file=".env") diff --git a/bot/requirements.txt b/bot/requirements.txt index da4762a6..927ebcf2 100644 --- a/bot/requirements.txt +++ b/bot/requirements.txt @@ -1,4 +1,7 @@ discord.py==2.1.0 +Jinja2==3.1.2 pydantic==1.9.1 python-dotenv==0.21.0 +pytz==2022.7 requests==2.28.1 +schedule==1.1.0 diff --git a/bot/templates/boot.msg b/bot/templates/boot.msg new file mode 100644 index 00000000..1d212e99 --- /dev/null +++ b/bot/templates/boot.msg @@ -0,0 +1,14 @@ +``` +________ __ +\_____ \ _____ _____ _______/ |_ + / | \\__ \ \__ \ / ___/\ __\ +/ | \/ __ \_/ __ \_\___ \ | | +\_______ (____ (____ /____ > |__| + \/ \/ \/ \/ + +{{bot_name}} {{version}} +git hash: {{git_hash}} +debug_mode: {{debug}} +``` + +https://github.com/LAION-AI/Open-Assistant diff --git a/bot/templates/rate_summary.msg b/bot/templates/rate_summary.msg new file mode 100644 index 00000000..31007c40 --- /dev/null +++ b/bot/templates/rate_summary.msg @@ -0,0 +1,7 @@ +Rate the following summary: +{{task.summary}} + +Full text: +{{task.full_text}} + +Rating scale: {{task.scale.min}} - {{task.scale.max}} diff --git a/bot/templates/welcome.msg b/bot/templates/welcome.msg new file mode 100644 index 00000000..553f7925 --- /dev/null +++ b/bot/templates/welcome.msg @@ -0,0 +1,6 @@ +Hi there, + +I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗! +Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world. + +Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands. diff --git a/bot/utils.py b/bot/utils.py new file mode 100644 index 00000000..1a06b833 --- /dev/null +++ b/bot/utils.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +import subprocess +from datetime import datetime + +import pytz + + +def get_git_head_hash(): + # get current git hash + x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True) + if x.returncode == 0: + return x.stdout.replace("\n", "") + return None + + +def utcnow() -> datetime: + return datetime.now(pytz.UTC) diff --git a/pyproject.toml b/pyproject.toml index 83b614a2..30541eec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,4 @@ line_length = 120 [tool.black] line-length = 120 target-version = ['py310'] +exclude = ["bot/templates"]