mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-30 16:40:05 +08:00
536: Add endpoint to resolve frontend user by auth method and username (#539)
* Add endpoint to resolve frontend user by auth method and username * Require client ID for frontend user lookup * Remove unnecessary if check * Fix PromptRepository -> UserRepository * Convert to protocol User * Move User prep * Address review comments * 404 -> HTTP_404_NOT_FOUND
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
@@ -6,6 +7,7 @@ from oasst_backend.api import deps
|
||||
from oasst_backend.api.v1 import utils
|
||||
from oasst_backend.models import ApiClient
|
||||
from oasst_backend.prompt_repository import PromptRepository
|
||||
from oasst_backend.user_repository import UserRepository
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Session
|
||||
from starlette.status import HTTP_204_NO_CONTENT
|
||||
@@ -13,6 +15,22 @@ from starlette.status import HTTP_204_NO_CONTENT
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{auth_method}/{username}", response_model=protocol.User)
|
||||
def query_frontend_user(
|
||||
auth_method: str,
|
||||
username: str,
|
||||
api_client_id: Optional[UUID] = None,
|
||||
api_client: ApiClient = Depends(deps.get_api_client),
|
||||
db: Session = Depends(deps.get_db),
|
||||
):
|
||||
"""
|
||||
Query frontend user.
|
||||
"""
|
||||
ur = UserRepository(db, api_client)
|
||||
user = ur.query_frontend_user(auth_method, username, api_client_id)
|
||||
return protocol.User(id=user.username, display_name=user.display_name, auth_method=user.auth_method)
|
||||
|
||||
|
||||
@router.get("/{username}/messages", response_model=list[protocol.Message])
|
||||
def query_frontend_user_messages(
|
||||
username: str,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from oasst_backend.models import ApiClient, Message, User
|
||||
from oasst_shared.exceptions import OasstError, OasstErrorCode
|
||||
from oasst_shared.schemas import protocol as protocol_schema
|
||||
from oasst_shared.schemas.protocol import LeaderboardStats
|
||||
from sqlmodel import Session, func
|
||||
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
|
||||
|
||||
|
||||
class UserRepository:
|
||||
@@ -11,6 +14,27 @@ class UserRepository:
|
||||
self.db = db
|
||||
self.api_client = api_client
|
||||
|
||||
def query_frontend_user(
|
||||
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
|
||||
) -> Optional[User]:
|
||||
if not api_client_id:
|
||||
api_client_id = self.api_client.id
|
||||
|
||||
if not self.api_client.trusted and api_client_id != self.api_client.id:
|
||||
# Unprivileged API client asks for foreign user
|
||||
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
|
||||
|
||||
user: User = (
|
||||
self.db.query(User)
|
||||
.filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if user is None:
|
||||
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
|
||||
|
||||
return user
|
||||
|
||||
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
|
||||
if not client_user:
|
||||
return None
|
||||
|
||||
@@ -9,7 +9,7 @@ class OasstErrorCode(IntEnum):
|
||||
Ranges:
|
||||
0-1000: general errors
|
||||
1000-2000: tasks endpoint
|
||||
2000-3000: prompt_repository
|
||||
2000-3000: prompt_repository, task_repository, user_repository
|
||||
3000-4000: external resources
|
||||
"""
|
||||
|
||||
@@ -45,6 +45,7 @@ class OasstErrorCode(IntEnum):
|
||||
TASK_NOT_ACK = 2104
|
||||
TASK_ALREADY_DONE = 2105
|
||||
TASK_NOT_COLLECTIVE = 2106
|
||||
USER_NOT_FOUND = 2200
|
||||
|
||||
# 3000-4000: external resources
|
||||
HUGGINGFACE_API_ERROR = 3001
|
||||
|
||||
Reference in New Issue
Block a user