mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-28 16:20:34 +08:00
538: Add endpoints to manage users (#601)
* Add endpoints for getting, updating, deleting users by global user ID * Resolve formatting * Include alembic revision script * Updated down_revision to current alembic head Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
This commit is contained in:
+32
@@ -0,0 +1,32 @@
|
||||
"""Add enabled, deleted, notes fields to User
|
||||
|
||||
Revision ID: 846cc08ac79f
|
||||
Revises: aac6b2f66006
|
||||
Create Date: 2023-01-10 17:33:07.104596
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "846cc08ac79f"
|
||||
down_revision = "befa42582ea4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("user", sa.Column("enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False))
|
||||
op.add_column("user", sa.Column("deleted", sa.Boolean(), server_default=sa.text("false"), nullable=False))
|
||||
op.add_column("user", sa.Column("notes", sqlmodel.sql.sqltypes.AutoString(length=1024), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("user", "notes")
|
||||
op.drop_column("user", "deleted")
|
||||
op.drop_column("user", "enabled")
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,11 +1,13 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.models import ApiClient, User
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -13,6 +15,49 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/users/{user_id}", response_model=protocol.User)
|
||||
def get_user(
|
||||
user_id: UUID,
|
||||
api_client_id: UUID = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
):
|
||||
"""
|
||||
Get a user by global user ID. Only trusted clients can resolve users they did not register.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
user: User = ur.get_user(user_id, api_client_id)
|
||||
return protocol.User(user.username, user.display_name, user.auth_method)
|
||||
|
||||
|
||||
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def update_user(
|
||||
user_id: UUID,
|
||||
enabled: Optional[bool] = None,
|
||||
notes: Optional[str] = None,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
"""
|
||||
Update a user by global user ID. Only trusted clients can update users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.update_user(user_id, enabled, notes)
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
|
||||
def delete_user(
|
||||
user_id: UUID,
|
||||
db: Session = Depends(deps.get_db),
|
||||
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
||||
):
|
||||
"""
|
||||
Delete a user by global user ID. Only trusted clients can delete users.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
ur.mark_user_deleted(user_id)
|
||||
|
||||
|
||||
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
|
||||
def query_user_messages(
|
||||
user_id: UUID,
|
||||
|
||||
@@ -23,3 +23,6 @@ class User(SQLModel, table=True):
|
||||
sa_column=sa.Column(sa.DateTime(), nullable=False, server_default=sa.func.current_timestamp())
|
||||
)
|
||||
api_client_id: UUID = Field(foreign_key="api_client.id")
|
||||
enabled: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.true()))
|
||||
notes: str = Field(nullable=False, max_length=1024, default="")
|
||||
deleted: bool = Field(sa_column=sa.Column(sa.Boolean, nullable=False, server_default=sa.false()))
|
||||
|
||||
@@ -14,6 +14,34 @@ class UserRepository:
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def get_user(self, id: UUID, api_client_id: Optional[UUID] = None) -> User:
|
||||
"""
|
||||
Get a user by global user ID. All clients may get users with the same API client ID as the querying client.
|
||||
Trusted clients can get any user.
|
||||
|
||||
Raises:
|
||||
OasstError: 403 if untrusted client attempts to query foreign users. 404 if user with ID not found.
|
||||
"""
|
||||
if not self.api_client.trusted and api_client_id is None:
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged client requests foreign user
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
# Will always be unique
|
||||
user_query = self.db.query(User).filter(User.id == id)
|
||||
|
||||
if api_client_id:
|
||||
user_query = user_query.filter(User.api_client_id == api_client_id)
|
||||
|
||||
user: User = user_query.first()
|
||||
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
return user
|
||||
|
||||
def query_frontend_user(
|
||||
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
|
||||
) -> Optional[User]:
|
||||
@@ -35,6 +63,49 @@ class UserRepository:
|
||||
|
||||
return user
|
||||
|
||||
def update_user(self, id: UUID, enabled: Optional[bool] = None, notes: Optional[str] = None) -> None:
|
||||
"""
|
||||
Update a user by global user ID to disable or set admin notes. Only trusted clients may update users.
|
||||
|
||||
Raises:
|
||||
OasstError: 403 if untrusted client attempts to update a user. 404 if user with ID not found.
|
||||
"""
|
||||
if not self.api_client.trusted:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
user: User = self.db.query(User).filter(User.id == id).first()
|
||||
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
if enabled is not None:
|
||||
user.enabled = enabled
|
||||
if notes is not None:
|
||||
user.notes = notes
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
def mark_user_deleted(self, id: UUID) -> None:
|
||||
"""
|
||||
Update a user by global user ID to set deleted flag. Only trusted clients may delete users.
|
||||
|
||||
Raises:
|
||||
OasstError: 403 if untrusted client attempts to delete a user. 404 if user with ID not found.
|
||||
"""
|
||||
if not self.api_client.trusted:
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
user: User = self.db.query(User).filter(User.id == id).first()
|
||||
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
user.deleted = True
|
||||
|
||||
self.db.add(user)
|
||||
self.db.commit()
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user