538: Add endpoints to manage users (#601)

* Add endpoints for getting, updating, deleting users by global user ID

* Resolve formatting

* Include alembic revision script

* Updated down_revision to current alembic head

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
Oliver Stanley
2023-01-12 20:30:07 +00:00
committed by GitHub
parent 30242a2f32
commit 050d4902f3
4 changed files with 152 additions and 1 deletions
@@ -0,0 +1,32 @@
"""Add enabled, deleted, notes fields to User
Revision ID: 846cc08ac79f
Revises: aac6b2f66006
Create Date: 2023-01-10 17:33:07.104596
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "846cc08ac79f"
down_revision = "befa42582ea4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("user", sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False))
op.add_column("user", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
op.add_column("user", sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=1024), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("user", "notes")
op.drop_column("user", "deleted")
op.drop_column("user", "enabled")
# ### end Alembic commands ###
+46 -1
View File
@@ -1,11 +1,13 @@
import datetime
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -13,6 +15,49 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/users/{user_id}", response_model=protocol.User)
def get_user(
user_id: UUID,
api_client_id: UUID = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_api_client),
):
"""
Get a user by global user ID. Only trusted clients can resolve users they did not register.
"""
ur = UserRepository(db, api_client)
user: User = ur.get_user(user_id, api_client_id)
return protocol.User(user.username, user.display_name, user.auth_method)
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
def update_user(
user_id: UUID,
enabled: Optional[bool] = None,
notes: Optional[str] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
"""
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes)
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
def delete_user(
user_id: UUID,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
"""
Delete a user by global user ID. Only trusted clients can delete users.
"""
ur = UserRepository(db, api_client)
ur.mark_user_deleted(user_id)
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
def query_user_messages(
user_id: UUID,
+3
View File
@@ -23,3 +23,6 @@ class User(SQLModel, table=True):
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
)
api_client_id: UUID = Field(foreign_key="api_client.id")
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
notes: str = Field(nullable=False, max_length=1024, default="")
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
+71
View File
@@ -14,6 +14,34 @@ class UserRepository:
self.db = db
self.api_client = api_client
def get_user(self, id: UUID, api_client_id: Optional[UUID] = None) -> User:
"""
Get a user by global user ID. All clients may get users with the same API client ID as the querying client.
Trusted clients can get any user.
Raises:
OasstError: 403 if untrusted client attempts to query foreign users. 404 if user with ID not found.
"""
if not self.api_client.trusted and api_client_id is None:
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged client requests foreign user
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
# Will always be unique
user_query = self.db.query(User).filter(User.id == id)
if api_client_id:
user_query = user_query.filter(User.api_client_id == api_client_id)
user: User = user_query.first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
return user
def query_frontend_user(
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
) -> Optional[User]:
@@ -35,6 +63,49 @@ class UserRepository:
return user
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
"""
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
Raises:
OasstError: 403 if untrusted client attempts to update a user. 404 if user with ID not found.
"""
if not self.api_client.trusted:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = self.db.query(User).filter(User.id == id).first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
if enabled is not None:
user.enabled = enabled
if notes is not None:
user.notes = notes
self.db.add(user)
self.db.commit()
def mark_user_deleted(self, id: UUID) -> None:
"""
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
Raises:
OasstError: 403 if untrusted client attempts to delete a user. 404 if user with ID not found.
"""
if not self.api_client.trusted:
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = self.db.query(User).filter(User.id == id).first()
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
user.deleted = True
self.db.add(user)
self.db.commit()
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None