mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add first bits of jinja template support
This commit is contained in:
@@ -12,5 +12,7 @@ if __name__ == "__main__":
|
||||
backend_url=settings.BACKEND_URL,
|
||||
api_key=settings.API_KEY,
|
||||
owner_id=settings.OWNER_ID,
|
||||
template_dir=settings.TEMPLATE_DIR,
|
||||
debug=settings.DEBUG,
|
||||
)
|
||||
bot.run()
|
||||
|
||||
+81
-15
@@ -1,12 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import discord
|
||||
import discord.ui as ui
|
||||
import jinja2
|
||||
from api_client import ApiClient, TaskType
|
||||
from discord import app_commands
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from utils import get_git_head_hash, utcnow
|
||||
|
||||
__version__ = "0.0.1"
|
||||
BOT_NAME = "Open-Assistant Junior"
|
||||
|
||||
|
||||
class RatingButton(discord.ui.Button):
|
||||
@@ -26,6 +34,26 @@ def generate_rating_view(lo: int, hi: int, response_handler) -> discord.ui.View:
|
||||
return view
|
||||
|
||||
|
||||
class Questionnaire(ui.Modal, title="Questionnaire Response"):
|
||||
name = ui.TextInput(label="Name")
|
||||
answer = ui.TextInput(label="Answer", style=discord.TextStyle.paragraph)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.send_message(f"Thanks for your response, {self.name}!", ephemeral=True)
|
||||
|
||||
|
||||
class MessageTemplates:
|
||||
def __init__(self, template_dir="./templates"):
|
||||
self.env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(template_dir),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=("msg",), default=False, default_for_string=False),
|
||||
)
|
||||
|
||||
def render(self, template_name, **kwargs):
|
||||
template = self.env.get_template(template_name)
|
||||
return template.render(kwargs)
|
||||
|
||||
|
||||
class OpenAssistantBot:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -34,7 +62,14 @@ class OpenAssistantBot:
|
||||
backend_url: str,
|
||||
api_key: str,
|
||||
owner_id: Optional[Union[int, str]] = None,
|
||||
template_dir: str = "./templates",
|
||||
debug: bool = False,
|
||||
):
|
||||
self.template_dir = Path(template_dir)
|
||||
self.bot_channel_name = bot_channel_name
|
||||
self.templates = MessageTemplates(template_dir)
|
||||
self.debug = debug
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
@@ -59,6 +94,11 @@ class OpenAssistantBot:
|
||||
client.loop.create_task(self.background_timer(), name="OpenAssistantBot.background_timer()")
|
||||
logger.info(f"{client.user} is now running!")
|
||||
|
||||
await self.delete_all_old_bot_messages()
|
||||
if self.debug:
|
||||
await self.post_boot_message()
|
||||
await self.post_welcome_message()
|
||||
|
||||
@client.event
|
||||
async def on_message(message: discord.Message):
|
||||
# ignore own messages
|
||||
@@ -78,7 +118,42 @@ class OpenAssistantBot:
|
||||
@self.tree.command()
|
||||
async def work(interaction: discord.Interaction):
|
||||
"""Request a new personalized task"""
|
||||
await interaction.response.send_message(f"work command by {interaction.user.name}")
|
||||
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.rate_summary, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
q = Questionnaire()
|
||||
await interaction.response.send_modal(q)
|
||||
|
||||
def ensure_bot_channel(self) -> None:
|
||||
if self.bot_channel is None:
|
||||
raise RuntimeError(f"bot channel '{self.bot_channel_name}' not found")
|
||||
|
||||
async def post(self, content: str, view: discord.ui.View = None) -> discord.Message:
|
||||
self.ensure_bot_channel()
|
||||
return await self.bot_channel.send(content=content)
|
||||
|
||||
async def post_template(self, name: str, view: discord.ui.View = None, **kwargs: Any) -> discord.Message:
|
||||
logger.info(f"rendering {name}")
|
||||
text = self.templates.render(name, **kwargs)
|
||||
return await self.post(text, view)
|
||||
|
||||
async def post_boot_message(self) -> discord.Message:
|
||||
return await self.post_template(
|
||||
"boot.msg", bot_name=BOT_NAME, version=__version__, git_hash=get_git_head_hash(), debug=self.debug
|
||||
)
|
||||
|
||||
async def post_welcome_message(self) -> discord.Message:
|
||||
return await self.post_template("welcome.msg")
|
||||
|
||||
async def delete_all_old_bot_messages(self) -> None:
|
||||
logger.info("Begin deleting old bot messages.")
|
||||
look_until = utcnow() - timedelta(days=365)
|
||||
async for msg in self.bot_channel.history(limit=None):
|
||||
msg: discord.Message
|
||||
if msg.created_at < look_until:
|
||||
break
|
||||
if msg.author.id == self.client.user.id:
|
||||
await msg.delete()
|
||||
logger.info("Completed deleting old bot messages.")
|
||||
|
||||
async def print_separtor(self, title: str) -> discord.Message:
|
||||
msg: discord.Message = await self.bot_channel.send(f"\n:point_right: {title} :point_left:\n")
|
||||
@@ -100,21 +175,12 @@ class OpenAssistantBot:
|
||||
return msg
|
||||
|
||||
async def generate_rate_summary(self, task: protocol_schema.RateSummaryTask):
|
||||
s = [
|
||||
"Rate the following summary:",
|
||||
task.summary,
|
||||
"Full text:",
|
||||
task.full_text,
|
||||
f"Rating scale: {task.scale.min} - {task.scale.max}",
|
||||
]
|
||||
text = "\n".join(s)
|
||||
|
||||
async def rating_response_handler(score, interaction: discord.Interaction):
|
||||
logger.info("rating_response_handler", score)
|
||||
await interaction.response.send_message(f"got your feedback: {score}")
|
||||
|
||||
view = generate_rating_view(task.scale.min, task.scale.max, rating_response_handler)
|
||||
msg: discord.Message = await self.bot_channel.send(text, view=view)
|
||||
msg: discord.Message = await self.post_template("rate_summary", task=task, view=view)
|
||||
|
||||
async def on_reply(message: discord.Message):
|
||||
logger.info("on_summary_reply", message)
|
||||
@@ -235,8 +301,8 @@ class OpenAssistantBot:
|
||||
return msg
|
||||
|
||||
async def next_task(self):
|
||||
# task = self.backend.fetch_task(protocol_schema.TaskRequestType.user_reply, user=None)
|
||||
task = self.backend.fetch_random_task(user=None)
|
||||
task = self.backend.fetch_task(protocol_schema.TaskRequestType.summarize_story, user=None)
|
||||
# task = self.backend.fetch_random_task(user=None)
|
||||
|
||||
await self.print_separtor("New Task")
|
||||
|
||||
@@ -269,7 +335,7 @@ class OpenAssistantBot:
|
||||
await self.next_task()
|
||||
except Exception:
|
||||
logger.exception("fetching next task failed")
|
||||
await asyncio.sleep(30)
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _sync(self, command: str, message: discord.Message):
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ class BotSettings(BaseSettings):
|
||||
BOT_TOKEN: str
|
||||
BOT_CHANNEL_NAME: str = "bot"
|
||||
OWNER_ID: int = None
|
||||
TEMPLATE_DIR: str = "./templates"
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
settings = BotSettings(_env_file=".env")
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
discord.py==2.1.0
|
||||
Jinja2==3.1.2
|
||||
pydantic==1.9.1
|
||||
python-dotenv==0.21.0
|
||||
pytz==2022.7
|
||||
requests==2.28.1
|
||||
schedule==1.1.0
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
```
|
||||
________ __
|
||||
\_____ \ _____ _____ _______/ |_
|
||||
/ | \\__ \ \__ \ / ___/\ __\
|
||||
/ | \/ __ \_/ __ \_\___ \ | |
|
||||
\_______ (____ (____ /____ > |__|
|
||||
\/ \/ \/ \/
|
||||
|
||||
{{bot_name}} {{version}}
|
||||
git hash: {{git_hash}}
|
||||
debug_mode: {{debug}}
|
||||
```
|
||||
|
||||
https://github.com/LAION-AI/Open-Assistant
|
||||
@@ -0,0 +1,7 @@
|
||||
Rate the following summary:
|
||||
{{task.summary}}
|
||||
|
||||
Full text:
|
||||
{{task.full_text}}
|
||||
|
||||
Rating scale: {{task.scale.min}} - {{task.scale.max}}
|
||||
@@ -0,0 +1,6 @@
|
||||
Hi there,
|
||||
|
||||
I am the **Open-Assistant Junior Bot** 🤖. I would love to get your feedback 🤗!
|
||||
Currently I am still learning from human demonstrations how to reply to instructions. When I am grown up I want to become a fully functional AI Assistant language model that is fully open-sourced and assists millions of humans all over the world.
|
||||
|
||||
Type `/tutorial` to start the tutorial or `/help` to see a list of all my commands.
|
||||
@@ -0,0 +1,17 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
def get_git_head_hash():
|
||||
# get current git hash
|
||||
x = subprocess.run(["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE, universal_newlines=True)
|
||||
if x.returncode == 0:
|
||||
return x.stdout.replace("\n", "")
|
||||
return None
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(pytz.UTC)
|
||||
@@ -11,3 +11,4 @@ line_length = 120
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
target-version = ['py310']
|
||||
exclude = ["bot/templates"]
|
||||
|
||||
Reference in New Issue
Block a user