diff --git a/backend/oasst_backend/database.py b/backend/oasst_backend/database.py index b160da61..9632b9da 100644 --- a/backend/oasst_backend/database.py +++ b/backend/oasst_backend/database.py @@ -5,4 +5,4 @@ from sqlmodel import create_engine if settings.DATABASE_URI is None: raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET) -engine = create_engine(settings.DATABASE_URI) +engine = create_engine(settings.DATABASE_URI, echo=True, isolation_level="REPEATABLE READ") diff --git a/backend/oasst_backend/journal_writer.py b/backend/oasst_backend/journal_writer.py index 67892ded..b39b498d 100644 --- a/backend/oasst_backend/journal_writer.py +++ b/backend/oasst_backend/journal_writer.py @@ -4,6 +4,7 @@ from uuid import UUID from oasst_backend.models import ApiClient, Journal, Task, User from oasst_backend.models.payload_column_type import PayloadContainer, payload_type +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.utils import utcnow from pydantic import BaseModel from sqlmodel import Session @@ -80,6 +81,7 @@ class JournalWriter: message_id=message_id, ) + @managed_tx_method(CommitMode.FLUSH) def log( self, *, @@ -115,7 +117,4 @@ class JournalWriter: ) self.db.add(entry) - if commit: - self.db.commit() - return entry diff --git a/backend/oasst_backend/prompt_repository.py b/backend/oasst_backend/prompt_repository.py index 6483cdc2..7f51ea19 100644 --- a/backend/oasst_backend/prompt_repository.py +++ b/backend/oasst_backend/prompt_repository.py @@ -24,6 +24,7 @@ from oasst_backend.models import ( from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import SystemStats @@ -67,6 +68,7 @@ class PromptRepository: ) return message + @managed_tx_method(CommitMode.FLUSH) def insert_message( self, *, @@ -104,8 +106,8 @@ class PromptRepository: review_result=review_result, ) self.db.add(message) - self.db.commit() - self.db.refresh(message) + + # self.db.refresh(message) return message def _validate_task( @@ -134,6 +136,7 @@ class PromptRepository: def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState: return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one() + @managed_tx_method(CommitMode.FLUSH) def store_text_reply( self, text: str, @@ -205,10 +208,10 @@ class PromptRepository: if not task.collective: task.done = True self.db.add(task) - self.db.commit() self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text)) return user_message + @managed_tx_method(CommitMode.FLUSH) def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction: message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True) @@ -238,6 +241,7 @@ class PromptRepository: logger.info(f"Ranking {rating.rating} stored for task {task.id}.") return reaction + @managed_tx_method(CommitMode.COMMIT) def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]: # fetch task task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id) @@ -310,6 +314,7 @@ class PromptRepository: return reaction, task + @managed_tx_method(CommitMode.FLUSH) def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity: """Save the toxicity score of a new message in the database. Args: @@ -325,10 +330,9 @@ class PromptRepository: message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label) self.db.add(message_toxicity) - self.db.commit() - self.db.refresh(message_toxicity) return message_toxicity + @managed_tx_method(CommitMode.FLUSH) def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding: """Insert the embedding of a new message in the database. @@ -346,10 +350,9 @@ class PromptRepository: message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding) self.db.add(message_embedding) - self.db.commit() - self.db.refresh(message_embedding) return message_embedding + @managed_tx_method(CommitMode.FLUSH) def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction: if self.user_id is None: raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED) @@ -363,10 +366,9 @@ class PromptRepository: payload_type=type(payload).__name__, ) self.db.add(reaction) - self.db.commit() - self.db.refresh(reaction) return reaction + @managed_tx_method(CommitMode.FLUSH) def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]: valid_labels: Optional[list[str]] = None @@ -436,8 +438,6 @@ class PromptRepository: self.db.add(message) self.db.add(model) - self.db.commit() - self.db.refresh(model) return model, task, message def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]: @@ -702,6 +702,7 @@ class PromptRepository: return messages.all() + @managed_tx_method(CommitMode.COMMIT) def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True): """ Marks deleted messages and all their descendants. @@ -730,8 +731,6 @@ class PromptRepository: parent_ids = self.db.execute(query).scalars().all() - self.db.commit() - def get_stats(self) -> SystemStats: """ Get data stats such as number of all messages in the system, diff --git a/backend/oasst_backend/task_repository.py b/backend/oasst_backend/task_repository.py index acf48182..eb100fe3 100644 --- a/backend/oasst_backend/task_repository.py +++ b/backend/oasst_backend/task_repository.py @@ -6,6 +6,7 @@ from loguru import logger from oasst_backend.models import ApiClient, Task from oasst_backend.models.payload_column_type import PayloadContainer from oasst_backend.user_repository import UserRepository +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -128,6 +129,7 @@ class TaskRepository: assert task_model.id == task.id return task_model + @managed_tx_method(CommitMode.COMMIT) def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str): validate_frontend_message_id(frontend_message_id) @@ -142,10 +144,9 @@ class TaskRepository: task.frontend_message_id = frontend_message_id task.ack = True - # ToDo: check race-condition, transaction self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False): """ Mark task as done. No further messages will be accepted for this task. @@ -166,8 +167,8 @@ class TaskRepository: task.done = True self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def acknowledge_task_failure(self, task_id): # find task task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first() @@ -181,8 +182,8 @@ class TaskRepository: task.ack = False # ToDo: check race-condition, transaction self.db.add(task) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def insert_task( self, payload: db_payload.TaskPayload, @@ -204,8 +205,6 @@ class TaskRepository: ) logger.debug(f"inserting {task=}") self.db.add(task) - self.db.commit() - self.db.refresh(task) return task def fetch_task_by_frontend_message_id(self, message_id: str) -> Task: diff --git a/backend/oasst_backend/tree_manager.py b/backend/oasst_backend/tree_manager.py index a9a282e2..2d8a7f4d 100644 --- a/backend/oasst_backend/tree_manager.py +++ b/backend/oasst_backend/tree_manager.py @@ -11,6 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio from oasst_backend.config import TreeManagerConfiguration, settings from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema @@ -340,6 +341,7 @@ class TreeManager: return task, message_tree_id, parent_message_id + @async_managed_tx_method(CommitMode.COMMIT) async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task: pr = self.pr match type(interaction): @@ -358,7 +360,6 @@ class TreeManager: if not message.parent_id: logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}") self._insert_default_state(message.id) - self.db.commit() if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION: try: @@ -428,7 +429,6 @@ class TreeManager: if acceptance_score > self.cfg.acceptance_threshold_initial_prompt: msg.review_result = True self.db.add(msg) - self.db.commit() logger.info( f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" ) @@ -439,7 +439,6 @@ class TreeManager: if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply: msg.review_result = True self.db.add(msg) - self.db.commit() logger.info( f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}" ) @@ -451,6 +450,7 @@ class TreeManager: return protocol_schema.TaskDone() + @managed_tx_method(CommitMode.FLUSH) def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State): assert mts and mts.active @@ -460,7 +460,6 @@ class TreeManager: mts.active = False mts.state = state.value self.db.add(mts) - self.db.commit() if is_terminal: logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})") @@ -472,6 +471,7 @@ class TreeManager: mts = self.pr.fetch_tree_state(message_tree_id) self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE) + @managed_tx_method(CommitMode.COMMIT) def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_growing_state({message_tree_id=})") @@ -489,6 +489,7 @@ class TreeManager: self._enter_state(mts, message_tree_state.State.GROWING) return True + @managed_tx_method(CommitMode.COMMIT) def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool: logger.debug(f"check_condition_for_ranking_state({message_tree_id=})") @@ -735,6 +736,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin rankings_by_message[parent_id].append(MessageReaction.from_orm(x)) return rankings_by_message + @managed_tx_method(CommitMode.COMMIT) def ensure_tree_states(self): """Add message tree state rows for all root nodes (inital prompt messages).""" @@ -746,7 +748,6 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin state = message_tree_state.State.GROWING logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})") self._insert_default_state(id, state=state) - self.db.commit() def query_num_active_trees(self) -> int: query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active) @@ -763,6 +764,7 @@ WHERE t.done = TRUE r = self.db.execute(text(sql_qry), {"message_id": message_id}) return [TextLabels.from_orm(x) for x in r.all()] + @managed_tx_method(CommitMode.FLUSH) def _insert_tree_state( self, root_message_id: UUID, @@ -784,6 +786,7 @@ WHERE t.done = TRUE self.db.add(model) return model + @managed_tx_method(CommitMode.FLUSH) def _insert_default_state( self, root_message_id: UUID, diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 8d8a96d5..578dc5f1 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -2,6 +2,7 @@ from typing import Optional from uuid import UUID from oasst_backend.models import ApiClient, User +from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -62,6 +63,7 @@ class UserRepository: return user + @managed_tx_method(CommitMode.COMMIT) def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None: """ Update a user by global user ID to disable or set admin notes. Only trusted clients may update users. @@ -83,8 +85,8 @@ class UserRepository: user.notes = notes self.db.add(user) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def mark_user_deleted(self, id: UUID) -> None: """ Update a user by global user ID to set deleted flag. Only trusted clients may delete users. @@ -103,8 +105,8 @@ class UserRepository: user.deleted = True self.db.add(user) - self.db.commit() + @managed_tx_method(CommitMode.COMMIT) def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: if not client_user: return None @@ -127,13 +129,10 @@ class UserRepository: auth_method=client_user.auth_method, ) self.db.add(user) - self.db.commit() - self.db.refresh(user) elif client_user.display_name and client_user.display_name != user.display_name: # we found the user but the display name changed user.display_name = client_user.display_name self.db.add(user) - self.db.commit() return user def query_users( diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py new file mode 100644 index 00000000..b2d92127 --- /dev/null +++ b/backend/oasst_backend/utils/database_utils.py @@ -0,0 +1,94 @@ +from enum import IntEnum +from functools import wraps +from http import HTTPStatus + +from loguru import logger +from oasst_shared.exceptions import OasstError, OasstErrorCode +from sqlalchemy.exc import OperationalError +from sqlmodel import SQLModel + +MAX_DB_RETRY_COUNT = 3 + + +class CommitMode(IntEnum): + """ + Commit modes for the managed tx methods + """ + + NONE = 0 + FLUSH = 1 + COMMIT = 2 + + +""" +* managed_tx_method and async_managed_tx_method methods are decorators functions +* to be used on class functions. It expects the Class to have a 'db' Session object +* initialised +* TODO: tx method decorator for non class methods +""" + + +def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT): + def decorator(f): + @wraps(f) + def wrapped_f(self, *args, **kwargs): + try: + for i in range(num_retries): + try: + result = f(self, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + self.db.commit() + elif auto_commit == CommitMode.FLUSH: + self.db.flush() + if isinstance(result, SQLModel): + self.db.refresh(result) + return result + except OperationalError: + logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict") + self.db.rollback() + pass + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.error("Db Rollback Failure") + raise e + + return wrapped_f + + return decorator + + +def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT): + def decorator(f): + @wraps(f) + async def wrapped_f(self, *args, **kwargs): + try: + for i in range(num_retries): + try: + result = await f(self, *args, **kwargs) + if auto_commit == CommitMode.COMMIT: + self.db.commit() + elif auto_commit == CommitMode.FLUSH: + self.db.flush() + if isinstance(result, SQLModel): + self.db.refresh(result) + return result + except OperationalError: + logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict") + self.db.rollback() + pass + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + except Exception as e: + logger.error("Db Rollback Failure") + raise e + + return wrapped_f + + return decorator diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 31ba00f6..7ad0b65e 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -18,6 +18,7 @@ class OasstErrorCode(IntEnum): DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 ROOT_TOKEN_NOT_AUTHORIZED = 3 + DATABASE_MAX_RETRIES_EXHAUSTED = 4 TOO_MANY_REQUESTS = 429 SERVER_ERROR0 = 500