From cc03376d86560c262be0e10aaecf682f7fd16354 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 15 Jan 2023 22:24:25 +0100 Subject: [PATCH] added root tokens and endpoint for adding api keys (#742) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added root tokens and endpoint for adding api keys * Change down revision to current alembic head * removed added_by_root_token * refactored description * fixed jinja errors Co-authored-by: Andreas Köpf --- ansible/deploy-dev.yaml | 9 ++- backend/oasst_backend/api/deps.py | 65 +++++++++++++++---- backend/oasst_backend/api/v1/admin.py | 31 +++++++++ backend/oasst_backend/api/v1/api.py | 2 + backend/oasst_backend/config.py | 4 +- copilot/api/manifest.yml | 2 +- .../exceptions/oasst_api_error.py | 1 + scripts/backend-development/run-local.sh | 2 +- 8 files changed, 97 insertions(+), 19 deletions(-) create mode 100644 backend/oasst_backend/api/v1/admin.py diff --git a/ansible/deploy-dev.yaml b/ansible/deploy-dev.yaml index adb694e3..8d701fb2 100644 --- a/ansible/deploy-dev.yaml +++ b/ansible/deploy-dev.yaml @@ -32,6 +32,7 @@ name: "oasst-{{ stack_name }}-redis" image: redis state: started + recreate: "{{ (stack_name == 'dev') | bool }}" restart_policy: always network_mode: "oasst-{{ stack_name }}" healthcheck: @@ -48,6 +49,7 @@ name: "oasst-{{ stack_name }}-postgres-{{ item.name }}" image: postgres:15 state: started + recreate: "{{ (stack_name == 'dev') | bool }}" restart_policy: always network_mode: "oasst-{{ stack_name }}" env: @@ -75,11 +77,12 @@ env: POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend" REDIS_HOST: "oasst-{{ stack_name }}-redis" - DEBUG_ALLOW_ANY_API_KEY: "true" + DEBUG_ALLOW_DEBUG_API_KEY: "true" DEBUG_USE_SEED_DATA: "true" - DEBUG_ALLOW_SELF_LABELING: "true" + DEBUG_ALLOW_SELF_LABELING: + "{{ 'true' if stack_name == 'dev' else 'false' }}" MAX_WORKERS: "1" - RATE_LIMIT: "false" + RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}" DEBUG_SKIP_EMBEDDING_COMPUTATION: "true" DEBUG_SKIP_TOXICITY_CALCULATION: "true" ports: diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index 901ec8ab..b4d27870 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -1,9 +1,9 @@ from http import HTTPStatus from secrets import token_hex from typing import Generator -from uuid import UUID from fastapi import Depends, Request, Response, Security +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery from fastapi_limiter.depends import RateLimiter from loguru import logger @@ -22,6 +22,8 @@ def get_db() -> Generator: api_key_query = APIKeyQuery(name="api_key", auto_error=False) api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) +bearer_token = HTTPBearer(auto_error=False) + async def get_api_key( api_key_query: str = Security(api_key_query), @@ -33,22 +35,47 @@ async def get_api_key( return api_key_header -def get_dummy_api_client(db: Session) -> ApiClient: +def create_api_client( + *, + session: Session, + description: str, + frontend_type: str, + trusted: bool | None = False, + admin_email: str | None = None, + api_key: str | None = None, +) -> ApiClient: + if api_key is None: + api_key = token_hex(32) + + logger.info(f"Creating new api client with {api_key=}") + api_client = ApiClient( + api_key=api_key, + description=description, + frontend_type=frontend_type, + trusted=trusted, + admin_email=admin_email, + ) + session.add(api_client) + session.commit() + session.refresh(api_client) + return api_client + + +def get_dummy_api_client(session: Session) -> ApiClient: # make sure that a dummy api key exits in db (foreign key references) - ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") - api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + DUMMY_API_KEY = "1234" + api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first() if api_client is None: - token = token_hex(32) - logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") - api_client = ApiClient( - id=ANY_API_KEY_ID, - api_key=token, - description="ANY_API_KEY, random token", + logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}") + api_client = create_api_client( + session=session, + api_key=DUMMY_API_KEY, + description="Dummy api key for debugging", trusted=True, frontend_type="Test frontend", ) - db.add(api_client) - db.commit() + session.add(api_client) + session.commit() return api_client @@ -58,7 +85,7 @@ def api_auth( ) -> ApiClient: if api_key or settings.DEBUG_SKIP_API_KEY_CHECK: - if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY: + if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY: return get_dummy_api_client(db) api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() @@ -93,6 +120,18 @@ def get_trusted_api_client( return client +def get_root_token(bearer_token: HTTPAuthorizationCredentials = Security(bearer_token)) -> str: + if bearer_token: + token = bearer_token.credentials + if token and token in settings.ROOT_TOKENS: + return token + raise OasstError( + "Could not validate credentials", + error_code=OasstErrorCode.ROOT_TOKEN_NOT_AUTHORIZED, + http_status_code=HTTPStatus.FORBIDDEN, + ) + + class UserRateLimiter(RateLimiter): def __init__( self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0 diff --git a/backend/oasst_backend/api/v1/admin.py b/backend/oasst_backend/api/v1/admin.py new file mode 100644 index 00000000..e8d3078e --- /dev/null +++ b/backend/oasst_backend/api/v1/admin.py @@ -0,0 +1,31 @@ +import pydantic +from fastapi import APIRouter, Depends +from loguru import logger +from oasst_backend.api import deps + +router = APIRouter() + + +class CreateApiClientRequest(pydantic.BaseModel): + description: str + frontend_type: str + trusted: bool | None = False + admin_email: str | None = None + + +@router.post("/api_client") +async def create_api_client( + request: CreateApiClientRequest, + root_token: str = Depends(deps.get_root_token), + session: deps.Session = Depends(deps.get_db), +) -> str: + logger.info(f"Creating new api client with {request=}") + api_client = deps.create_api_client( + session=session, + description=request.description, + frontend_type=request.frontend_type, + trusted=request.trusted, + admin_email=request.admin_email, + ) + logger.info(f"Created api_client with key {api_client.api_key}") + return api_client.api_key diff --git a/backend/oasst_backend/api/v1/api.py b/backend/oasst_backend/api/v1/api.py index 0c68b5c9..2931ac05 100644 --- a/backend/oasst_backend/api/v1/api.py +++ b/backend/oasst_backend/api/v1/api.py @@ -1,5 +1,6 @@ from fastapi import APIRouter from oasst_backend.api.v1 import ( + admin, frontend_messages, frontend_users, hugging_face, @@ -21,3 +22,4 @@ api_router.include_router(frontend_users.router, prefix="/frontend_users", tags= api_router.include_router(stats.router, prefix="/stats", tags=["stats"]) api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"]) api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"]) +api_router.include_router(admin.router, prefix="/admin", tags=["admin"]) diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 1b182a87..005f43dd 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -71,7 +71,7 @@ class Settings(BaseSettings): REDIS_HOST: str = "localhost" REDIS_PORT: str = "6379" - DEBUG_ALLOW_ANY_API_KEY: bool = False + DEBUG_ALLOW_DEBUG_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = ( @@ -83,6 +83,8 @@ class Settings(BaseSettings): HUGGING_FACE_API_KEY: str = "" + ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list + @validator("DATABASE_URI", pre=True) def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: if isinstance(v, str): diff --git a/copilot/api/manifest.yml b/copilot/api/manifest.yml index b6ff6cf7..59848a25 100644 --- a/copilot/api/manifest.yml +++ b/copilot/api/manifest.yml @@ -29,7 +29,7 @@ environments: variables: # Note: this has to be a valid JSON list for Pydantic to parse it. BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]' - DEBUG_ALLOW_ANY_API_KEY: True + DEBUG_ALLOW_DEBUG_API_KEY: True DEBUG_SKIP_API_KEY_CHECK: True MAX_WORKERS: 1 diff --git a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py index e6eba233..b4432252 100644 --- a/oasst-shared/oasst_shared/exceptions/oasst_api_error.py +++ b/oasst-shared/oasst_shared/exceptions/oasst_api_error.py @@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum): GENERIC_ERROR = 0 DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 + ROOT_TOKEN_NOT_AUTHORIZED = 3 TOO_MANY_REQUESTS = 429 SERVER_ERROR0 = 500 diff --git a/scripts/backend-development/run-local.sh b/scripts/backend-development/run-local.sh index 22701ace..7d3f715c 100755 --- a/scripts/backend-development/run-local.sh +++ b/scripts/backend-development/run-local.sh @@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) # switch to backend directory pushd "$parent_path/../../backend" -export DEBUG_SKIP_API_KEY_CHECK=True +export DEBUG_SKIP_API_KEY_CHECK=False export DEBUG_USE_SEED_DATA=True export DEBUG_SKIP_TOXICITY_CALCULATION=True export DEBUG_ALLOW_SELF_LABELING=True