added root tokens and endpoint for adding api keys (#742)

* added root tokens and endpoint for adding api keys

* Change down revision to current alembic head

* removed added_by_root_token

* refactored description

* fixed jinja errors

Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
This commit is contained in:
Yannic Kilcher
2023-01-15 22:24:25 +01:00
committed by GitHub
parent e58ffd64fa
commit cc03376d86
8 changed files with 97 additions and 19 deletions
+6 -3
View File
@@ -32,6 +32,7 @@
name: "oasst-{{ stack_name }}-redis"
image: redis
state: started
recreate: "{{ (stack_name == 'dev') | bool }}"
restart_policy: always
network_mode: "oasst-{{ stack_name }}"
healthcheck:
@@ -48,6 +49,7 @@
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
image: postgres:15
state: started
recreate: "{{ (stack_name == 'dev') | bool }}"
restart_policy: always
network_mode: "oasst-{{ stack_name }}"
env:
@@ -75,11 +77,12 @@
env:
POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend"
REDIS_HOST: "oasst-{{ stack_name }}-redis"
DEBUG_ALLOW_ANY_API_KEY: "true"
DEBUG_ALLOW_DEBUG_API_KEY: "true"
DEBUG_USE_SEED_DATA: "true"
DEBUG_ALLOW_SELF_LABELING: "true"
DEBUG_ALLOW_SELF_LABELING:
"{{ 'true' if stack_name == 'dev' else 'false' }}"
MAX_WORKERS: "1"
RATE_LIMIT: "false"
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
ports:
+52 -13
View File
@@ -1,9 +1,9 @@
from http import HTTPStatus
from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Depends, Request, Response, Security
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from fastapi_limiter.depends import RateLimiter
from loguru import logger
@@ -22,6 +22,8 @@ def get_db() -> Generator:
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
bearer_token = HTTPBearer(auto_error=False)
async def get_api_key(
api_key_query: str = Security(api_key_query),
@@ -33,22 +35,47 @@ async def get_api_key(
return api_key_header
def get_dummy_api_client(db: Session) -> ApiClient:
def create_api_client(
*,
session: Session,
description: str,
frontend_type: str,
trusted: bool | None = False,
admin_email: str | None = None,
api_key: str | None = None,
) -> ApiClient:
if api_key is None:
api_key = token_hex(32)
logger.info(f"Creating new api client with {api_key=}")
api_client = ApiClient(
api_key=api_key,
description=description,
frontend_type=frontend_type,
trusted=trusted,
admin_email=admin_email,
)
session.add(api_client)
session.commit()
session.refresh(api_client)
return api_client
def get_dummy_api_client(session: Session) -> ApiClient:
# make sure that a dummy api key exits in db (foreign key references)
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
DUMMY_API_KEY = "1234"
api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first()
if api_client is None:
token = token_hex(32)
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
api_client = ApiClient(
id=ANY_API_KEY_ID,
api_key=token,
description="ANY_API_KEY, random token",
logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}")
api_client = create_api_client(
session=session,
api_key=DUMMY_API_KEY,
description="Dummy api key for debugging",
trusted=True,
frontend_type="Test frontend",
)
db.add(api_client)
db.commit()
session.add(api_client)
session.commit()
return api_client
@@ -58,7 +85,7 @@ def api_auth(
) -> ApiClient:
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY:
return get_dummy_api_client(db)
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
@@ -93,6 +120,18 @@ def get_trusted_api_client(
return client
def get_root_token(bearer_token: HTTPAuthorizationCredentials = Security(bearer_token)) -> str:
if bearer_token:
token = bearer_token.credentials
if token and token in settings.ROOT_TOKENS:
return token
raise OasstError(
"Could not validate credentials",
error_code=OasstErrorCode.ROOT_TOKEN_NOT_AUTHORIZED,
http_status_code=HTTPStatus.FORBIDDEN,
)
class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
+31
View File
@@ -0,0 +1,31 @@
import pydantic
from fastapi import APIRouter, Depends
from loguru import logger
from oasst_backend.api import deps
router = APIRouter()
class CreateApiClientRequest(pydantic.BaseModel):
description: str
frontend_type: str
trusted: bool | None = False
admin_email: str | None = None
@router.post("/api_client")
async def create_api_client(
request: CreateApiClientRequest,
root_token: str = Depends(deps.get_root_token),
session: deps.Session = Depends(deps.get_db),
) -> str:
logger.info(f"Creating new api client with {request=}")
api_client = deps.create_api_client(
session=session,
description=request.description,
frontend_type=request.frontend_type,
trusted=request.trusted,
admin_email=request.admin_email,
)
logger.info(f"Created api_client with key {api_client.api_key}")
return api_client.api_key
+2
View File
@@ -1,5 +1,6 @@
from fastapi import APIRouter
from oasst_backend.api.v1 import (
admin,
frontend_messages,
frontend_users,
hugging_face,
@@ -21,3 +22,4 @@ api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
+3 -1
View File
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
REDIS_HOST: str = "localhost"
REDIS_PORT: str = "6379"
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_ALLOW_DEBUG_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
@@ -83,6 +83,8 @@ class Settings(BaseSettings):
HUGGING_FACE_API_KEY: str = ""
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
@validator("DATABASE_URI", pre=True)
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
if isinstance(v, str):
+1 -1
View File
@@ -29,7 +29,7 @@ environments:
variables:
# Note: this has to be a valid JSON list for Pydantic to parse it.
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
DEBUG_ALLOW_ANY_API_KEY: True
DEBUG_ALLOW_DEBUG_API_KEY: True
DEBUG_SKIP_API_KEY_CHECK: True
MAX_WORKERS: 1
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
GENERIC_ERROR = 0
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
ROOT_TOKEN_NOT_AUTHORIZED = 3
TOO_MANY_REQUESTS = 429
SERVER_ERROR0 = 500
+1 -1
View File
@@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
# switch to backend directory
pushd "$parent_path/../../backend"
export DEBUG_SKIP_API_KEY_CHECK=True
export DEBUG_SKIP_API_KEY_CHECK=False
export DEBUG_USE_SEED_DATA=True
export DEBUG_SKIP_TOXICITY_CALCULATION=True
export DEBUG_ALLOW_SELF_LABELING=True