mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
switch to loguru
This commit is contained in:
@@ -5,6 +5,7 @@ import lightbulb
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import mention
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin("GuildSettings")
|
||||
plugin.add_checks(lightbulb.guild_only)
|
||||
@@ -34,6 +35,7 @@ async def get(ctx: lightbulb.SlashContext) -> None:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
if row is None:
|
||||
logger.warning(f"No guild settings for {ctx.guild_id}")
|
||||
await ctx.respond("No settings found for this guild.")
|
||||
return
|
||||
|
||||
@@ -70,6 +72,7 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
|
||||
)
|
||||
|
||||
await conn.commit()
|
||||
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
|
||||
@@ -4,6 +4,7 @@ from glob import glob
|
||||
|
||||
import hikari
|
||||
import lightbulb
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"HotReloadPlugin",
|
||||
@@ -37,7 +38,7 @@ async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikar
|
||||
required=False,
|
||||
default=None,
|
||||
)
|
||||
@lightbulb.command("reload", "Reload a plugin")
|
||||
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
|
||||
@lightbulb.implements(lightbulb.SlashCommand)
|
||||
async def reload(ctx: lightbulb.SlashContext):
|
||||
"""Reload a plugin or all plugins."""
|
||||
@@ -45,10 +46,12 @@ async def reload(ctx: lightbulb.SlashContext):
|
||||
if ctx.options.plugin is None:
|
||||
ctx.bot.reload_extensions(*_get_extensions())
|
||||
await ctx.respond("Reloaded all plugins.")
|
||||
logger.info("Reloaded all plugins.")
|
||||
# Otherwise, reload the specified plugin.
|
||||
else:
|
||||
ctx.bot.reload_extensions(ctx.options.plugin)
|
||||
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
|
||||
logger.info(f"Reloaded `{ctx.options.plugin}`.")
|
||||
|
||||
|
||||
def load(bot: lightbulb.BotApp):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Task plugin for testing different data collection methods."""
|
||||
# TODO: Delete this once user input method has been decided for final bot.
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
@@ -16,8 +16,6 @@ plugin = lightbulb.Plugin("TaskPlugin")
|
||||
|
||||
MAX_TASK_TIME = 60 * 60
|
||||
MAX_TASK_ACCEPT_TIME = 60
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@plugin.command
|
||||
|
||||
@@ -9,6 +9,7 @@ import miru
|
||||
from aiosqlite import Connection
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
|
||||
plugin = lightbulb.Plugin(
|
||||
"TextLabels",
|
||||
@@ -49,6 +50,7 @@ class LabelModal(miru.Modal):
|
||||
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
|
||||
flags=hikari.MessageFlag.EPHEMERAL,
|
||||
)
|
||||
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
|
||||
|
||||
# Send a notification to the log channel
|
||||
assert context.guild_id is not None # `guild_only` check
|
||||
@@ -56,6 +58,7 @@ class LabelModal(miru.Modal):
|
||||
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
|
||||
|
||||
if guild_settings is None or guild_settings.log_channel_id is None:
|
||||
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
|
||||
return
|
||||
|
||||
embed = (
|
||||
@@ -148,6 +151,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
|
||||
"""Label a message."""
|
||||
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
|
||||
# so the select menu will open the modal
|
||||
|
||||
msg: hikari.Message = ctx.options.target
|
||||
# Exit if the message is empty
|
||||
if not msg.content:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Work plugin for collecting user data."""
|
||||
import asyncio
|
||||
import logging
|
||||
import typing as t
|
||||
from datetime import datetime
|
||||
|
||||
@@ -13,6 +12,7 @@ from aiosqlite import Connection
|
||||
from bot.api_client import OasstApiClient, TaskType
|
||||
from bot.db.schemas import GuildSettings
|
||||
from bot.utils import EMPTY
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import TaskRequestType
|
||||
|
||||
@@ -21,9 +21,6 @@ plugin = lightbulb.Plugin("WorkPlugin")
|
||||
MAX_TASK_TIME = 60 * 60 # 1 hour
|
||||
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
@plugin.command
|
||||
@lightbulb.option(
|
||||
@@ -41,7 +38,7 @@ async def work(ctx: lightbulb.SlashContext):
|
||||
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
|
||||
|
||||
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
|
||||
logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}")
|
||||
logger.debug(f"Starting task_type: {task_type!r}")
|
||||
|
||||
await _handle_task(ctx, task_type)
|
||||
|
||||
@@ -76,6 +73,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.author.send("Task timed out. Exiting")
|
||||
await oasst_api.nack_task(task.id, reason="timed out")
|
||||
logger.info(f"Task {task.id} timed out")
|
||||
return
|
||||
|
||||
# Invalid response
|
||||
@@ -83,7 +81,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
await ctx.author.send("Invalid response")
|
||||
continue
|
||||
|
||||
logger.info(f"Successful user input received: {event.content}")
|
||||
logger.debug(f"Successful user input received: {event.content}")
|
||||
|
||||
# Send the response to the backend
|
||||
reply = protocol_schema.TextReplyToPost(
|
||||
@@ -105,7 +103,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
|
||||
completed = True
|
||||
continue
|
||||
else:
|
||||
logger.fatal(f"Unexpected task type received: {new_task.type}")
|
||||
logger.critical(f"Unexpected task type received: {new_task.type}")
|
||||
|
||||
# Send a message in the log channel that the task is complete
|
||||
# TODO: Maybe do something with the msg ID so users can rate the "answer"
|
||||
@@ -159,7 +157,7 @@ async def _select_task(
|
||||
task = await oasst_api.fetch_task(task_type, user)
|
||||
resp, msg_id = await _send_task(ctx, task)
|
||||
|
||||
logger.debug(f"user choice: {resp}")
|
||||
logger.debug(f"User choice: {resp}")
|
||||
match resp:
|
||||
case "accept":
|
||||
logger.info(f"Task {task.id} accepted, sending ACK")
|
||||
@@ -200,32 +198,32 @@ async def _send_task(
|
||||
# Create an embed based on the task's type
|
||||
if task.type == TaskRequestType.initial_prompt:
|
||||
assert isinstance(task, protocol_schema.InitialPromptTask)
|
||||
logger.info("sending initial prompt task")
|
||||
logger.debug("sending initial prompt task")
|
||||
embed = _initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_initial_prompts:
|
||||
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
|
||||
logger.info("sending rank initial prompt task")
|
||||
logger.debug("sending rank initial prompt task")
|
||||
embed = _rank_initial_prompt_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_user_replies:
|
||||
assert isinstance(task, protocol_schema.RankUserRepliesTask)
|
||||
logger.info("sending rank user reply task")
|
||||
logger.debug("sending rank user reply task")
|
||||
embed = _rank_user_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.rank_assistant_replies:
|
||||
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
|
||||
logger.info("sending rank assistant reply task")
|
||||
logger.debug("sending rank assistant reply task")
|
||||
embed = _rank_assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.user_reply:
|
||||
assert isinstance(task, protocol_schema.UserReplyTask)
|
||||
logger.info("sending user reply task")
|
||||
logger.debug("sending user reply task")
|
||||
embed = _user_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.assistant_reply:
|
||||
assert isinstance(task, protocol_schema.AssistantReplyTask)
|
||||
logger.info("sending assistant reply task")
|
||||
logger.debug("sending assistant reply task")
|
||||
embed = _assistant_reply_embed(task)
|
||||
|
||||
elif task.type == TaskRequestType.summarize_story:
|
||||
@@ -234,7 +232,7 @@ async def _send_task(
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.error(f"unknown task type {task.type}")
|
||||
logger.critical(f"unknown task type {task.type}")
|
||||
raise ValueError(f"unknown task type {task.type}")
|
||||
|
||||
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
|
||||
@@ -279,7 +277,7 @@ def _validate_user_input(content: str | None, task_type: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
else:
|
||||
logger.fatal(f"Unknown task type {task_type}")
|
||||
logger.critical(f"Unknown task type {task_type}")
|
||||
raise ValueError(f"Unknown task type {task_type}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user