mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
add managed_tx_function() decorator and use it for startup db-calls
This commit is contained in:
+78
-80
@@ -20,6 +20,7 @@ from oasst_backend.models import message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository, TaskRepository, UserRepository
|
||||
from oasst_backend.tree_manager import TreeManager
|
||||
from oasst_backend.user_stats_repository import UserStatsRepository, UserStatsTimeFrame
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_function
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from pydantic import BaseModel
|
||||
@@ -120,7 +121,8 @@ if settings.RATE_LIMIT:
|
||||
if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
def seed_data():
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def create_seed_data(session: Session):
|
||||
class DummyMessage(BaseModel):
|
||||
task_message_id: str
|
||||
user_message_id: str
|
||||
@@ -134,73 +136,73 @@ if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
try:
|
||||
logger.info("Seed data check began")
|
||||
with Session(engine) as db:
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=db)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
ur = UserRepository(db=db, api_client=api_client)
|
||||
tr = TaskRepository(db=db, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=db, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(db, pr)
|
||||
api_client = api_auth(settings.OFFICIAL_WEB_API_KEY, db=session)
|
||||
dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local")
|
||||
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
ur = UserRepository(db=session, api_client=api_client)
|
||||
tr = TaskRepository(db=session, api_client=api_client, client_user=dummy_user, user_repository=ur)
|
||||
pr = PromptRepository(
|
||||
db=session, api_client=api_client, client_user=dummy_user, user_repository=ur, task_repository=tr
|
||||
)
|
||||
tm = TreeManager(session, pr)
|
||||
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
with open(settings.DEBUG_USE_SEED_DATA_PATH) as f:
|
||||
dummy_messages_raw = json.load(f)
|
||||
|
||||
for msg in dummy_messages:
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
db.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = prepare_conversation(conversation_messages)
|
||||
if msg.role == "assistant":
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
else:
|
||||
task = tr.store_task(
|
||||
protocol_schema.PrompterReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text,
|
||||
msg.task_message_id,
|
||||
msg.user_message_id,
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
|
||||
)
|
||||
db.commit()
|
||||
dummy_messages = [DummyMessage(**dm) for dm in dummy_messages_raw]
|
||||
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
for msg in dummy_messages:
|
||||
task = tr.fetch_task_by_frontend_message_id(msg.task_message_id)
|
||||
if task and not task.ack:
|
||||
logger.warning("Deleting unacknowledged seed data task")
|
||||
session.delete(task)
|
||||
task = None
|
||||
if not task:
|
||||
if msg.parent_message_id is None:
|
||||
task = tr.store_task(
|
||||
protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
parent_message = pr.fetch_message_by_frontend_message_id(
|
||||
msg.parent_message_id, fail_if_missing=True
|
||||
)
|
||||
conversation_messages = pr.fetch_message_conversation(parent_message)
|
||||
conversation = prepare_conversation(conversation_messages)
|
||||
if msg.role == "assistant":
|
||||
task = tr.store_task(
|
||||
protocol_schema.AssistantReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
else:
|
||||
task = tr.store_task(
|
||||
protocol_schema.PrompterReplyTask(conversation=conversation),
|
||||
message_tree_id=parent_message.message_tree_id,
|
||||
parent_message_id=parent_message.id,
|
||||
)
|
||||
tr.bind_frontend_message_id(task.id, msg.task_message_id)
|
||||
message = pr.store_text_reply(
|
||||
msg.text,
|
||||
msg.task_message_id,
|
||||
msg.user_message_id,
|
||||
review_count=5,
|
||||
review_result=True,
|
||||
check_tree_state=False,
|
||||
)
|
||||
if message.parent_id is None:
|
||||
tm._insert_default_state(
|
||||
root_message_id=message.id, state=msg.tree_state or message_tree_state.State.GROWING
|
||||
)
|
||||
session.flush()
|
||||
|
||||
logger.info("Seed data check completed")
|
||||
logger.info(
|
||||
f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"seed data task found: {task.id}")
|
||||
|
||||
logger.info("Seed data check completed")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Seed data insertion failed")
|
||||
@@ -220,48 +222,44 @@ def ensure_tree_states():
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_DAY, wait_first=False)
|
||||
def update_leader_board_day() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_day(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.day)
|
||||
except Exception:
|
||||
logger.exception("Error during leaderboard update (daily)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_WEEK, wait_first=False)
|
||||
def update_leader_board_week() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_week(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.week)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (weekly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_MONTH, wait_first=False)
|
||||
def update_leader_board_month() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_month(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.month)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (monthly)")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
@repeat_every(seconds=60 * settings.USER_STATS_INTERVAL_TOTAL, wait_first=False)
|
||||
def update_leader_board_total() -> None:
|
||||
@managed_tx_function(auto_commit=CommitMode.COMMIT)
|
||||
def update_leader_board_total(session: Session) -> None:
|
||||
try:
|
||||
with Session(engine) as session:
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
session.commit()
|
||||
usr = UserStatsRepository(session)
|
||||
usr.update_stats(time_frame=UserStatsTimeFrame.total)
|
||||
except Exception:
|
||||
logger.exception("Error during user states update (total)")
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ class Settings(BaseSettings):
|
||||
POSTGRES_PASSWORD: str = "postgres"
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
DATABASE_MAX_TX_RETRY_COUNT: int = 3
|
||||
|
||||
RATE_LIMIT: bool = True
|
||||
REDIS_HOST: str = "localhost"
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from http import HTTPStatus
|
||||
from typing import Callable
|
||||
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
MAX_DB_RETRY_COUNT = 3
|
||||
from sqlmodel import Session, SQLModel
|
||||
|
||||
|
||||
class CommitMode(IntEnum):
|
||||
@@ -28,7 +29,7 @@ class CommitMode(IntEnum):
|
||||
"""
|
||||
|
||||
|
||||
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
|
||||
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
@@ -44,16 +45,15 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
self.db.rollback()
|
||||
pass
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Db Rollback Failure")
|
||||
logger.error("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
@@ -61,7 +61,9 @@ def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=M
|
||||
return decorator
|
||||
|
||||
|
||||
def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
|
||||
def async_managed_tx_method(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT, num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT
|
||||
):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapped_f(self, *args, **kwargs):
|
||||
@@ -77,16 +79,58 @@ def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_ret
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
self.db.rollback()
|
||||
pass
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Db Rollback Failure")
|
||||
logger.exception("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def default_session_factor() -> Session:
|
||||
return Session(engine)
|
||||
|
||||
|
||||
def managed_tx_function(
|
||||
auto_commit: CommitMode = CommitMode.COMMIT,
|
||||
num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT,
|
||||
session_factory: Callable[..., Session] = default_session_factor,
|
||||
):
|
||||
"""Passes Session object as first argument to wrapped function."""
|
||||
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(*args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
with session_factory() as session:
|
||||
try:
|
||||
result = f(session, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
session.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
session.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
session.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retry {i+1}/{num_retries} after possible DB concurrent update conflict.")
|
||||
session.rollback()
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("DB Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
Reference in New Issue
Block a user