mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-07-05 17:30:48 +08:00
94e5d50537
* add users/cursor endpoint
* add messages/cursor endpoint
* add user/{user_id}/messages/cursor, frontend_user/{auth_method}/{username}/messages/cursor
* user regex to parse cursor value
136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
import datetime
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, Query
|
|
from oasst_backend.api import deps
|
|
from oasst_backend.api.v1 import utils
|
|
from oasst_backend.api.v1.messages import get_messages_cursor
|
|
from oasst_backend.models import ApiClient
|
|
from oasst_backend.prompt_repository import PromptRepository
|
|
from oasst_backend.user_repository import UserRepository
|
|
from oasst_shared.schemas import protocol
|
|
from sqlmodel import Session
|
|
from starlette.status import HTTP_204_NO_CONTENT
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/", response_model=list[protocol.FrontEndUser], deprecated=True)
|
|
def get_users_ordered_by_username(
|
|
api_client_id: Optional[UUID] = None,
|
|
gte_username: Optional[str] = None,
|
|
gt_id: Optional[UUID] = None,
|
|
lte_username: Optional[str] = None,
|
|
lt_id: Optional[UUID] = None,
|
|
search_text: Optional[str] = None,
|
|
auth_method: Optional[str] = None,
|
|
max_count: Optional[int] = Query(100, gt=0, le=10000),
|
|
api_client: ApiClient = Depends(deps.get_api_client),
|
|
db: Session = Depends(deps.get_db),
|
|
):
|
|
ur = UserRepository(db, api_client)
|
|
users = ur.query_users_ordered_by_username(
|
|
api_client_id=api_client_id,
|
|
gte_username=gte_username,
|
|
gt_id=gt_id,
|
|
lte_username=lte_username,
|
|
lt_id=lt_id,
|
|
auth_method=auth_method,
|
|
search_text=search_text,
|
|
limit=max_count,
|
|
)
|
|
return [u.to_protocol_frontend_user() for u in users]
|
|
|
|
|
|
@router.get("/{auth_method}/{username}", response_model=protocol.FrontEndUser)
|
|
def query_frontend_user(
|
|
auth_method: str,
|
|
username: str,
|
|
api_client_id: Optional[UUID] = None,
|
|
api_client: ApiClient = Depends(deps.get_api_client),
|
|
db: Session = Depends(deps.get_db),
|
|
):
|
|
"""
|
|
Query frontend user.
|
|
"""
|
|
ur = UserRepository(db, api_client)
|
|
user = ur.query_frontend_user(auth_method, username, api_client_id)
|
|
return user.to_protocol_frontend_user()
|
|
|
|
|
|
@router.get("/{auth_method}/{username}/messages", response_model=list[protocol.Message])
|
|
def query_frontend_user_messages(
|
|
auth_method: str,
|
|
username: str,
|
|
api_client_id: UUID = None,
|
|
max_count: int = Query(10, gt=0, le=1000),
|
|
start_date: datetime.datetime = None,
|
|
end_date: datetime.datetime = None,
|
|
only_roots: bool = False,
|
|
desc: bool = True,
|
|
include_deleted: bool = False,
|
|
api_client: ApiClient = Depends(deps.get_api_client),
|
|
db: Session = Depends(deps.get_db),
|
|
):
|
|
"""
|
|
Query frontend user messages.
|
|
"""
|
|
pr = PromptRepository(db, api_client)
|
|
messages = pr.query_messages_ordered_by_created_date(
|
|
auth_method=auth_method,
|
|
username=username,
|
|
api_client_id=api_client_id,
|
|
desc=desc,
|
|
limit=max_count,
|
|
gte_created_date=start_date,
|
|
lte_created_date=end_date,
|
|
only_roots=only_roots,
|
|
deleted=None if include_deleted else False,
|
|
)
|
|
return utils.prepare_message_list(messages)
|
|
|
|
|
|
@router.get("/{auth_method}/{username}/messages/cursor", response_model=protocol.MessagePage)
|
|
def query_frontend_user_messages_cursor(
|
|
auth_method: str,
|
|
username: str,
|
|
lt: Optional[str] = None,
|
|
gt: Optional[str] = None,
|
|
only_roots: Optional[bool] = False,
|
|
include_deleted: Optional[bool] = False,
|
|
max_count: Optional[int] = Query(10, gt=0, le=1000),
|
|
desc: Optional[bool] = False,
|
|
api_client: ApiClient = Depends(deps.get_api_client),
|
|
db: Session = Depends(deps.get_db),
|
|
):
|
|
return get_messages_cursor(
|
|
lt=lt,
|
|
gt=gt,
|
|
auth_method=auth_method,
|
|
username=username,
|
|
only_roots=only_roots,
|
|
include_deleted=include_deleted,
|
|
max_count=max_count,
|
|
desc=desc,
|
|
api_client=api_client,
|
|
db=db,
|
|
)
|
|
|
|
|
|
@router.delete("/{auth_method}/{username}/messages", status_code=HTTP_204_NO_CONTENT)
|
|
def mark_frontend_user_messages_deleted(
|
|
auth_method: str,
|
|
username: str,
|
|
api_client: ApiClient = Depends(deps.get_trusted_api_client),
|
|
db: Session = Depends(deps.get_db),
|
|
):
|
|
pr = PromptRepository(db, api_client)
|
|
messages = pr.query_messages_ordered_by_created_date(
|
|
auth_method=auth_method,
|
|
username=username,
|
|
api_client_id=api_client.id,
|
|
limit=None,
|
|
)
|
|
pr.mark_messages_deleted(messages)
|