Files
Open-Assistant/backend/oasst_backend/utils/database_utils.py
T

139 lines
4.9 KiB
Python

from enum import IntEnum
from functools import wraps
from http import HTTPStatus
from typing import Callable
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlalchemy.exc import OperationalError
from sqlmodel import Session, SQLModel
class CommitMode(IntEnum):
"""
Commit modes for the managed tx methods
"""
NONE = 0
FLUSH = 1
COMMIT = 2
"""
* managed_tx_method and async_managed_tx_method methods are decorators functions
* to be used on class functions. It expects the Class to have a 'db' Session object
* initialised
* TODO: tx method decorator for non class methods
"""
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
def decorator(f):
@wraps(f)
def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("DB Rollback Failure")
raise e
return wrapped_f
return decorator
def async_managed_tx_method(
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
):
def decorator(f):
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.exception("DB Rollback Failure")
raise e
return wrapped_f
return decorator
def default_session_factor() -> Session:
return Session(engine)
def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
):
"""Passes Session object as first argument to wrapped function."""
def decorator(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
try:
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
session.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("DB Rollback Failure")
raise e
return wrapped_f
return decorator