switch to loguru

This commit is contained in:
Alex Ott
2022-12-30 05:28:51 -08:00
parent 708011e6a0
commit 6cccd74e34
5 changed files with 26 additions and 20 deletions
@@ -5,6 +5,7 @@ import lightbulb
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import mention
from loguru import logger
plugin = lightbulb.Plugin("GuildSettings")
plugin.add_checks(lightbulb.guild_only)
@@ -34,6 +35,7 @@ async def get(ctx: lightbulb.SlashContext) -> None:
row = await cursor.fetchone()
if row is None:
logger.warning(f"No guild settings for {ctx.guild_id}")
await ctx.respond("No settings found for this guild.")
return
@@ -70,6 +72,7 @@ async def log_channel(ctx: lightbulb.SlashContext) -> None:
)
await conn.commit()
logger.info(f"Updated `log_channel` for {ctx.guild_id} to {channel.id}.")
def load(bot: lightbulb.BotApp):
+4 -1
View File
@@ -4,6 +4,7 @@ from glob import glob
import hikari
import lightbulb
from loguru import logger
plugin = lightbulb.Plugin(
"HotReloadPlugin",
@@ -37,7 +38,7 @@ async def _plugin_autocomplete(option: hikari.CommandInteractionOption, _: hikar
required=False,
default=None,
)
@lightbulb.command("reload", "Reload a plugin")
@lightbulb.command("reload", "Reload a plugin", ephemeral=True)
@lightbulb.implements(lightbulb.SlashCommand)
async def reload(ctx: lightbulb.SlashContext):
"""Reload a plugin or all plugins."""
@@ -45,10 +46,12 @@ async def reload(ctx: lightbulb.SlashContext):
if ctx.options.plugin is None:
ctx.bot.reload_extensions(*_get_extensions())
await ctx.respond("Reloaded all plugins.")
logger.info("Reloaded all plugins.")
# Otherwise, reload the specified plugin.
else:
ctx.bot.reload_extensions(ctx.options.plugin)
await ctx.respond(f"Reloaded `{ctx.options.plugin}`.")
logger.info(f"Reloaded `{ctx.options.plugin}`.")
def load(bot: lightbulb.BotApp):
+1 -3
View File
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
"""Task plugin for testing different data collection methods."""
# TODO: Delete this once user input method has been decided for final bot.
import asyncio
import logging
import typing as t
from datetime import datetime, timedelta
@@ -16,8 +16,6 @@ plugin = lightbulb.Plugin("TaskPlugin")
MAX_TASK_TIME = 60 * 60
MAX_TASK_ACCEPT_TIME = 60
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@plugin.command
@@ -9,6 +9,7 @@ import miru
from aiosqlite import Connection
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from loguru import logger
plugin = lightbulb.Plugin(
"TextLabels",
@@ -49,6 +50,7 @@ class LabelModal(miru.Modal):
f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.",
flags=hikari.MessageFlag.EPHEMERAL,
)
logger.info(f"Sending {self.label}=`{val}` for `{self.content.value}` (edited={edited}) to the backend.")
# Send a notification to the log channel
assert context.guild_id is not None # `guild_only` check
@@ -56,6 +58,7 @@ class LabelModal(miru.Modal):
guild_settings = await GuildSettings.from_db(conn, context.guild_id)
if guild_settings is None or guild_settings.log_channel_id is None:
logger.warning(f"No guild settings or log channel for guild {context.guild_id}")
return
embed = (
@@ -148,6 +151,7 @@ async def label_message_text(ctx: lightbulb.MessageContext):
"""Label a message."""
# We have to do some funny interaction chaining because discord only allows one component (select or modal) per interaction
# so the select menu will open the modal
msg: hikari.Message = ctx.options.target
# Exit if the message is empty
if not msg.content:
+14 -16
View File
@@ -1,7 +1,6 @@
# -*- coding: utf-8 -*-
"""Work plugin for collecting user data."""
import asyncio
import logging
import typing as t
from datetime import datetime
@@ -13,6 +12,7 @@ from aiosqlite import Connection
from bot.api_client import OasstApiClient, TaskType
from bot.db.schemas import GuildSettings
from bot.utils import EMPTY
from loguru import logger
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import TaskRequestType
@@ -21,9 +21,6 @@ plugin = lightbulb.Plugin("WorkPlugin")
MAX_TASK_TIME = 60 * 60 # 1 hour
MAX_TASK_ACCEPT_TIME = 60 # 1 minute
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@plugin.command
@lightbulb.option(
@@ -41,7 +38,7 @@ async def work(ctx: lightbulb.SlashContext):
task_type: TaskRequestType = TaskRequestType(ctx.options.type.split(".")[-1])
await ctx.respond("Sending you a task, check your DMs", flags=hikari.MessageFlag.EPHEMERAL)
logger.debug(f"task_type: {task_type!r}, task_type type {type(task_type)}")
logger.debug(f"Starting task_type: {task_type!r}")
await _handle_task(ctx, task_type)
@@ -76,6 +73,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
except asyncio.TimeoutError:
await ctx.author.send("Task timed out. Exiting")
await oasst_api.nack_task(task.id, reason="timed out")
logger.info(f"Task {task.id} timed out")
return
# Invalid response
@@ -83,7 +81,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
await ctx.author.send("Invalid response")
continue
logger.info(f"Successful user input received: {event.content}")
logger.debug(f"Successful user input received: {event.content}")
# Send the response to the backend
reply = protocol_schema.TextReplyToPost(
@@ -105,7 +103,7 @@ async def _handle_task(ctx: lightbulb.SlashContext, task_type: TaskRequestType)
completed = True
continue
else:
logger.fatal(f"Unexpected task type received: {new_task.type}")
logger.critical(f"Unexpected task type received: {new_task.type}")
# Send a message in the log channel that the task is complete
# TODO: Maybe do something with the msg ID so users can rate the "answer"
@@ -159,7 +157,7 @@ async def _select_task(
task = await oasst_api.fetch_task(task_type, user)
resp, msg_id = await _send_task(ctx, task)
logger.debug(f"user choice: {resp}")
logger.debug(f"User choice: {resp}")
match resp:
case "accept":
logger.info(f"Task {task.id} accepted, sending ACK")
@@ -200,32 +198,32 @@ async def _send_task(
# Create an embed based on the task's type
if task.type == TaskRequestType.initial_prompt:
assert isinstance(task, protocol_schema.InitialPromptTask)
logger.info("sending initial prompt task")
logger.debug("sending initial prompt task")
embed = _initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_initial_prompts:
assert isinstance(task, protocol_schema.RankInitialPromptsTask)
logger.info("sending rank initial prompt task")
logger.debug("sending rank initial prompt task")
embed = _rank_initial_prompt_embed(task)
elif task.type == TaskRequestType.rank_user_replies:
assert isinstance(task, protocol_schema.RankUserRepliesTask)
logger.info("sending rank user reply task")
logger.debug("sending rank user reply task")
embed = _rank_user_reply_embed(task)
elif task.type == TaskRequestType.rank_assistant_replies:
assert isinstance(task, protocol_schema.RankAssistantRepliesTask)
logger.info("sending rank assistant reply task")
logger.debug("sending rank assistant reply task")
embed = _rank_assistant_reply_embed(task)
elif task.type == TaskRequestType.user_reply:
assert isinstance(task, protocol_schema.UserReplyTask)
logger.info("sending user reply task")
logger.debug("sending user reply task")
embed = _user_reply_embed(task)
elif task.type == TaskRequestType.assistant_reply:
assert isinstance(task, protocol_schema.AssistantReplyTask)
logger.info("sending assistant reply task")
logger.debug("sending assistant reply task")
embed = _assistant_reply_embed(task)
elif task.type == TaskRequestType.summarize_story:
@@ -234,7 +232,7 @@ async def _send_task(
raise NotImplementedError
else:
logger.error(f"unknown task type {task.type}")
logger.critical(f"unknown task type {task.type}")
raise ValueError(f"unknown task type {task.type}")
view = TaskAcceptView(timeout=MAX_TASK_ACCEPT_TIME)
@@ -279,7 +277,7 @@ def _validate_user_input(content: str | None, task_type: str) -> bool:
raise NotImplementedError
else:
logger.fatal(f"Unknown task type {task_type}")
logger.critical(f"Unknown task type {task_type}")
raise ValueError(f"Unknown task type {task_type}")