From e5abb2dc85b7cca615c1642bb1807aed70555ce8 Mon Sep 17 00:00:00 2001 From: Oliver Stanley Date: Tue, 10 Jan 2023 17:20:56 +0000 Subject: [PATCH] 536: Add endpoint to resolve frontend user by auth method and username (#539) * Add endpoint to resolve frontend user by auth method and username * Require client ID for frontend user lookup * Remove unnecessary if check * Fix PromptRepository -> UserRepository * Convert to protocol User * Move User prep * Address review comments * 404 -> HTTP_404_NOT_FOUND --- .../oasst_backend/api/v1/frontend_users.py | 18 ++++++++++++++ backend/oasst_backend/user_repository.py | 24 +++++++++++++++++++ .../exceptions/oasst_api_error.py | 3 ++- 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/backend/oasst_backend/api/v1/frontend_users.py b/backend/oasst_backend/api/v1/frontend_users.py index 8d56b7f9..31f14c64 100644 --- a/backend/oasst_backend/api/v1/frontend_users.py +++ b/backend/oasst_backend/api/v1/frontend_users.py @@ -1,4 +1,5 @@ import datetime +from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query @@ -6,6 +7,7 @@ from oasst_backend.api import deps from oasst_backend.api.v1 import utils from oasst_backend.models import ApiClient from oasst_backend.prompt_repository import PromptRepository +from oasst_backend.user_repository import UserRepository from oasst_shared.schemas import protocol from sqlmodel import Session from starlette.status import HTTP_204_NO_CONTENT @@ -13,6 +15,22 @@ from starlette.status import HTTP_204_NO_CONTENT router = APIRouter() +@router.get("/{auth_method}/{username}", response_model=protocol.User) +def query_frontend_user( + auth_method: str, + username: str, + api_client_id: Optional[UUID] = None, + api_client: ApiClient = Depends(deps.get_api_client), + db: Session = Depends(deps.get_db), +): + """ + Query frontend user. + """ + ur = UserRepository(db, api_client) + user = ur.query_frontend_user(auth_method, username, api_client_id) + return protocol.User(id=user.username, display_name=user.display_name, auth_method=user.auth_method) + + @router.get("/{username}/messages", response_model=list[protocol.Message]) def query_frontend_user_messages( username: str, diff --git a/backend/oasst_backend/user_repository.py b/backend/oasst_backend/user_repository.py index b5508899..3acb1751 100644 --- a/backend/oasst_backend/user_repository.py +++ b/backend/oasst_backend/user_repository.py @@ -1,9 +1,12 @@ from typing import Optional +from uuid import UUID from oasst_backend.models import ApiClient, Message, User +from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.schemas.protocol import LeaderboardStats from sqlmodel import Session, func +from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND class UserRepository: @@ -11,6 +14,27 @@ class UserRepository: self.db = db self.api_client = api_client + def query_frontend_user( + self, auth_method: str, username: str, api_client_id: Optional[UUID] = None + ) -> Optional[User]: + if not api_client_id: + api_client_id = self.api_client.id + + if not self.api_client.trusted and api_client_id != self.api_client.id: + # Unprivileged API client asks for foreign user + raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN) + + user: User = ( + self.db.query(User) + .filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id) + .first() + ) + + if user is None: + raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND) + + return user + def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]: if not client_user: return None diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index 6cc25918..ce08d31d 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -9,7 +9,7 @@ class OasstErrorCode(IntEnum): Ranges: 0-1000: general errors 1000-2000: tasks endpoint - 2000-3000: prompt_repository + 2000-3000: prompt_repository, task_repository, user_repository 3000-4000: external resources """ @@ -45,6 +45,7 @@ class OasstErrorCode(IntEnum): TASK_NOT_ACK = 2104 TASK_ALREADY_DONE = 2105 TASK_NOT_COLLECTIVE = 2106 + USER_NOT_FOUND = 2200 # 3000-4000: external resources HUGGINGFACE_API_ERROR = 3001