diff --git a/backend/main.py b/backend/main.py index d9c35095..4eb4a436 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,4 +1,5 @@ from http import HTTPStatus +from math import ceil from pathlib import Path from typing import Optional @@ -6,6 +7,8 @@ import alembic.command import alembic.config import fastapi import pydantic +import redis.asyncio as redis +from fastapi_limiter import FastAPILimiter from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router @@ -62,6 +65,29 @@ if settings.UPDATE_ALEMBIC: logger.exception("Alembic upgrade failed on startup") +if settings.RATE_LIMIT: + + @app.on_event("startup") + async def connect_redis(): + async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int): + """Error callback function when too many requests""" + expire = ceil(pexpire / 1000) + raise OasstError( + f"Too Many Requests. Retry After {expire} seconds.", + OasstErrorCode.TOO_MANY_REQUESTS, + HTTPStatus.TOO_MANY_REQUESTS, + ) + + try: + redis_client = redis.from_url( + f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True + ) + logger.info(f"Connected to {redis_client=}") + await FastAPILimiter.init(redis_client, http_callback=http_callback) + except Exception: + logger.exception("Failed to establish Redis connection") + + if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") diff --git a/backend/oasst_backend/api/deps.py b/backend/oasst_backend/api/deps.py index fe26f0a6..103de4fd 100644 --- a/backend/oasst_backend/api/deps.py +++ b/backend/oasst_backend/api/deps.py @@ -3,8 +3,9 @@ from secrets import token_hex from typing import Generator from uuid import UUID -from fastapi import Depends, Security +from fastapi import Depends, Request, Response, Security from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery +from fastapi_limiter.depends import RateLimiter from loguru import logger from oasst_backend.config import settings from oasst_backend.database import engine @@ -84,3 +85,58 @@ def get_trusted_api_client( http_status_code=HTTPStatus.FORBIDDEN, ) return client + + +class UserRateLimiter(RateLimiter): + def __init__( + self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0 + ) -> None: + async def identifier(request: Request) -> str: + """Identify a request based on api_key and user.id""" + api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key") + user = (await request.json()).get("user") + return f"{api_key}:{user.get('id')}" + + super().__init__(times, milliseconds, seconds, minutes, hours, identifier) + + async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None: + # Skip if rate limiting is disabled + if not settings.RATE_LIMIT: + return + + # Attempt to retrieve api_key and user information + user = (await request.json()).get("user") + + # Skip when api_key and user information are not available + # (such that it will be handled by `APIClientRateLimiter`) + if not api_key or not user or not user.get("id"): + return + + return await super().__call__(request, response) + + +class APIClientRateLimiter(RateLimiter): + def __init__( + self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0 + ) -> None: + async def identifier(request: Request) -> str: + """Identify a request based on api_key and user.id""" + api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key") + return f"{api_key}" + + super().__init__(times, milliseconds, seconds, minutes, hours, identifier) + + async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None: + # Skip if rate limiting is disabled + if not settings.RATE_LIMIT: + return + + # Attempt to retrieve api_key and user information + user = (await request.json()).get("user") + + # Skip if user information is available + # (such that it will be handled by `UserRateLimiter`) + if not api_key or user: + return + + return await super().__call__(request, response) diff --git a/backend/oasst_backend/api/v1/tasks.py b/backend/oasst_backend/api/v1/tasks.py index aaa4a8c1..6571a5a9 100644 --- a/backend/oasst_backend/api/v1/tasks.py +++ b/backend/oasst_backend/api/v1/tasks.py @@ -127,7 +127,14 @@ def generate_task( return task, message_tree_id, parent_message_id -@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added +@router.post( + "/", + response_model=protocol_schema.AnyTask, + dependencies=[ + Depends(deps.UserRateLimiter(times=100, minutes=5)), + Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)), + ], +) # work with Union once more types are added def request_task( *, db: Session = Depends(deps.get_db), diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index c675c148..fef59832 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -14,6 +14,10 @@ class Settings(BaseSettings): POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + RATE_LIMIT: bool = True + REDIS_HOST: str = "localhost" + REDIS_PORT: str = "6379" + DEBUG_ALLOW_ANY_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False diff --git a/backend/oasst_backend/exceptions.py b/backend/oasst_backend/exceptions.py index b6fb2d7f..49eeb088 100644 --- a/backend/oasst_backend/exceptions.py +++ b/backend/oasst_backend/exceptions.py @@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum): DATABASE_URI_NOT_SET = 1 API_CLIENT_NOT_AUTHORIZED = 2 SERVER_ERROR = 3 + TOO_MANY_REQUESTS = 429 # 1000-2000: tasks endpoint TASK_INVALID_REQUEST_TYPE = 1000 diff --git a/backend/requirements.txt b/backend/requirements.txt index dd11aa18..fedf8ee3 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,5 +1,6 @@ alembic==1.8.1 fastapi==0.88.0 +fastapi-limiter==0.1.5 loguru==0.6.0 numpy==1.22.4 psycopg2-binary==2.9.5 diff --git a/docker-compose.yaml b/docker-compose.yaml index c8d1377e..dc147c73 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,7 +4,7 @@ services: # Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend. backend-dev: image: sverrirab/sleep - depends_on: [db, adminer] + depends_on: [db, adminer, redis, redis-insights] # Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend. frontend-dev: @@ -91,6 +91,7 @@ services: image: oasst-backend environment: - POSTGRES_HOST=db + - REDIS_HOST=redis - DEBUG_SKIP_API_KEY_CHECK=True - DEBUG_USE_SEED_DATA=True - MAX_WORKERS=1