Files
Open-Assistant/backend/oasst_backend/user_repository.py
T
James Melvin Ebenezer c6fbf5543b 599 add row versioning to backend tables (#710)
* fix: isolation level and nested db.commit() with retry wrappers on concurrent update errors

* refactor: incorporated review comments

changes decorator methods to managed_tx_method and async_managed_tx_method
new enum CommitMode
removed commented commit() from the previous commits

* fix: merge pre-commit errors

* fix: merge pre-commit changes

* fix: conflict in existing OasstErrorCode

* refactor: Added a refresh just to be sure that the select command is triggered on the mapped object

* fix: added refresh for async decorator

Co-authored-by: James Melvin <melvin@gameface.ai>
2023-01-16 08:43:07 +01:00

204 lines
7.2 KiB
Python

from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, User
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from sqlmodel import Session
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class UserRepository:
def __init__(self, db: Session, api_client: ApiClient):
self.db = db
self.api_client = api_client
def get_user(self, id: UUID, api_client_id: Optional[UUID] = None) -> User:
"""
Get a user by global user ID. All clients may get users with the same API client ID as the querying client.
Trusted clients can get any user.
Raises:
OasstError: 403 if untrusted client attempts to query foreign users. 404 if user with ID not found.
"""
if not self.api_client.trusted and api_client_id is None:
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged client requests foreign user
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
# Will always be unique
user_query = self.db.query(User).filter(User.id == id)
if api_client_id:
user_query = user_query.filter(User.api_client_id == api_client_id)
user: User = user_query.first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
return user
def query_frontend_user(
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
) -> Optional[User]:
if not api_client_id:
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged API client asks for foreign user
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = (
self.db.query(User)
.filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id)
.first()
)
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
return user
@managed_tx_method(CommitMode.COMMIT)
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
Raises:
OasstError: 403 if untrusted client attempts to update a user. 404 if user with ID not found.
"""
if not self.api_client.trusted:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = self.db.query(User).filter(User.id == id).first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
if enabled is not None:
user.enabled = enabled
if notes is not None:
user.notes = notes
self.db.add(user)
@managed_tx_method(CommitMode.COMMIT)
def mark_user_deleted(self, id: UUID) -> None:
"""
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
Raises:
OasstError: 403 if untrusted client attempts to delete a user. 404 if user with ID not found.
"""
if not self.api_client.trusted:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = self.db.query(User).filter(User.id == id).first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
user.deleted = True
self.db.add(user)
@managed_tx_method(CommitMode.COMMIT)
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
user: User = (
self.db.query(User)
.filter(
User.api_client_id == self.api_client.id,
User.username == client_user.id,
User.auth_method == client_user.auth_method,
)
.first()
)
if user is None:
if create_missing:
# user is unknown, create new record
user = User(
username=client_user.id,
display_name=client_user.display_name,
api_client_id=self.api_client.id,
auth_method=client_user.auth_method,
)
self.db.add(user)
elif client_user.display_name and client_user.display_name != user.display_name:
# we found the user but the display name changed
user.display_name = client_user.display_name
self.db.add(user)
return user
def query_users(
self,
api_client_id: Optional[UUID] = None,
limit: Optional[int] = 20,
gt: Optional[str] = None,
lt: Optional[str] = None,
auth_method: Optional[str] = None,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
api_client_id = self.api_client.id
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
users = self.db.query(User)
if api_client_id:
users = users.filter(User.api_client_id == api_client_id)
if auth_method:
users = users.filter(User.auth_method == auth_method)
users = users.order_by(User.display_name)
if gt:
users = users.filter(User.display_name > gt)
if lt:
users = users.filter(User.display_name < lt)
if limit is not None:
users = users.limit(limit)
return users.all()
def query_users_by_display_name(
self,
search_text: str,
exact: Optional[bool] = False,
limit: Optional[int] = 20,
api_client_id: Optional[UUID] = None,
auth_method: Optional[str] = None,
) -> list[User]:
if not self.api_client.trusted:
if not api_client_id:
api_client_id = self.api_client.id
if api_client_id != self.api_client.id:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
qry = self.db.query(User).order_by(User.display_name)
if exact:
qry = qry.filter(User.display_name == search_text)
else:
pattern = "%{}%".format(search_text.replace("\\", "\\\\").replace("_", "\\_").replace("%", "\\%"))
qry = qry.filter(User.display_name.like(pattern))
if auth_method:
qry = qry.filter(User.auth_method == auth_method)
if limit is not None:
qry = qry.limit(limit)
return qry.all()