add first bits of jinja template support

This commit is contained in:
Andreas Köpf
2022-12-21 20:15:38 +01:00
parent 55c79b98f1
commit d20828e759
9 changed files with 133 additions and 15 deletions
+2
View File
@@ -12,5 +12,7 @@ if __name__ == "__main__":
backend_url=settings.BACKEND_URL,
api_key=settings.API_KEY,
owner_id=settings.OWNER_ID,
template_dir=settings.TEMPLATE_DIR,
debug=settings.DEBUG,
)
bot.run()
+81 -15
View File
@@ -1,12 +1,20 @@
# -*- coding: utf-8 -*-
import asyncio
from typing import Optional, Union
from datetime import timedelta
from pathlib import Path
from typing import Any, Optional, Union
import discord
import discord.ui as ui
import jinja2
from api_client import ApiClient, TaskType
from discord import app_commands
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from utils import get_git_head_hash, utcnow
__version__ = "0.0.1"
BOT_NAME = "Open-Assistant Junior"
class RatingButton(discord.ui.Button):
@@ -26,6 +34,26 @@ def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
return view
class Questionnaire(ui.Modal, title="Questionnaire Response"):
name = ui.TextInput(label="Name")
answer = ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
class MessageTemplates:
def __init__(self, template_dir="./templates"):
self.env = jinja2.Environment(
loader=jinja2.FileSystemLoader(template_dir),
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
)
def render(self, template_name, **kwargs):
template = self.env.get_template(template_name)
return template.render(kwargs)
class OpenAssistantBot:
def __init__(
self,
@@ -34,7 +62,14 @@ class OpenAssistantBot:
backend_url: str,
api_key: str,
owner_id: Optional[Union[int, str]] = None,
template_dir: str = "./templates",
debug: bool = False,
):
self.template_dir = Path(template_dir)
self.bot_channel_name = bot_channel_name
self.templates = MessageTemplates(template_dir)
self.debug = debug
intents = discord.Intents.default()
intents.message_content = True
@@ -59,6 +94,11 @@ class OpenAssistantBot:
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
logger.info(f"{client.user} is now running!")
await self.delete_all_old_bot_messages()
if self.debug:
await self.post_boot_message()
await self.post_welcome_message()
@client.event
async def on_message(message: discord.Message):
# ignore own messages
@@ -78,7 +118,42 @@ class OpenAssistantBot:
@self.tree.command()
async def work(interaction: discord.Interaction):
"""Request a new personalized task"""
await interaction.response.send_message(f"work command by {interaction.user.name}")
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
# task = self.backend.fetch_random_task(user=None)
q = Questionnaire()
await interaction.response.send_modal(q)
def ensure_bot_channel(self) -> None:
if self.bot_channel is None:
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
async def post(self, content: str, view: discord.ui.View = None) -> discord.Message:
self.ensure_bot_channel()
return await self.bot_channel.send(content=content)
async def post_template(self, name: str, view: discord.ui.View = None, **kwargs: Any) -> discord.Message:
logger.info(f"rendering {name}")
text = self.templates.render(name, **kwargs)
return await self.post(text, view)
async def post_boot_message(self) -> discord.Message:
return await self.post_template(
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
)
async def post_welcome_message(self) -> discord.Message:
return await self.post_template("welcome.msg")
async def delete_all_old_bot_messages(self) -> None:
logger.info("Begin deleting old bot messages.")
look_until = utcnow() - timedelta(days=365)
async for msg in self.bot_channel.history(limit=None):
msg: discord.Message
if msg.created_at < look_until:
break
if msg.author.id == self.client.user.id:
await msg.delete()
logger.info("Completed deleting old bot messages.")
async def print_separtor(self, title: str) -> discord.Message:
msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n")
@@ -100,21 +175,12 @@ class OpenAssistantBot:
return msg
async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask):
s = [
"Rate the following summary:",
task.summary,
"Full text:",
task.full_text,
f"Rating scale: {task.scale.min} - {task.scale.max}",
]
text = "\n".join(s)
async def rating_response_handler(score, interaction: discord.Interaction):
logger.info("rating_response_handler", score)
await interaction.response.send_message(f"got your feedback: {score}")
view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler)
msg: discord.Message = await self.bot_channel.send(text, view=view)
msg: discord.Message = await self.post_template("rate_summary", task=task, view=view)
async def on_reply(message: discord.Message):
logger.info("on_summary_reply", message)
@@ -235,8 +301,8 @@ class OpenAssistantBot:
return msg
async def next_task(self):
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.user_reply, user=None)
task = self.backend.fetch_random_task(user=None)
task = self.backend.fetch_task(protocol_schema.TaskRequestType.summarize_story, user=None)
# task = self.backend.fetch_random_task(user=None)
await self.print_separtor("New Task")
@@ -269,7 +335,7 @@ class OpenAssistantBot:
await self.next_task()
except Exception:
logger.exception("fetching next task failed")
await asyncio.sleep(30)
await asyncio.sleep(60)
async def _sync(self, command: str, message: discord.Message):
+2
View File
@@ -8,6 +8,8 @@ class BotSettings(BaseSettings):
BOT_TOKEN: str
BOT_CHANNEL_NAME: str = "bot"
OWNER_ID: int = None
TEMPLATE_DIR: str = "./templates"
DEBUG: bool = True
settings = BotSettings(_env_file=".env")
+3
View File
@@ -1,4 +1,7 @@
discord.py==2.1.0
Jinja2==3.1.2
pydantic==1.9.1
python-dotenv==0.21.0
pytz==2022.7
requests==2.28.1
schedule==1.1.0
+14
View File
@@ -0,0 +1,14 @@
```
________ __
\_____ \ _____ _____ _______/ |_
/ | \\__ \ \__ \ / ___/\ __\
/ | \/ __ \_/ __ \_\___ \ | |
\_______ (____ (____ /____ > |__|
\/ \/ \/ \/
{{bot_name}} {{version}}
git hash: {{git_hash}}
debug_mode: {{debug}}
```
https://github.com/LAION-AI/Open-Assistant
+7
View File
@@ -0,0 +1,7 @@
Rate the following summary:
{{task.summary}}
Full text:
{{task.full_text}}
Rating scale: {{task.scale.min}} - {{task.scale.max}}
+6
View File
@@ -0,0 +1,6 @@
Hi there,
I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗!
Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world.
Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands.
+17
View File
@@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
import subprocess
from datetime import datetime
import pytz
def get_git_head_hash():
# get current git hash
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
if x.returncode == 0:
return x.stdout.replace("\n", "")
return None
def utcnow() -> datetime:
return datetime.now(pytz.UTC)
+1
View File
@@ -11,3 +11,4 @@ line_length = 120
[tool.black]
line-length = 120
target-version = ['py310']
exclude = ["bot/templates"]