From 7f562fbbae2af75b67180aa3246ebe8243dfc5bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Mon, 16 Jan 2023 09:16:28 +0100 Subject: [PATCH] add managed_tx_function() decorator and use it for startup db-calls --- backend/main.py | 158 +++++++++--------- backend/oasst_backend/config.py | 1 + backend/oasst_backend/utils/database_utils.py | 66 ++++++-- 3 files changed, 134 insertions(+), 91 deletions(-) diff --git a/backend/main.py b/backend/main.py index 9c900f8b..ce316f41 100644 --- a/backend/main.py +++ b/backend/main.py @@ -20,6 +20,7 @@ from oasst_backend.models import message_tree_state from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame +from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from pydantic import BaseModel @@ -120,7 +121,8 @@ if settings.RATE_LIMIT: if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") - def seed_data(): + @managed_tx_function(auto_commit=CommitMode.COMMIT) + def create_seed_data(session: Session): class DummyMessage(BaseModel): task_message_id: str user_message_id: str @@ -134,73 +136,73 @@ if settings.DEBUG_USE_SEED_DATA: try: logger.info("Seed data check began") - with Session(engine) as db: - api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db) - dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - ur = UserRepository(db=db, api_client=api_client) - tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur) - pr = PromptRepository( - db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr - ) - tm = TreeManager(db, pr) + api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=session) + dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") - with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: - dummy_messages_raw = json.load(f) + ur = UserRepository(db=session, api_client=api_client) + tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur) + pr = PromptRepository( + db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr + ) + tm = TreeManager(session, pr) - dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] + with open(settings.DEBUG_USE_SEED_DATA_PATH) as f: + dummy_messages_raw = json.load(f) - for msg in dummy_messages: - task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) - if task and not task.ack: - logger.warning("Deleting unacknowledged seed data task") - db.delete(task) - task = None - if not task: - if msg.parent_message_id is None: - task = tr.store_task( - protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None - ) - else: - parent_message = pr.fetch_message_by_frontend_message_id( - msg.parent_message_id, fail_if_missing=True - ) - conversation_messages = pr.fetch_message_conversation(parent_message) - conversation = prepare_conversation(conversation_messages) - if msg.role == "assistant": - task = tr.store_task( - protocol_schema.AssistantReplyTask(conversation=conversation), - message_tree_id=parent_message.message_tree_id, - parent_message_id=parent_message.id, - ) - else: - task = tr.store_task( - protocol_schema.PrompterReplyTask(conversation=conversation), - message_tree_id=parent_message.message_tree_id, - parent_message_id=parent_message.id, - ) - tr.bind_frontend_message_id(task.id, msg.task_message_id) - message = pr.store_text_reply( - msg.text, - msg.task_message_id, - msg.user_message_id, - review_count=5, - review_result=True, - check_tree_state=False, - ) - if message.parent_id is None: - tm._insert_default_state( - root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING - ) - db.commit() + dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw] - logger.info( - f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" + for msg in dummy_messages: + task = tr.fetch_task_by_frontend_message_id(msg.task_message_id) + if task and not task.ack: + logger.warning("Deleting unacknowledged seed data task") + session.delete(task) + task = None + if not task: + if msg.parent_message_id is None: + task = tr.store_task( + protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: - logger.debug(f"seed data task found: {task.id}") + parent_message = pr.fetch_message_by_frontend_message_id( + msg.parent_message_id, fail_if_missing=True + ) + conversation_messages = pr.fetch_message_conversation(parent_message) + conversation = prepare_conversation(conversation_messages) + if msg.role == "assistant": + task = tr.store_task( + protocol_schema.AssistantReplyTask(conversation=conversation), + message_tree_id=parent_message.message_tree_id, + parent_message_id=parent_message.id, + ) + else: + task = tr.store_task( + protocol_schema.PrompterReplyTask(conversation=conversation), + message_tree_id=parent_message.message_tree_id, + parent_message_id=parent_message.id, + ) + tr.bind_frontend_message_id(task.id, msg.task_message_id) + message = pr.store_text_reply( + msg.text, + msg.task_message_id, + msg.user_message_id, + review_count=5, + review_result=True, + check_tree_state=False, + ) + if message.parent_id is None: + tm._insert_default_state( + root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING + ) + session.flush() - logger.info("Seed data check completed") + logger.info( + f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" + ) + else: + logger.debug(f"seed data task found: {task.id}") + + logger.info("Seed data check completed") except Exception: logger.exception("Seed data insertion failed") @@ -220,48 +222,44 @@ def ensure_tree_states(): @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False) -def update_leader_board_day() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_day(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.day) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.day) except Exception: logger.exception("Error during leaderboard update (daily)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False) -def update_leader_board_week() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_week(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.week) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.week) except Exception: logger.exception("Error during user states update (weekly)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False) -def update_leader_board_month() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_month(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.month) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.month) except Exception: logger.exception("Error during user states update (monthly)") @app.on_event("startup") @repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False) -def update_leader_board_total() -> None: +@managed_tx_function(auto_commit=CommitMode.COMMIT) +def update_leader_board_total(session: Session) -> None: try: - with Session(engine) as session: - usr = UserStatsRepository(session) - usr.update_stats(time_frame=UserStatsTimeFrame.total) - session.commit() + usr = UserStatsRepository(session) + usr.update_stats(time_frame=UserStatsTimeFrame.total) except Exception: logger.exception("Error during user states update (total)") diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 2a8a3a7c..71a36160 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -67,6 +67,7 @@ class Settings(BaseSettings): POSTGRES_PASSWORD: str = "postgres" POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + DATABASE_MAX_TX_RETRY_COUNT: int = 3 RATE_LIMIT: bool = True REDIS_HOST: str = "localhost" diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index b2d92127..d378c6a6 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -1,13 +1,14 @@ from enum import IntEnum from functools import wraps from http import HTTPStatus +from typing import Callable from loguru import logger +from oasst_backend.config import settings +from oasst_backend.database import engine from oasst_shared.exceptions import OasstError, OasstErrorCode from sqlalchemy.exc import OperationalError -from sqlmodel import SQLModel - -MAX_DB_RETRY_COUNT = 3 +from sqlmodel import Session, SQLModel class CommitMode(IntEnum): @@ -28,7 +29,7 @@ class CommitMode(IntEnum): """ -def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT): +def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT): def decorator(f): @wraps(f) def wrapped_f(self, *args, **kwargs): @@ -44,16 +45,15 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M self.db.refresh(result) return result except OperationalError: - logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict") + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") self.db.rollback() - pass raise OasstError( "DATABASE_MAX_RETIRES_EXHAUSTED", error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, ) except Exception as e: - logger.error("Db Rollback Failure") + logger.error("DB Rollback Failure") raise e return wrapped_f @@ -61,7 +61,9 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M return decorator -def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT): +def async_managed_tx_method( + auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT +): def decorator(f): @wraps(f) async def wrapped_f(self, *args, **kwargs): @@ -77,16 +79,58 @@ def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_ret self.db.refresh(result) return result except OperationalError: - logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict") + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") self.db.rollback() - pass raise OasstError( "DATABASE_MAX_RETIRES_EXHAUSTED", error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, ) except Exception as e: - logger.error("Db Rollback Failure") + logger.exception("DB Rollback Failure") + raise e + + return wrapped_f + + return decorator + + +def default_session_factor() -> Session: + return Session(engine) + + +def managed_tx_function( + auto_commit: CommitMode = CommitMode.COMMIT, + num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, + session_factory: Callable[..., Session] = default_session_factor, +): + """Passes Session object as first argument to wrapped function.""" + + def decorator(f): + @wraps(f) + def wrapped_f(*args, **kwargs): + try: + for i in range(num_retries): + with session_factory() as session: + try: + result = f(session, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + session.commit() + elif auto_commit == CommitMode.FLUSH: + session.flush() + if isinstance(result, SQLModel): + session.refresh(result) + return result + except OperationalError: + logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.") + session.rollback() + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.error("DB Rollback Failure") raise e return wrapped_f