From 65d69dac6661be6130eb9308fe37d0af98e82fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 31 Jan 2023 13:08:13 +0100 Subject: [PATCH] fix steak interaction changes --- backend/oasst_backend/api/v1/tasks.py | 18 +++--- backend/oasst_backend/user_repository.py | 6 +- backend/oasst_backend/utils/database_utils.py | 61 +++++++++++++++++++ 3 files changed, 70 insertions(+), 15 deletions(-) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index f76f2f3c..e46c2358 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -8,7 +8,7 @@ from oasst_backend.api import deps from oasst_backend.prompt_repository import PromptRepository, TaskRepository from oasst_backend.tree_manager import TreeManager from oasst_backend.user_repository import UserRepository -from oasst_backend.utils.database_utils import CommitMode, managed_tx_function +from oasst_backend.utils.database_utils import CommitMode, async_managed_tx_function from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session @@ -129,7 +129,6 @@ def tasks_acknowledge_failure( @router.post("/interaction", response_model=protocol_schema.TaskDone) async def tasks_interaction( *, - db: Session = Depends(deps.get_db), api_key: APIKey = Depends(deps.get_api_key), interaction: protocol_schema.AnyInteraction, ) -> Any: @@ -137,16 +136,15 @@ async def tasks_interaction( The frontend reports an interaction. """ - @managed_tx_function(CommitMode.COMMIT) + @async_managed_tx_function(CommitMode.COMMIT) async def interaction_tx(session: deps.Session): - api_client = deps.api_auth(api_key, db) - pr = PromptRepository(db, api_client, client_user=interaction.user) - tm = TreeManager(db, pr) - ur = UserRepository(db, api_client) + api_client = deps.api_auth(api_key, session) + pr = PromptRepository(session, api_client, client_user=interaction.user) + tm = TreeManager(session, pr) + ur = UserRepository(session, api_client) task = await tm.handle_interaction(interaction) - match (type(task)): - case protocol_schema.TaskDone: - ur.update_user_last_activity(client_user=interaction.user) + if type(task) is protocol_schema.TaskDone: + ur.update_user_last_activity(user=pr.user) return task try: diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 679adf05..984964b6 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -312,10 +312,6 @@ class UserRepository: return qry.all() @managed_tx_method(CommitMode.FLUSH) - def update_user_last_activity(self, client_user: protocol_schema.User) -> None: - user = self.lookup_client_user(client_user=client_user, create_missing=False) - if user is None: - raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND) - + def update_user_last_activity(self, user: User) -> None: user.last_activity_date = utcnow() self.db.add(user) diff --git a/backend/oasst_backend/utils/database_utils.py b/backend/oasst_backend/utils/database_utils.py index 5ac25f50..d2a9ebe7 100644 --- a/backend/oasst_backend/utils/database_utils.py +++ b/backend/oasst_backend/utils/database_utils.py @@ -203,3 +203,64 @@ def managed_tx_function( return wrapped_f return decorator + + +def async_managed_tx_function( + auto_commit: CommitMode = CommitMode.COMMIT, + num_retries=settings.DATABASE_MAX_TX_RETRY_COUNT, + session_factory: Callable[..., Session] = default_session_factor, +): + """Passes Session object as first argument to wrapped function.""" + + def decorator(f): + @wraps(f) + async def wrapped_f(*args, **kwargs): + try: + result = None + if auto_commit == CommitMode.COMMIT: + retry_exhausted = True + for i in range(num_retries): + with session_factory() as session: + try: + result = await f(session, *args, **kwargs) + session.commit() + if isinstance(result, SQLModel): + session.refresh(result) + retry_exhausted = False + break + except PendingRollbackError as e: + logger.info(str(e)) + session.rollback() + except OperationalError as e: + if e.orig is not None and isinstance( + e.orig, + (SerializationFailure, DeadlockDetected, UniqueViolation, ExclusionViolation), + ): + logger.info(f"{type(e.orig)} Inner {e.orig.pgcode} {type(e.orig.pgcode)}") + session.rollback() + else: + raise e + logger.info(f"Retry {i+1}/{num_retries}") + if retry_exhausted: + raise OasstError( + "DATABASE_MAX_RETIRES_EXHAUSTED", + error_code=OasstErrorCode.DATABASE_MAX_RETRIES_EXHAUSTED, + http_status_code=HTTPStatus.SERVICE_UNAVAILABLE, + ) + else: + with session_factory() as session: + result = await f(session, *args, **kwargs) + if auto_commit == CommitMode.FLUSH: + session.flush() + if isinstance(result, SQLModel): + session.refresh(result) + elif auto_commit == CommitMode.ROLLBACK: + session.rollback() + return result + except Exception as e: + logger.info(str(e)) + raise e + + return wrapped_f + + return decorator