Files
Open-Assistant/backend/oasst_backend/utils/database_utils.py
T
2023-01-27 19:44:48 +01:00

206 lines
8.3 KiB
Python

from enum import IntEnum
from functools import wraps
from http import HTTPStatus
from typing import Callable
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_shared.exceptions import OasstError, OasstErrorCode
from psycopg2.errors import DeadlockDetected, ExclusionViolation, SerializationFailure, UniqueViolation
from sqlalchemy.exc import OperationalError, PendingRollbackError
from sqlmodel import Session, SQLModel
"""
Error Handling Reference: https://www.postgresql.org/docs/15/mvcc-serialization-failure-handling.html
"""
class CommitMode(IntEnum):
"""
Commit modes for the managed tx methods
"""
NONE = 0
FLUSH = 1
COMMIT = 2
ROLLBACK = 3
"""
* managed_tx_method and async_managed_tx_method methods are decorators functions
* to be used on class functions. It expects the Class to have a 'db' Session object
* initialised
"""
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
def decorator(f):
@wraps(f)
def wrapped_f(self, *args, **kwargs):
try:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
self.db.commit()
if isinstance(result, SQLModel):
self.db.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
self.db.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
self.db.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
return result
except Exception as e:
logger.info(str(e))
raise e
return wrapped_f
return decorator
def async_managed_tx_method(
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
):
def decorator(f):
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
try:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
self.db.commit()
if isinstance(result, SQLModel):
self.db.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
self.db.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig, (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation)
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
self.db.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
self.db.rollback()
return result
except Exception as e:
logger.info(str(e))
raise e
return wrapped_f
return decorator
def default_session_factor() -> Session:
return Session(engine)
def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
):
"""Passes Session object as first argument to wrapped function."""
def decorator(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
try:
result = None
if auto_commit == CommitMode.COMMIT:
retry_exhausted = True
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
session.commit()
if isinstance(result, SQLModel):
session.refresh(result)
retry_exhausted = False
break
except PendingRollbackError as e:
logger.info(str(e))
session.rollback()
except OperationalError as e:
if e.orig is not None and isinstance(
e.orig,
(SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation),
):
logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}")
session.rollback()
else:
raise e
logger.info(f"Retry {i+1}/{num_retries}")
if retry_exhausted:
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
else:
with session_factory() as session:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
elif auto_commit == CommitMode.ROLLBACK:
session.rollback()
return result
except Exception as e:
logger.info(str(e))
raise e
return wrapped_f
return decorator