536: Add endpoint to resolve frontend user by auth method and username (#539)

* Add endpoint to resolve frontend user by auth method and username

* Require client ID for frontend user lookup

* Remove unnecessary if check

* Fix PromptRepository -> UserRepository

* Convert to protocol User

* Move User prep

* Address review comments

* 404 -> HTTP_404_NOT_FOUND
This commit is contained in:
Oliver Stanley
2023-01-10 17:20:56 +00:00
committed by GitHub
parent 1318c213d3
commit e5abb2dc85
3 changed files with 44 additions and 1 deletions
@@ -1,4 +1,5 @@
import datetime
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
@@ -6,6 +7,7 @@ from oasst_backend.api import deps
from oasst_backend.api.v1 import utils
from oasst_backend.models import ApiClient
from oasst_backend.prompt_repository import PromptRepository
from oasst_backend.user_repository import UserRepository
from oasst_shared.schemas import protocol
from sqlmodel import Session
from starlette.status import HTTP_204_NO_CONTENT
@@ -13,6 +15,22 @@ from starlette.status import HTTP_204_NO_CONTENT
router = APIRouter()
@router.get("/{auth_method}/{username}", response_model=protocol.User)
def query_frontend_user(
auth_method: str,
username: str,
api_client_id: Optional[UUID] = None,
api_client: ApiClient = Depends(deps.get_api_client),
db: Session = Depends(deps.get_db),
):
"""
Query frontend user.
"""
ur = UserRepository(db, api_client)
user = ur.query_frontend_user(auth_method, username, api_client_id)
return protocol.User(id=user.username, display_name=user.display_name, auth_method=user.auth_method)
@router.get("/{username}/messages", response_model=list[protocol.Message])
def query_frontend_user_messages(
username: str,
+24
View File
@@ -1,9 +1,12 @@
from typing import Optional
from uuid import UUID
from oasst_backend.models import ApiClient, Message, User
from oasst_shared.exceptions import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.schemas.protocol import LeaderboardStats
from sqlmodel import Session, func
from starlette.status import HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
class UserRepository:
@@ -11,6 +14,27 @@ class UserRepository:
self.db = db
self.api_client = api_client
def query_frontend_user(
self, auth_method: str, username: str, api_client_id: Optional[UUID] = None
) -> Optional[User]:
if not api_client_id:
api_client_id = self.api_client.id
if not self.api_client.trusted and api_client_id != self.api_client.id:
# Unprivileged API client asks for foreign user
raise OasstError("Forbidden", OasstErrorCode.API_CLIENT_NOT_AUTHORIZED, HTTP_403_FORBIDDEN)
user: User = (
self.db.query(User)
.filter(User.auth_method == auth_method, User.username == username, User.api_client_id == api_client_id)
.first()
)
if user is None:
raise OasstError("User not found", OasstErrorCode.USER_NOT_FOUND, HTTP_404_NOT_FOUND)
return user
def lookup_client_user(self, client_user: protocol_schema.User, create_missing: bool = True) -> Optional[User]:
if not client_user:
return None
@@ -9,7 +9,7 @@ class OasstErrorCode(IntEnum):
Ranges:
0-1000: general errors
1000-2000: tasks endpoint
2000-3000: prompt_repository
2000-3000: prompt_repository, task_repository, user_repository
3000-4000: external resources
"""
@@ -45,6 +45,7 @@ class OasstErrorCode(IntEnum):
TASK_NOT_ACK = 2104
TASK_ALREADY_DONE = 2105
TASK_NOT_COLLECTIVE = 2106
USER_NOT_FOUND = 2200
# 3000-4000: external resources
HUGGINGFACE_API_ERROR = 3001