diff --git a/backend/main.py b/backend/main.py index d9c35095..4eb4a436 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from math import ceil from pathlib import Path from typing import Optional @@ -6,6 +7,8 @@ import alembic.command import alembic.config import fastapi import pydantic +import redis.asyncio as redis +from fastapi_limiter import FastAPILimiter from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router @@ -62,6 +65,29 @@ if settings.UPDATE_ALEMBIC: logger.exception("Alembic upgrade failed on startup") +if settings.RATE_LIMIT: + + @app.on_event("startup") + async def connect_redis(): + async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int): + """Error callback function when too many requests""" + expire = ceil(pexpire / 1000) + raise OasstError( + f"Too Many Requests. Retry After {expire} seconds.", + OasstErrorCode.TOO_MANY_REQUESTS, + HTTPStatus.TOO_MANY_REQUESTS, + ) + + try: + redis_client = redis.from_url( + f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True + ) + logger.info(f"Connected to {redis_client=}") + await FastAPILimiter.init(redis_client, http_callback=http_callback) + except Exception: + logger.exception("Failed to establish Redis connection") + + if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index c675c148..fef59832 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -14,6 +14,10 @@ class Settings(BaseSettings): POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + RATE_LIMIT: bool = True + REDIS_HOST: str = "localhost" + REDIS_PORT: str = "6379" + DEBUG_ALLOW_ANY_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False