From 6cccd74e3491efec03b4ffed5841ec510cae0e50 Mon Sep 17 00:00:00 2001 From: Alex Ott <66271487+AlexanderHOtt@users.noreply.github.com> Date: Fri, 30 Dec 2022 05:28:51 -0800 Subject: [PATCH] switch to loguru --- discord-bot/bot/extensions/guild_settings.py | 3 ++ discord-bot/bot/extensions/hot_reload.py | 5 +++- discord-bot/bot/extensions/tasks.py | 4 +-- discord-bot/bot/extensions/text_labels.py | 4 +++ discord-bot/bot/extensions/work.py | 30 +++++++++----------- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/discord-bot/bot/extensions/guild_settings.py b/discord-bot/bot/extensions/guild_settings.py index 5623cd5a..f5785b8d 100644 --- a/discord-bot/bot/extensions/guild_settings.py +++ b/discord-bot/bot/extensions/guild_settings.py @@ -5,6 +5,7 @@ import lightbulb from aiosqlite import Connection from bot.db.schemas import GuildSettings from bot.utils import mention +from loguru import logger plugin = lightbulb.Plugin("GuildSettings") plugin.add_checks(lightbulb.guild_only) @@ -34,6 +35,7 @@ async def get(ctx: lightbulb.SlashContext) -> None: row = await cursor.fetchone() if row is None: + logger.warning(f"No guild settings for {ctx.guild_id}") await ctx.respond("No settings found for this guild.") return @@ -70,6 +72,7 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None: ) await conn.commit() + logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.") def load(bot: lightbulb.BotApp): diff --git a/discord-bot/bot/extensions/hot_reload.py b/discord-bot/bot/extensions/hot_reload.py index 28bcede3..ad2cd730 100644 --- a/discord-bot/bot/extensions/hot_reload.py +++ b/discord-bot/bot/extensions/hot_reload.py @@ -4,6 +4,7 @@ from glob import glob import hikari import lightbulb +from loguru import logger plugin = lightbulb.Plugin( "HotReloadPlugin", @@ -37,7 +38,7 @@ async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikar required=False, default=None, ) -@lightbulb.command("reload", "Reload a plugin") +@lightbulb.command("reload", "Reload a plugin", ephemeral=True) @lightbulb.implements(lightbulb.SlashCommand) async def reload(ctx: lightbulb.SlashContext): """Reload a plugin or all plugins.""" @@ -45,10 +46,12 @@ async def reload(ctx: lightbulb.SlashContext): if ctx.options.plugin is None: ctx.bot.reload_extensions(*_get_extensions()) await ctx.respond("Reloaded all plugins.") + logger.info("Reloaded all plugins.") # Otherwise, reload the specified plugin. else: ctx.bot.reload_extensions(ctx.options.plugin) await ctx.respond(f"Reloaded `{ctx.options.plugin}`.") + logger.info(f"Reloaded `{ctx.options.plugin}`.") def load(bot: lightbulb.BotApp): diff --git a/discord-bot/bot/extensions/tasks.py b/discord-bot/bot/extensions/tasks.py index 70fa5257..94ddb973 100644 --- a/discord-bot/bot/extensions/tasks.py +++ b/discord-bot/bot/extensions/tasks.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """Task plugin for testing different data collection methods.""" +# TODO: Delete this once user input method has been decided for final bot. import asyncio -import logging import typing as t from datetime import datetime, timedelta @@ -16,8 +16,6 @@ plugin = lightbulb.Plugin("TaskPlugin") MAX_TASK_TIME = 60 * 60 MAX_TASK_ACCEPT_TIME = 60 -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) @plugin.command diff --git a/discord-bot/bot/extensions/text_labels.py b/discord-bot/bot/extensions/text_labels.py index 53d0a1fd..618e6642 100644 --- a/discord-bot/bot/extensions/text_labels.py +++ b/discord-bot/bot/extensions/text_labels.py @@ -9,6 +9,7 @@ import miru from aiosqlite import Connection from bot.db.schemas import GuildSettings from bot.utils import EMPTY +from loguru import logger plugin = lightbulb.Plugin( "TextLabels", @@ -49,6 +50,7 @@ class LabelModal(miru.Modal): f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.", flags=hikari.MessageFlag.EPHEMERAL, ) + logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.") # Send a notification to the log channel assert context.guild_id is not None # `guild_only` check @@ -56,6 +58,7 @@ class LabelModal(miru.Modal): guild_settings = await GuildSettings.from_db(conn, context.guild_id) if guild_settings is None or guild_settings.log_channel_id is None: + logger.warning(f"No guild settings or log channel for guild {context.guild_id}") return embed = ( @@ -148,6 +151,7 @@ async def label_message_text(ctx: lightbulb.MessageContext): """Label a message.""" # We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction # so the select menu will open the modal + msg: hikari.Message = ctx.options.target # Exit if the message is empty if not msg.content: diff --git a/discord-bot/bot/extensions/work.py b/discord-bot/bot/extensions/work.py index 8e3ad7b5..5244920b 100644 --- a/discord-bot/bot/extensions/work.py +++ b/discord-bot/bot/extensions/work.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- """Work plugin for collecting user data.""" import asyncio -import logging import typing as t from datetime import datetime @@ -13,6 +12,7 @@ from aiosqlite import Connection from bot.api_client import OasstApiClient, TaskType from bot.db.schemas import GuildSettings from bot.utils import EMPTY +from loguru import logger from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import TaskRequestType @@ -21,9 +21,6 @@ plugin = lightbulb.Plugin("WorkPlugin") MAX_TASK_TIME = 60 * 60 # 1 hour MAX_TASK_ACCEPT_TIME = 60 # 1 minute -logger = logging.getLogger(__name__) -logger.setLevel(logging.DEBUG) - @plugin.command @lightbulb.option( @@ -41,7 +38,7 @@ async def work(ctx: lightbulb.SlashContext): task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1]) await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL) - logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}") + logger.debug(f"Starting task_type: {task_type!r}") await _handle_task(ctx, task_type) @@ -76,6 +73,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) except asyncio.TimeoutError: await ctx.author.send("Task timed out. Exiting") await oasst_api.nack_task(task.id, reason="timed out") + logger.info(f"Task {task.id} timed out") return # Invalid response @@ -83,7 +81,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) await ctx.author.send("Invalid response") continue - logger.info(f"Successful user input received: {event.content}") + logger.debug(f"Successful user input received: {event.content}") # Send the response to the backend reply = protocol_schema.TextReplyToPost( @@ -105,7 +103,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType) completed = True continue else: - logger.fatal(f"Unexpected task type received: {new_task.type}") + logger.critical(f"Unexpected task type received: {new_task.type}") # Send a message in the log channel that the task is complete # TODO: Maybe do something with the msg ID so users can rate the "answer" @@ -159,7 +157,7 @@ async def _select_task( task = await oasst_api.fetch_task(task_type, user) resp, msg_id = await _send_task(ctx, task) - logger.debug(f"user choice: {resp}") + logger.debug(f"User choice: {resp}") match resp: case "accept": logger.info(f"Task {task.id} accepted, sending ACK") @@ -200,32 +198,32 @@ async def _send_task( # Create an embed based on the task's type if task.type == TaskRequestType.initial_prompt: assert isinstance(task, protocol_schema.InitialPromptTask) - logger.info("sending initial prompt task") + logger.debug("sending initial prompt task") embed = _initial_prompt_embed(task) elif task.type == TaskRequestType.rank_initial_prompts: assert isinstance(task, protocol_schema.RankInitialPromptsTask) - logger.info("sending rank initial prompt task") + logger.debug("sending rank initial prompt task") embed = _rank_initial_prompt_embed(task) elif task.type == TaskRequestType.rank_user_replies: assert isinstance(task, protocol_schema.RankUserRepliesTask) - logger.info("sending rank user reply task") + logger.debug("sending rank user reply task") embed = _rank_user_reply_embed(task) elif task.type == TaskRequestType.rank_assistant_replies: assert isinstance(task, protocol_schema.RankAssistantRepliesTask) - logger.info("sending rank assistant reply task") + logger.debug("sending rank assistant reply task") embed = _rank_assistant_reply_embed(task) elif task.type == TaskRequestType.user_reply: assert isinstance(task, protocol_schema.UserReplyTask) - logger.info("sending user reply task") + logger.debug("sending user reply task") embed = _user_reply_embed(task) elif task.type == TaskRequestType.assistant_reply: assert isinstance(task, protocol_schema.AssistantReplyTask) - logger.info("sending assistant reply task") + logger.debug("sending assistant reply task") embed = _assistant_reply_embed(task) elif task.type == TaskRequestType.summarize_story: @@ -234,7 +232,7 @@ async def _send_task( raise NotImplementedError else: - logger.error(f"unknown task type {task.type}") + logger.critical(f"unknown task type {task.type}") raise ValueError(f"unknown task type {task.type}") view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME) @@ -279,7 +277,7 @@ def _validate_user_input(content: str | None, task_type: str) -> bool: raise NotImplementedError else: - logger.fatal(f"Unknown task type {task_type}") + logger.critical(f"Unknown task type {task_type}") raise ValueError(f"Unknown task type {task_type}")