From 703c1129ddac20949769c7e8d1fe2d26cee0bc8d Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Sun, 25 Dec 2022 18:24:49 +0100 Subject: [PATCH] re-introduced ALLOW_ANY_API_KEY --- ansible/dev.yaml | 2 +- backend/oasst_backend/api/deps.py | 32 +++++++++++++++---------------- backend/oasst_backend/config.py | 1 + 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/ansible/dev.yaml b/ansible/dev.yaml index 6e03a9d2..1ff0e168 100644 --- a/ansible/dev.yaml +++ b/ansible/dev.yaml @@ -50,7 +50,7 @@ network_mode: oasst env: POSTGRES_HOST: oasst-postgres - DEBUG_SKIP_API_KEY_CHECK: "true" + DEBUG_ALLOW_ANY_API_KEY: "true" MAX_WORKERS: "1" ports: - 8080:8080 diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index 98b08078..505fa2c6 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -37,21 +37,21 @@ def api_auth( db: Session, ) -> ApiClient: - if api_key is not None or settings.DEBUG_SKIP_API_KEY_CHECK: - if settings.DEBUG_SKIP_API_KEY_CHECK: - # make sure that a dummy api key exits in db (foreign key references) - ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") - api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() - if api_client is None: - token = token_hex(32) - logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") - api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") - db.add(api_client) - db.commit() - return api_client + if api_key is None and not settings.DEBUG_SKIP_API_KEY_CHECK: + raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials") - api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() - if api_client is not None and api_client.enabled: - return api_client + if settings.DEBUG_SKIP_API_KEY_CHECK or settings.DEBUG_ALLOW_ANY_API_KEY: + # make sure that a dummy api key exits in db (foreign key references) + ANY_API_KEY_ID = UUID("00000000-1111-2222-3333-444444444444") + api_client: ApiClient = db.query(ApiClient).filter(ApiClient.id == ANY_API_KEY_ID).first() + if api_client is None: + token = token_hex(32) + logger.info(f"ANY_API_KEY missing, inserting api_key: {token}") + api_client = ApiClient(id=ANY_API_KEY_ID, api_key=token, description="ANY_API_KEY, random token") + db.add(api_client) + db.commit() + return api_client - raise HTTPException(status_code=HTTP_403_FORBIDDEN, detail="Could not validate credentials") + api_client = db.query(ApiClient).filter(ApiClient.api_key == api_key).first() + if api_client is not None and api_client.enabled: + return api_client diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 5516635f..96d6021e 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -15,6 +15,7 @@ class Settings(BaseSettings): POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + DEBUG_ALLOW_ANY_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False @validator("DATABASE_URI", pre=True)