add managed_tx_function() decorator and use it for startup db-calls

This commit is contained in:
Andreas Köpf
2023-01-16 09:16:28 +01:00
parent 35887decfb
commit 7f562fbbae
3 changed files with 134 additions and 91 deletions
+78 -80
View File
@@ -20,6 +20,7 @@ from oasst_backend.models import message_tree_state
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
from oasst_backend.tree_manager import TreeManager
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from pydantic import BaseModel
@@ -120,7 +121,8 @@ if settings.RATE_LIMIT:
if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
def seed_data():
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def create_seed_data(session: Session):
class DummyMessage(BaseModel):
task_message_id: str
user_message_id: str
@@ -134,73 +136,73 @@ if settings.DEBUG_USE_SEED_DATA:
try:
logger.info("Seed data check began")
with Session(engine) as db:
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
ur = UserRepository(db=db, api_client=api_client)
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
tm = TreeManager(db, pr)
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=session)
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
ur = UserRepository(db=session, api_client=api_client)
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
pr = PromptRepository(
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
)
tm = TreeManager(session, pr)
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
dummy_messages_raw = json.load(f)
for msg in dummy_messages:
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
db.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
conversation_messages = pr.fetch_message_conversation(parent_message)
conversation = prepare_conversation(conversation_messages)
if msg.role == "assistant":
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
else:
task = tr.store_task(
protocol_schema.PrompterReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(
msg.text,
msg.task_message_id,
msg.user_message_id,
review_count=5,
review_result=True,
check_tree_state=False,
)
if message.parent_id is None:
tm._insert_default_state(
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
)
db.commit()
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
logger.info(
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
for msg in dummy_messages:
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
if task and not task.ack:
logger.warning("Deleting unacknowledged seed data task")
session.delete(task)
task = None
if not task:
if msg.parent_message_id is None:
task = tr.store_task(
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
)
else:
logger.debug(f"seed data task found: {task.id}")
parent_message = pr.fetch_message_by_frontend_message_id(
msg.parent_message_id, fail_if_missing=True
)
conversation_messages = pr.fetch_message_conversation(parent_message)
conversation = prepare_conversation(conversation_messages)
if msg.role == "assistant":
task = tr.store_task(
protocol_schema.AssistantReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
else:
task = tr.store_task(
protocol_schema.PrompterReplyTask(conversation=conversation),
message_tree_id=parent_message.message_tree_id,
parent_message_id=parent_message.id,
)
tr.bind_frontend_message_id(task.id, msg.task_message_id)
message = pr.store_text_reply(
msg.text,
msg.task_message_id,
msg.user_message_id,
review_count=5,
review_result=True,
check_tree_state=False,
)
if message.parent_id is None:
tm._insert_default_state(
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
)
session.flush()
logger.info("Seed data check completed")
logger.info(
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
)
else:
logger.debug(f"seed data task found: {task.id}")
logger.info("Seed data check completed")
except Exception:
logger.exception("Seed data insertion failed")
@@ -220,48 +222,44 @@ def ensure_tree_states():
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
def update_leader_board_day() -> None:
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def update_leader_board_day(session: Session) -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.day)
session.commit()
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.day)
except Exception:
logger.exception("Error during leaderboard update (daily)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
def update_leader_board_week() -> None:
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def update_leader_board_week(session: Session) -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.week)
session.commit()
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.week)
except Exception:
logger.exception("Error during user states update (weekly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
def update_leader_board_month() -> None:
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def update_leader_board_month(session: Session) -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.month)
session.commit()
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.month)
except Exception:
logger.exception("Error during user states update (monthly)")
@app.on_event("startup")
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
def update_leader_board_total() -> None:
@managed_tx_function(auto_commit=CommitMode.COMMIT)
def update_leader_board_total(session: Session) -> None:
try:
with Session(engine) as session:
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.total)
session.commit()
usr = UserStatsRepository(session)
usr.update_stats(time_frame=UserStatsTimeFrame.total)
except Exception:
logger.exception("Error during user states update (total)")
+1
View File
@@ -67,6 +67,7 @@ class Settings(BaseSettings):
POSTGRES_PASSWORD: str = "postgres"
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
DATABASE_MAX_TX_RETRY_COUNT: int = 3
RATE_LIMIT: bool = True
REDIS_HOST: str = "localhost"
+55 -11
View File
@@ -1,13 +1,14 @@
from enum import IntEnum
from functools import wraps
from http import HTTPStatus
from typing import Callable
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlalchemy.exc import OperationalError
from sqlmodel import SQLModel
MAX_DB_RETRY_COUNT = 3
from sqlmodel import Session, SQLModel
class CommitMode(IntEnum):
@@ -28,7 +29,7 @@ class CommitMode(IntEnum):
"""
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
def decorator(f):
@wraps(f)
def wrapped_f(self, *args, **kwargs):
@@ -44,16 +45,15 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
pass
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("Db Rollback Failure")
logger.error("DB Rollback Failure")
raise e
return wrapped_f
@@ -61,7 +61,9 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M
return decorator
def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
def async_managed_tx_method(
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
):
def decorator(f):
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
@@ -77,16 +79,58 @@ def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_ret
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
self.db.rollback()
pass
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("Db Rollback Failure")
logger.exception("DB Rollback Failure")
raise e
return wrapped_f
return decorator
def default_session_factor() -> Session:
return Session(engine)
def managed_tx_function(
auto_commit: CommitMode = CommitMode.COMMIT,
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
session_factory: Callable[..., Session] = default_session_factor,
):
"""Passes Session object as first argument to wrapped function."""
def decorator(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
try:
for i in range(num_retries):
with session_factory() as session:
try:
result = f(session, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
session.commit()
elif auto_commit == CommitMode.FLUSH:
session.flush()
if isinstance(result, SQLModel):
session.refresh(result)
return result
except OperationalError:
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
session.rollback()
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("DB Rollback Failure")
raise e
return wrapped_f