diff --git a/backend/alembic/versions/2023_01_10_1733-846cc08ac79f_add_enabled_deleted_notes_fields_to_user.py b/backend/alembic/versions/2023_01_10_1733-846cc08ac79f_add_enabled_deleted_notes_fields_to_user.py new file mode 100644 index 00000000..79e444a7 --- /dev/null +++ b/backend/alembic/versions/2023_01_10_1733-846cc08ac79f_add_enabled_deleted_notes_fields_to_user.py @@ -0,0 +1,32 @@ +"""Add enabled, deleted, notes fields to User + +Revision ID: 846cc08ac79f +Revises: aac6b2f66006 +Create Date: 2023-01-10 17:33:07.104596 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "846cc08ac79f" +down_revision = "befa42582ea4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("user", sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False)) + op.add_column("user", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False)) + op.add_column("user", sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=1024), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("user", "notes") + op.drop_column("user", "deleted") + op.drop_column("user", "enabled") + # ### end Alembic commands ### diff --git a/backend/oasst_backend/api/v1/users.py b/backend/oasst_backend/api/v1/users.py index 5dda88eb..5bc0195e 100644 --- a/backend/oasst_backend/api/v1/users.py +++ b/backend/oasst_backend/api/v1/users.py @@ -1,11 +1,13 @@ import datetime +from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query from oasst_backend.api import deps from oasst_backend.api.v1 import utils -from oasst_backend.models import ApiClient +from oasst_backend.models import ApiClient, User from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -13,6 +15,49 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() +@router.get("/users/{user_id}", response_model=protocol.User) +def get_user( + user_id: UUID, + api_client_id: UUID = None, + db: Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_api_client), +): + """ + Get a user by global user ID. Only trusted clients can resolve users they did not register. + """ + ur = UserRepository(db, api_client) + user: User = ur.get_user(user_id, api_client_id) + return protocol.User(user.username, user.display_name, user.auth_method) + + +@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +def update_user( + user_id: UUID, + enabled: Optional[bool] = None, + notes: Optional[str] = None, + db: Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +): + """ + Update a user by global user ID. Only trusted clients can update users. + """ + ur = UserRepository(db, api_client) + ur.update_user(user_id, enabled, notes) + + +@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT) +def delete_user( + user_id: UUID, + db: Session = Depends(deps.get_db), + api_client: ApiClient = Depends(deps.get_trusted_api_client), +): + """ + Delete a user by global user ID. Only trusted clients can delete users. + """ + ur = UserRepository(db, api_client) + ur.mark_user_deleted(user_id) + + @router.get("/{user_id}/messages", response_model=list[protocol.Message]) def query_user_messages( user_id: UUID, diff --git a/backend/oasst_backend/models/user.py b/backend/oasst_backend/models/user.py index 1a06a524..f2dd76b9 100644 --- a/backend/oasst_backend/models/user.py +++ b/backend/oasst_backend/models/user.py @@ -23,3 +23,6 @@ class User(SQLModel, table=True): sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp()) ) api_client_id: UUID = Field(foreign_key="api_client.id") + enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true())) + notes: str = Field(nullable=False, max_length=1024, default="") + deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false())) diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index 3acb1751..14a15b35 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -14,6 +14,34 @@ class UserRepository: self.db = db self.api_client = api_client + def get_user(self, id: UUID, api_client_id: Optional[UUID] = None) -> User: + """ + Get a user by global user ID. All clients may get users with the same API client ID as the querying client. + Trusted clients can get any user. + + Raises: + OasstError: 403 if untrusted client attempts to query foreign users. 404 if user with ID not found. + """ + if not self.api_client.trusted and api_client_id is None: + api_client_id = self.api_client.id + + if not self.api_client.trusted and api_client_id != self.api_client.id: + # Unprivileged client requests foreign user + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + + # Will always be unique + user_query = self.db.query(User).filter(User.id == id) + + if api_client_id: + user_query = user_query.filter(User.api_client_id == api_client_id) + + user: User = user_query.first() + + if user is None: + raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND) + + return user + def query_frontend_user( self, auth_method: str, username: str, api_client_id: Optional[UUID] = None ) -> Optional[User]: @@ -35,6 +63,49 @@ class UserRepository: return user + def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None: + """ + Update a user by global user ID to disable or set admin notes. Only trusted clients may update users. + + Raises: + OasstError: 403 if untrusted client attempts to update a user. 404 if user with ID not found. + """ + if not self.api_client.trusted: + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + + user: User = self.db.query(User).filter(User.id == id).first() + + if user is None: + raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND) + + if enabled is not None: + user.enabled = enabled + if notes is not None: + user.notes = notes + + self.db.add(user) + self.db.commit() + + def mark_user_deleted(self, id: UUID) -> None: + """ + Update a user by global user ID to set deleted flag. Only trusted clients may delete users. + + Raises: + OasstError: 403 if untrusted client attempts to delete a user. 404 if user with ID not found. + """ + if not self.api_client.trusted: + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + + user: User = self.db.query(User).filter(User.id == id).first() + + if user is None: + raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND) + + user.deleted = True + + self.db.add(user) + self.db.commit() + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: if not client_user: return None