mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
added root tokens and endpoint for adding api keys (#742)
* added root tokens and endpoint for adding api keys * Change down revision to current alembic head * removed added_by_root_token * refactored description * fixed jinja errors Co-authored-by: Andreas Köpf <andreas.koepf@provisio.com>
This commit is contained in:
@@ -32,6 +32,7 @@
|
||||
name: "oasst-{{ stack_name }}-redis"
|
||||
image: redis
|
||||
state: started
|
||||
recreate: "{{ (stack_name == 'dev') | bool }}"
|
||||
restart_policy: always
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
healthcheck:
|
||||
@@ -48,6 +49,7 @@
|
||||
name: "oasst-{{ stack_name }}-postgres-{{ item.name }}"
|
||||
image: postgres:15
|
||||
state: started
|
||||
recreate: "{{ (stack_name == 'dev') | bool }}"
|
||||
restart_policy: always
|
||||
network_mode: "oasst-{{ stack_name }}"
|
||||
env:
|
||||
@@ -75,11 +77,12 @@
|
||||
env:
|
||||
POSTGRES_HOST: "oasst-{{ stack_name }}-postgres-backend"
|
||||
REDIS_HOST: "oasst-{{ stack_name }}-redis"
|
||||
DEBUG_ALLOW_ANY_API_KEY: "true"
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: "true"
|
||||
DEBUG_USE_SEED_DATA: "true"
|
||||
DEBUG_ALLOW_SELF_LABELING: "true"
|
||||
DEBUG_ALLOW_SELF_LABELING:
|
||||
"{{ 'true' if stack_name == 'dev' else 'false' }}"
|
||||
MAX_WORKERS: "1"
|
||||
RATE_LIMIT: "false"
|
||||
RATE_LIMIT: "{{ 'false' if stack_name == 'dev' else 'true' }}"
|
||||
DEBUG_SKIP_EMBEDDING_COMPUTATION: "true"
|
||||
DEBUG_SKIP_TOXICITY_CALCULATION: "true"
|
||||
ports:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from http import HTTPStatus
|
||||
from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from loguru import logger
|
||||
@@ -22,6 +22,8 @@ def get_db() -> Generator:
|
||||
api_key_query = APIKeyQuery(name="api_key", auto_error=False)
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
bearer_token = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
async def get_api_key(
|
||||
api_key_query: str = Security(api_key_query),
|
||||
@@ -33,22 +35,47 @@ async def get_api_key(
|
||||
return api_key_header
|
||||
|
||||
|
||||
def get_dummy_api_client(db: Session) -> ApiClient:
|
||||
def create_api_client(
|
||||
*,
|
||||
session: Session,
|
||||
description: str,
|
||||
frontend_type: str,
|
||||
trusted: bool | None = False,
|
||||
admin_email: str | None = None,
|
||||
api_key: str | None = None,
|
||||
) -> ApiClient:
|
||||
if api_key is None:
|
||||
api_key = token_hex(32)
|
||||
|
||||
logger.info(f"Creating new api client with {api_key=}")
|
||||
api_client = ApiClient(
|
||||
api_key=api_key,
|
||||
description=description,
|
||||
frontend_type=frontend_type,
|
||||
trusted=trusted,
|
||||
admin_email=admin_email,
|
||||
)
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
session.refresh(api_client)
|
||||
return api_client
|
||||
|
||||
|
||||
def get_dummy_api_client(session: Session) -> ApiClient:
|
||||
# make sure that a dummy api key exits in db (foreign key references)
|
||||
ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444")
|
||||
api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first()
|
||||
DUMMY_API_KEY = "1234"
|
||||
api_client: ApiClient = session.query(ApiClient).filter(ApiClient.api_key == DUMMY_API_KEY).first()
|
||||
if api_client is None:
|
||||
token = token_hex(32)
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {token}")
|
||||
api_client = ApiClient(
|
||||
id=ANY_API_KEY_ID,
|
||||
api_key=token,
|
||||
description="ANY_API_KEY, random token",
|
||||
logger.info(f"ANY_API_KEY missing, inserting api_key: {DUMMY_API_KEY}")
|
||||
api_client = create_api_client(
|
||||
session=session,
|
||||
api_key=DUMMY_API_KEY,
|
||||
description="Dummy api key for debugging",
|
||||
trusted=True,
|
||||
frontend_type="Test frontend",
|
||||
)
|
||||
db.add(api_client)
|
||||
db.commit()
|
||||
session.add(api_client)
|
||||
session.commit()
|
||||
return api_client
|
||||
|
||||
|
||||
@@ -58,7 +85,7 @@ def api_auth(
|
||||
) -> ApiClient:
|
||||
if api_key or settings.DEBUG_SKIP_API_KEY_CHECK:
|
||||
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY:
|
||||
if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_DEBUG_API_KEY:
|
||||
return get_dummy_api_client(db)
|
||||
|
||||
api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first()
|
||||
@@ -93,6 +120,18 @@ def get_trusted_api_client(
|
||||
return client
|
||||
|
||||
|
||||
def get_root_token(bearer_token: HTTPAuthorizationCredentials = Security(bearer_token)) -> str:
|
||||
if bearer_token:
|
||||
token = bearer_token.credentials
|
||||
if token and token in settings.ROOT_TOKENS:
|
||||
return token
|
||||
raise OasstError(
|
||||
"Could not validate credentials",
|
||||
error_code=OasstErrorCode.ROOT_TOKEN_NOT_AUTHORIZED,
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
class UserRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import pydantic
|
||||
from fastapi import APIRouter, Depends
|
||||
from loguru import logger
|
||||
from oasst_backend.api import deps
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CreateApiClientRequest(pydantic.BaseModel):
|
||||
description: str
|
||||
frontend_type: str
|
||||
trusted: bool | None = False
|
||||
admin_email: str | None = None
|
||||
|
||||
|
||||
@router.post("/api_client")
|
||||
async def create_api_client(
|
||||
request: CreateApiClientRequest,
|
||||
root_token: str = Depends(deps.get_root_token),
|
||||
session: deps.Session = Depends(deps.get_db),
|
||||
) -> str:
|
||||
logger.info(f"Creating new api client with {request=}")
|
||||
api_client = deps.create_api_client(
|
||||
session=session,
|
||||
description=request.description,
|
||||
frontend_type=request.frontend_type,
|
||||
trusted=request.trusted,
|
||||
admin_email=request.admin_email,
|
||||
)
|
||||
logger.info(f"Created api_client with key {api_client.api_key}")
|
||||
return api_client.api_key
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from oasst_backend.api.v1 import (
|
||||
admin,
|
||||
frontend_messages,
|
||||
frontend_users,
|
||||
hugging_face,
|
||||
@@ -21,3 +22,4 @@ api_router.include_router(frontend_users.router, prefix="/frontend_users", tags=
|
||||
api_router.include_router(stats.router, prefix="/stats", tags=["stats"])
|
||||
api_router.include_router(leaderboards.router, prefix="/leaderboards", tags=["leaderboards"])
|
||||
api_router.include_router(hugging_face.router, prefix="/hf", tags=["hugging_face"])
|
||||
api_router.include_router(admin.router, prefix="/admin", tags=["admin"])
|
||||
|
||||
@@ -71,7 +71,7 @@ class Settings(BaseSettings):
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: str = "6379"
|
||||
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
DEBUG_USE_SEED_DATA_PATH: Optional[FilePath] = (
|
||||
@@ -83,6 +83,8 @@ class Settings(BaseSettings):
|
||||
|
||||
HUGGING_FACE_API_KEY: str = ""
|
||||
|
||||
ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list
|
||||
|
||||
@validator("DATABASE_URI", pre=True)
|
||||
def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
|
||||
@@ -29,7 +29,7 @@ environments:
|
||||
variables:
|
||||
# Note: this has to be a valid JSON list for Pydantic to parse it.
|
||||
BACKEND_CORS_ORIGINS: '["https://web.staging.open-assistant.surfacedata.org"]'
|
||||
DEBUG_ALLOW_ANY_API_KEY: True
|
||||
DEBUG_ALLOW_DEBUG_API_KEY: True
|
||||
DEBUG_SKIP_API_KEY_CHECK: True
|
||||
MAX_WORKERS: 1
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
|
||||
GENERIC_ERROR = 0
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
ROOT_TOKEN_NOT_AUTHORIZED = 3
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
SERVER_ERROR0 = 500
|
||||
|
||||
@@ -4,7 +4,7 @@ parent_path=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P )
|
||||
# switch to backend directory
|
||||
pushd "$parent_path/../../backend"
|
||||
|
||||
export DEBUG_SKIP_API_KEY_CHECK=True
|
||||
export DEBUG_SKIP_API_KEY_CHECK=False
|
||||
export DEBUG_USE_SEED_DATA=True
|
||||
export DEBUG_SKIP_TOXICITY_CALCULATION=True
|
||||
export DEBUG_ALLOW_SELF_LABELING=True
|
||||
|
||||
Reference in New Issue
Block a user