Files
Open-Assistant/backend/oasst_backend/api/v1/users.py
T
Oliver Stanley 050d4902f3 538: Add endpoints to manage users (#601)
* Add endpoints for getting, updating, deleting users by global user ID

* Resolve formatting

* Include alembic revision script

* Updated down_revision to current alembic head

Co-authored-by: Andreas Köpf <andreas.koepf@xamla.com>
2023-01-12 21:30:07 +01:00

99 lines
3.0 KiB
Python

import datetime
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient, User
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/users/{user_id}", response_model=protocol.User)
def get_user(
user_id: UUID,
api_client_id: UUID = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_api_client),
):
"""
Get a user by global user ID. Only trusted clients can resolve users they did not register.
"""
ur = UserRepository(db, api_client)
user: User = ur.get_user(user_id, api_client_id)
return protocol.User(user.username, user.display_name, user.auth_method)
@router.put("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
def update_user(
user_id: UUID,
enabled: Optional[bool] = None,
notes: Optional[str] = None,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
"""
Update a user by global user ID. Only trusted clients can update users.
"""
ur = UserRepository(db, api_client)
ur.update_user(user_id, enabled, notes)
@router.delete("/users/{user_id}", status_code=HTTP_204_NO_CONTENT)
def delete_user(
user_id: UUID,
db: Session = Depends(deps.get_db),
api_client: ApiClient = Depends(deps.get_trusted_api_client),
):
"""
Delete a user by global user ID. Only trusted clients can delete users.
"""
ur = UserRepository(db, api_client)
ur.mark_user_deleted(user_id)
@router.get("/{user_id}/messages", response_model=list[protocol.Message])
def query_user_messages(
user_id: UUID,
api_client_id: UUID = None,
max_count: int = Query(10, gt=0, le=1000),
start_date: datetime.datetime = None,
end_date: datetime.datetime = None,
only_roots: bool = False,
desc: bool = True,
include_deleted: bool = False,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query user messages.
"""
pr = PromptRepository(db, api_client)
messages = pr.query_messages(
user_id=user_id,
api_client_id=api_client_id,
desc=desc,
limit=max_count,
start_date=start_date,
end_date=end_date,
only_roots=only_roots,
deleted=None if include_deleted else False,
)
return utils.prepare_message_list(messages)
@router.delete("/{user_id}/messages", status_code=HTTP_204_NO_CONTENT)
def mark_user_messages_deleted(
user_id: UUID, api_client: ApiClient = Depends(deps.get_trusted_api_client), db: Session = Depends(deps.get_db)
):
pr = PromptRepository(db, api_client)
messages = pr.query_messages(user_id=user_id)
pr.mark_messages_deleted(messages)