599 add row versioning to backend tables (#710)

* fix: isolation level and nested db.commit() with retry wrappers on concurrent update errors

* refactor: incorporated review comments

changes decorator methods to managed_tx_method and async_managed_tx_method
new enum CommitMode
removed commented commit() from the previous commits

* fix: merge pre-commit errors

* fix: merge pre-commit changes

* fix: conflict in existing OasstErrorCode

* refactor: Added a refresh just to be sure that the select command is triggered on the mapped object

* fix: added refresh for async decorator

Co-authored-by: James Melvin <melvin@gameface.ai>
This commit is contained in:
James Melvin Ebenezer
2023-01-16 13:13:07 +05:30
committed by GitHub
parent 72a58ca2d3
commit c6fbf5543b
8 changed files with 127 additions and 33 deletions
+1 -1
View File
@@ -5,4 +5,4 @@ from sqlmodel import create_engine
if settings.DATABASE_URI is None:
raise OasstError("DATABASE_URI is not set", error_code=OasstErrorCode.DATABASE_URI_NOT_SET)
engine = create_engine(settings.DATABASE_URI)
engine = create_engine(settings.DATABASE_URI, echo=True, isolation_level="REPEATABLE READ")
+2 -3
View File
@@ -4,6 +4,7 @@ from uuid import UUID
from oasst_backend.models import ApiClient, Journal, Task, User
from oasst_backend.models.payload_column_type import PayloadContainer, payload_type
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.utils import utcnow
from pydantic import BaseModel
from sqlmodel import Session
@@ -80,6 +81,7 @@ class JournalWriter:
message_id=message_id,
)
@managed_tx_method(CommitMode.FLUSH)
def log(
self,
*,
@@ -115,7 +117,4 @@ class JournalWriter:
)
self.db.add(entry)
if commit:
self.db.commit()
return entry
+12 -13
View File
@@ -24,6 +24,7 @@ from oasst_backend.models import (
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.task_repository import TaskRepository, validate_frontend_message_id
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import SystemStats
@@ -67,6 +68,7 @@ class PromptRepository:
)
return message
@managed_tx_method(CommitMode.FLUSH)
def insert_message(
self,
*,
@@ -104,8 +106,8 @@ class PromptRepository:
review_result=review_result,
)
self.db.add(message)
self.db.commit()
self.db.refresh(message)
# self.db.refresh(message)
return message
def _validate_task(
@@ -134,6 +136,7 @@ class PromptRepository:
def fetch_tree_state(self, message_tree_id: UUID) -> MessageTreeState:
return self.db.query(MessageTreeState).filter(MessageTreeState.message_tree_id == message_tree_id).one()
@managed_tx_method(CommitMode.FLUSH)
def store_text_reply(
self,
text: str,
@@ -205,10 +208,10 @@ class PromptRepository:
if not task.collective:
task.done = True
self.db.add(task)
self.db.commit()
self.journal.log_text_reply(task=task, message_id=new_message_id, role=role, length=len(text))
return user_message
@managed_tx_method(CommitMode.FLUSH)
def store_rating(self, rating: protocol_schema.MessageRating) -> MessageReaction:
message = self.fetch_message_by_frontend_message_id(rating.message_id, fail_if_missing=True)
@@ -238,6 +241,7 @@ class PromptRepository:
logger.info(f"Ranking {rating.rating} stored for task {task.id}.")
return reaction
@managed_tx_method(CommitMode.COMMIT)
def store_ranking(self, ranking: protocol_schema.MessageRanking) -> Tuple[MessageReaction, Task]:
# fetch task
task = self.task_repository.fetch_task_by_frontend_message_id(ranking.message_id)
@@ -310,6 +314,7 @@ class PromptRepository:
return reaction, task
@managed_tx_method(CommitMode.FLUSH)
def insert_toxicity(self, message_id: UUID, model: str, score: float, label: str) -> MessageToxicity:
"""Save the toxicity score of a new message in the database.
Args:
@@ -325,10 +330,9 @@ class PromptRepository:
message_toxicity = MessageToxicity(message_id=message_id, model=model, score=score, label=label)
self.db.add(message_toxicity)
self.db.commit()
self.db.refresh(message_toxicity)
return message_toxicity
@managed_tx_method(CommitMode.FLUSH)
def insert_message_embedding(self, message_id: UUID, model: str, embedding: List[float]) -> MessageEmbedding:
"""Insert the embedding of a new message in the database.
@@ -346,10 +350,9 @@ class PromptRepository:
message_embedding = MessageEmbedding(message_id=message_id, model=model, embedding=embedding)
self.db.add(message_embedding)
self.db.commit()
self.db.refresh(message_embedding)
return message_embedding
@managed_tx_method(CommitMode.FLUSH)
def insert_reaction(self, task_id: UUID, payload: db_payload.ReactionPayload) -> MessageReaction:
if self.user_id is None:
raise OasstError("User required", OasstErrorCode.USER_NOT_SPECIFIED)
@@ -363,10 +366,9 @@ class PromptRepository:
payload_type=type(payload).__name__,
)
self.db.add(reaction)
self.db.commit()
self.db.refresh(reaction)
return reaction
@managed_tx_method(CommitMode.FLUSH)
def store_text_labels(self, text_labels: protocol_schema.TextLabels) -> Tuple[TextLabels, Task, Message]:
valid_labels: Optional[list[str]] = None
@@ -436,8 +438,6 @@ class PromptRepository:
self.db.add(message)
self.db.add(model)
self.db.commit()
self.db.refresh(model)
return model, task, message
def fetch_random_message_tree(self, require_role: str = None, reviewed: bool = True) -> list[Message]:
@@ -702,6 +702,7 @@ class PromptRepository:
return messages.all()
@managed_tx_method(CommitMode.COMMIT)
def mark_messages_deleted(self, messages: Message | UUID | list[Message | UUID], recursive: bool = True):
"""
Marks deleted messages and all their descendants.
@@ -730,8 +731,6 @@ class PromptRepository:
parent_ids = self.db.execute(query).scalars().all()
self.db.commit()
def get_stats(self) -> SystemStats:
"""
Get data stats such as number of all messages in the system,
+5 -6
View File
@@ -6,6 +6,7 @@ from loguru import logger
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -128,6 +129,7 @@ class TaskRepository:
assert task_model.id == task.id
return task_model
@managed_tx_method(CommitMode.COMMIT)
def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str):
validate_frontend_message_id(frontend_message_id)
@@ -142,10 +144,9 @@ class TaskRepository:
task.frontend_message_id = frontend_message_id
task.ack = True
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
@managed_tx_method(CommitMode.COMMIT)
def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
"""
Mark task as done. No further messages will be accepted for this task.
@@ -166,8 +167,8 @@ class TaskRepository:
task.done = True
self.db.add(task)
self.db.commit()
@managed_tx_method(CommitMode.COMMIT)
def acknowledge_task_failure(self, task_id):
# find task
task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
@@ -181,8 +182,8 @@ class TaskRepository:
task.ack = False
# ToDo: check race-condition, transaction
self.db.add(task)
self.db.commit()
@managed_tx_method(CommitMode.COMMIT)
def insert_task(
self,
payload: db_payload.TaskPayload,
@@ -204,8 +205,6 @@ class TaskRepository:
)
logger.debug(f"inserting {task=}")
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
+8 -5
View File
@@ -11,6 +11,7 @@ from oasst_backend.api.v1.utils import prepare_conversation, prepare_conversatio
from oasst_backend.config import TreeManagerConfiguration, settings
from oasst_backend.models import Message, MessageReaction, MessageTreeState, TextLabels, message_tree_state
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_method, managed_tx_method
from oasst_backend.utils.hugging_face import HfClassificationModel, HfEmbeddingModel, HfUrl, HuggingFaceAPI
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
@@ -340,6 +341,7 @@ class TreeManager:
return task, message_tree_id, parent_message_id
@async_managed_tx_method(CommitMode.COMMIT)
async def handle_interaction(self, interaction: protocol_schema.AnyInteraction) -> protocol_schema.Task:
pr = self.pr
match type(interaction):
@@ -358,7 +360,6 @@ class TreeManager:
if not message.parent_id:
logger.info(f"TreeManager: Inserting new tree state for initial prompt {message.id=}")
self._insert_default_state(message.id)
self.db.commit()
if not settings.DEBUG_SKIP_EMBEDDING_COMPUTATION:
try:
@@ -428,7 +429,6 @@ class TreeManager:
if acceptance_score > self.cfg.acceptance_threshold_initial_prompt:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Initial prompt message was accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
@@ -439,7 +439,6 @@ class TreeManager:
if not msg.review_result and acceptance_score > self.cfg.acceptance_threshold_reply:
msg.review_result = True
self.db.add(msg)
self.db.commit()
logger.info(
f"Reply message message accepted: {msg.id=}, {acceptance_score=}, {len(reviews)=}"
)
@@ -451,6 +450,7 @@ class TreeManager:
return protocol_schema.TaskDone()
@managed_tx_method(CommitMode.FLUSH)
def _enter_state(self, mts: MessageTreeState, state: message_tree_state.State):
assert mts and mts.active
@@ -460,7 +460,6 @@ class TreeManager:
mts.active = False
mts.state = state.value
self.db.add(mts)
self.db.commit()
if is_terminal:
logger.info(f"Tree entered terminal '{mts.state}' state ({mts.message_tree_id=})")
@@ -472,6 +471,7 @@ class TreeManager:
mts = self.pr.fetch_tree_state(message_tree_id)
self._enter_state(mts, message_tree_state.State.ABORTED_LOW_GRADE)
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_growing_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_growing_state({message_tree_id=})")
@@ -489,6 +489,7 @@ class TreeManager:
self._enter_state(mts, message_tree_state.State.GROWING)
return True
@managed_tx_method(CommitMode.COMMIT)
def check_condition_for_ranking_state(self, message_tree_id: UUID) -> bool:
logger.debug(f"check_condition_for_ranking_state({message_tree_id=})")
@@ -735,6 +736,7 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
rankings_by_message[parent_id].append(MessageReaction.from_orm(x))
return rankings_by_message
@managed_tx_method(CommitMode.COMMIT)
def ensure_tree_states(self):
"""Add message tree state rows for all root nodes (inital prompt messages)."""
@@ -746,7 +748,6 @@ LEFT JOIN message_reaction mr ON mr.task_id = t.id AND mr.payload_type = 'Rankin
state = message_tree_state.State.GROWING
logger.info(f"Inserting missing message tree state for message: {id} ({tree_size=}, {state=})")
self._insert_default_state(id, state=state)
self.db.commit()
def query_num_active_trees(self) -> int:
query = self.db.query(func.count(MessageTreeState.message_tree_id)).filter(MessageTreeState.active)
@@ -763,6 +764,7 @@ WHERE t.done = TRUE
r = self.db.execute(text(sql_qry), {"message_id": message_id})
return [TextLabels.from_orm(x) for x in r.all()]
@managed_tx_method(CommitMode.FLUSH)
def _insert_tree_state(
self,
root_message_id: UUID,
@@ -784,6 +786,7 @@ WHERE t.done = TRUE
self.db.add(model)
return model
@managed_tx_method(CommitMode.FLUSH)
def _insert_default_state(
self,
root_message_id: UUID,
+4 -5
View File
@@ -2,6 +2,7 @@ from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, User
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
@@ -62,6 +63,7 @@ class UserRepository:
return user
@managed_tx_method(CommitMode.COMMIT)
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
@@ -83,8 +85,8 @@ class UserRepository:
user.notes = notes
self.db.add(user)
self.db.commit()
@managed_tx_method(CommitMode.COMMIT)
def mark_user_deleted(self, id: UUID) -> None:
"""
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
@@ -103,8 +105,8 @@ class UserRepository:
user.deleted = True
self.db.add(user)
self.db.commit()
@managed_tx_method(CommitMode.COMMIT)
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
@@ -127,13 +129,10 @@ class UserRepository:
auth_method=client_user.auth_method,
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
self.db.commit()
return user
def query_users(
@@ -0,0 +1,94 @@
from enum import IntEnum
from functools import wraps
from http import HTTPStatus
from loguru import logger
from oasst_shared.exceptions import OasstError, OasstErrorCode
from sqlalchemy.exc import OperationalError
from sqlmodel import SQLModel
MAX_DB_RETRY_COUNT = 3
class CommitMode(IntEnum):
"""
Commit modes for the managed tx methods
"""
NONE = 0
FLUSH = 1
COMMIT = 2
"""
* managed_tx_method and async_managed_tx_method methods are decorators functions
* to be used on class functions. It expects the Class to have a 'db' Session object
* initialised
* TODO: tx method decorator for non class methods
"""
def managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
def decorator(f):
@wraps(f)
def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
self.db.rollback()
pass
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("Db Rollback Failure")
raise e
return wrapped_f
return decorator
def async_managed_tx_method(auto_commit: CommitMode = CommitMode.COMMIT, num_retries=MAX_DB_RETRY_COUNT):
def decorator(f):
@wraps(f)
async def wrapped_f(self, *args, **kwargs):
try:
for i in range(num_retries):
try:
result = await f(self, *args, **kwargs)
if auto_commit == CommitMode.COMMIT:
self.db.commit()
elif auto_commit == CommitMode.FLUSH:
self.db.flush()
if isinstance(result, SQLModel):
self.db.refresh(result)
return result
except OperationalError:
logger.info(f"Retrying count: {i+1} after possible db concurrent update conflict")
self.db.rollback()
pass
raise OasstError(
"DATABASE_MAX_RETIRES_EXHAUSTED",
error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED,
http_status_code=HTTPStatus.SERVICE_UNAVAILABLE,
)
except Exception as e:
logger.error("Db Rollback Failure")
raise e
return wrapped_f
return decorator
@@ -18,6 +18,7 @@ class OasstErrorCode(IntEnum):
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
ROOT_TOKEN_NOT_AUTHORIZED = 3
DATABASE_MAX_RETRIES_EXHAUSTED = 4
TOO_MANY_REQUESTS = 429
SERVER_ERROR0 = 500