From 50421dfada93447eddf9a85675130693b3c73f8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20K=C3=B6pf?= Date: Tue, 24 Jan 2023 09:57:45 +0100 Subject: [PATCH] retry user lookup in case of UniqueViolation (ix_user_username conflict) --- backend/oasst_backend/user_repository.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 136d3d29..873b0f6e 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -1,10 +1,12 @@ from typing import Optional from uuid import UUID +from oasst_backend.config import settings from oasst_backend.models import ApiClient, User from oasst_backend.utils.database_utils import CommitMode, managed_tx_method from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema +from sqlalchemy.exc import IntegrityError from sqlmodel import Session, and_, or_ from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND @@ -107,7 +109,7 @@ class UserRepository: self.db.add(user) @managed_tx_method(CommitMode.COMMIT) - def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + def _lookup_client_user_tx(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: if not client_user: return None user: User = ( @@ -135,6 +137,16 @@ class UserRepository: self.db.add(user) return user + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: + num_retries = settings.DATABASE_MAX_TX_RETRY_COUNT + for i in range(num_retries): + try: + return self._lookup_client_user_tx(client_user, create_missing) + except IntegrityError: + # catch UniqueViolation exception, for concurrent requests due to conflicts in ix_user_username + if i + 1 == num_retries: + raise + def query_users_ordered_by_username( self, api_client_id: Optional[UUID] = None,