mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-05 17:30:48 +08:00
599 add row versioning to backend tables (#710)
* fix: isolation level and nested db.commit() with retry wrappers on concurrent update errors * refactor: incorporated review comments changes decorator methods to managed_tx_method and async_managed_tx_method new enum CommitMode removed commented commit() from the previous commits * fix: merge pre-commit errors * fix: merge pre-commit changes * fix: conflict in existing OasstErrorCode * refactor: Added a refresh just to be sure that the select command is triggered on the mapped object * fix: added refresh for async decorator Co-authored-by: James Melvin <melvin@gameface.ai>
This commit is contained in:
committed by
GitHub
parent
72a58ca2d3
commit
c6fbf5543b
@@ -5,4 +5,4 @@ from sqlmodel import create_engine
|
||||
if settings.DATABASE_URI is None:
|
||||
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
|
||||
|
||||
engine = create_engine(settings.DATABASE_URI)
|
||||
engine = create_engine(settings.DATABASE_URI, echo=True, isolation_level="REPEATABLE READ")
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Journal, Task, User
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.utils import utcnow
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import Session
|
||||
@@ -80,6 +81,7 @@ class JournalWriter:
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def log(
|
||||
self,
|
||||
*,
|
||||
@@ -115,7 +117,4 @@ class JournalWriter:
|
||||
)
|
||||
|
||||
self.db.add(entry)
|
||||
if commit:
|
||||
self.db.commit()
|
||||
|
||||
return entry
|
||||
|
||||
@@ -24,6 +24,7 @@ from oasst_backend.models import (
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import SystemStats
|
||||
@@ -67,6 +68,7 @@ class PromptRepository:
|
||||
)
|
||||
return message
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_message(
|
||||
self,
|
||||
*,
|
||||
@@ -104,8 +106,8 @@ class PromptRepository:
|
||||
review_result=review_result,
|
||||
)
|
||||
self.db.add(message)
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
# self.db.refresh(message)
|
||||
return message
|
||||
|
||||
def _validate_task(
|
||||
@@ -134,6 +136,7 @@ class PromptRepository:
|
||||
def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
|
||||
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_reply(
|
||||
self,
|
||||
text: str,
|
||||
@@ -205,10 +208,10 @@ class PromptRepository:
|
||||
if not task.collective:
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
|
||||
return user_message
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
|
||||
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
|
||||
|
||||
@@ -238,6 +241,7 @@ class PromptRepository:
|
||||
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
|
||||
# fetch task
|
||||
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
|
||||
@@ -310,6 +314,7 @@ class PromptRepository:
|
||||
|
||||
return reaction, task
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity:
|
||||
"""Save the toxicity score of a new message in the database.
|
||||
Args:
|
||||
@@ -325,10 +330,9 @@ class PromptRepository:
|
||||
|
||||
message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label)
|
||||
self.db.add(message_toxicity)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_toxicity)
|
||||
return message_toxicity
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
|
||||
"""Insert the embedding of a new message in the database.
|
||||
|
||||
@@ -346,10 +350,9 @@ class PromptRepository:
|
||||
|
||||
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
|
||||
self.db.add(message_embedding)
|
||||
self.db.commit()
|
||||
self.db.refresh(message_embedding)
|
||||
return message_embedding
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
|
||||
if self.user_id is None:
|
||||
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
|
||||
@@ -363,10 +366,9 @@ class PromptRepository:
|
||||
payload_type=type(payload).__name__,
|
||||
)
|
||||
self.db.add(reaction)
|
||||
self.db.commit()
|
||||
self.db.refresh(reaction)
|
||||
return reaction
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
|
||||
|
||||
valid_labels: Optional[list[str]] = None
|
||||
@@ -436,8 +438,6 @@ class PromptRepository:
|
||||
self.db.add(message)
|
||||
|
||||
self.db.add(model)
|
||||
self.db.commit()
|
||||
self.db.refresh(model)
|
||||
return model, task, message
|
||||
|
||||
def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]:
|
||||
@@ -702,6 +702,7 @@ class PromptRepository:
|
||||
|
||||
return messages.all()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
|
||||
"""
|
||||
Marks deleted messages and all their descendants.
|
||||
@@ -730,8 +731,6 @@ class PromptRepository:
|
||||
|
||||
parent_ids = self.db.execute(query).scalars().all()
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def get_stats(self) -> SystemStats:
|
||||
"""
|
||||
Get data stats such as number of all messages in the system,
|
||||
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
from oasst_backend.models import ApiClient, Task
|
||||
from oasst_backend.models.payload_column_type import PayloadContainer
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -128,6 +129,7 @@ class TaskRepository:
|
||||
assert task_model.id == task.id
|
||||
return task_model
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
|
||||
validate_frontend_message_id(frontend_message_id)
|
||||
|
||||
@@ -142,10 +144,9 @@ class TaskRepository:
|
||||
|
||||
task.frontend_message_id = frontend_message_id
|
||||
task.ack = True
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
|
||||
"""
|
||||
Mark task as done. No further messages will be accepted for this task.
|
||||
@@ -166,8 +167,8 @@ class TaskRepository:
|
||||
|
||||
task.done = True
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def acknowledge_task_failure(self, task_id):
|
||||
# find task
|
||||
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
|
||||
@@ -181,8 +182,8 @@ class TaskRepository:
|
||||
task.ack = False
|
||||
# ToDo: check race-condition, transaction
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def insert_task(
|
||||
self,
|
||||
payload: db_payload.TaskPayload,
|
||||
@@ -204,8 +205,6 @@ class TaskRepository:
|
||||
)
|
||||
logger.debug(f"inserting {task=}")
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
self.db.refresh(task)
|
||||
return task
|
||||
|
||||
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
|
||||
|
||||
@@ -11,6 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio
|
||||
from oasst_backend.config import TreeManagerConfiguration, settings
|
||||
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
|
||||
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
|
||||
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
@@ -340,6 +341,7 @@ class TreeManager:
|
||||
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
@async_managed_tx_method(CommitMode.COMMIT)
|
||||
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
|
||||
pr = self.pr
|
||||
match type(interaction):
|
||||
@@ -358,7 +360,6 @@ class TreeManager:
|
||||
if not message.parent_id:
|
||||
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
|
||||
self._insert_default_state(message.id)
|
||||
self.db.commit()
|
||||
|
||||
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
|
||||
try:
|
||||
@@ -428,7 +429,6 @@ class TreeManager:
|
||||
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
@@ -439,7 +439,6 @@ class TreeManager:
|
||||
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
|
||||
msg.review_result = True
|
||||
self.db.add(msg)
|
||||
self.db.commit()
|
||||
logger.info(
|
||||
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
|
||||
)
|
||||
@@ -451,6 +450,7 @@ class TreeManager:
|
||||
|
||||
return protocol_schema.TaskDone()
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
|
||||
assert mts and mts.active
|
||||
|
||||
@@ -460,7 +460,6 @@ class TreeManager:
|
||||
mts.active = False
|
||||
mts.state = state.value
|
||||
self.db.add(mts)
|
||||
self.db.commit()
|
||||
|
||||
if is_terminal:
|
||||
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
|
||||
@@ -472,6 +471,7 @@ class TreeManager:
|
||||
mts = self.pr.fetch_tree_state(message_tree_id)
|
||||
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
|
||||
|
||||
@@ -489,6 +489,7 @@ class TreeManager:
|
||||
self._enter_state(mts, message_tree_state.State.GROWING)
|
||||
return True
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
|
||||
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
|
||||
|
||||
@@ -735,6 +736,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
|
||||
return rankings_by_message
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def ensure_tree_states(self):
|
||||
"""Add message tree state rows for all root nodes (inital prompt messages)."""
|
||||
|
||||
@@ -746,7 +748,6 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
|
||||
state = message_tree_state.State.GROWING
|
||||
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
|
||||
self._insert_default_state(id, state=state)
|
||||
self.db.commit()
|
||||
|
||||
def query_num_active_trees(self) -> int:
|
||||
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
|
||||
@@ -763,6 +764,7 @@ WHERE t.done = TRUE
|
||||
r = self.db.execute(text(sql_qry), {"message_id": message_id})
|
||||
return [TextLabels.from_orm(x) for x in r.all()]
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_tree_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
@@ -784,6 +786,7 @@ WHERE t.done = TRUE
|
||||
self.db.add(model)
|
||||
return model
|
||||
|
||||
@managed_tx_method(CommitMode.FLUSH)
|
||||
def _insert_default_state(
|
||||
self,
|
||||
root_message_id: UUID,
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from sqlmodel import Session
|
||||
@@ -62,6 +63,7 @@ class UserRepository:
|
||||
|
||||
return user
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
@@ -83,8 +85,8 @@ class UserRepository:
|
||||
user.notes = notes
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
"""
|
||||
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
|
||||
@@ -103,8 +105,8 @@ class UserRepository:
|
||||
user.deleted = True
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
@managed_tx_method(CommitMode.COMMIT)
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
@@ -127,13 +129,10 @@ class UserRepository:
|
||||
auth_method=client_user.auth_method,
|
||||
)
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
self.db.refresh(user)
|
||||
elif client_user.display_name and client_user.display_name != user.display_name:
|
||||
# we found the user but the display name changed
|
||||
user.display_name = client_user.display_name
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
return user
|
||||
|
||||
def query_users(
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
from enum import IntEnum
|
||||
from functools import wraps
|
||||
from http import HTTPStatus
|
||||
|
||||
from loguru import logger
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
MAX_DB_RETRY_COUNT = 3
|
||||
|
||||
|
||||
class CommitMode(IntEnum):
|
||||
"""
|
||||
Commit modes for the managed tx methods
|
||||
"""
|
||||
|
||||
NONE = 0
|
||||
FLUSH = 1
|
||||
COMMIT = 2
|
||||
|
||||
|
||||
"""
|
||||
* managed_tx_method and async_managed_tx_method methods are decorators functions
|
||||
* to be used on class functions. It expects the Class to have a 'db' Session object
|
||||
* initialised
|
||||
* TODO: tx method decorator for non class methods
|
||||
"""
|
||||
|
||||
|
||||
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
|
||||
self.db.rollback()
|
||||
pass
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Db Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapped_f(self, *args, **kwargs):
|
||||
try:
|
||||
for i in range(num_retries):
|
||||
try:
|
||||
result = await f(self, *args, **kwargs)
|
||||
if auto_commit == CommitMode.COMMIT:
|
||||
self.db.commit()
|
||||
elif auto_commit == CommitMode.FLUSH:
|
||||
self.db.flush()
|
||||
if isinstance(result, SQLModel):
|
||||
self.db.refresh(result)
|
||||
return result
|
||||
except OperationalError:
|
||||
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
|
||||
self.db.rollback()
|
||||
pass
|
||||
raise OasstError(
|
||||
"DATABASE_MAX_RETIRES_EXHAUSTED",
|
||||
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
|
||||
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Db Rollback Failure")
|
||||
raise e
|
||||
|
||||
return wrapped_f
|
||||
|
||||
return decorator
|
||||
@@ -18,6 +18,7 @@ class OasstErrorCode(IntEnum):
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
ROOT_TOKEN_NOT_AUTHORIZED = 3
|
||||
DATABASE_MAX_RETRIES_EXHAUSTED = 4
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
|
||||
Reference in New Issue
Block a user