Merge pull request #281 from kiritowu/implement_fastapi_limiter

Implement Backend Fastapi limiter
This commit is contained in:
Yannic Kilcher
2023-01-02 18:35:13 +01:00
committed by GitHub
7 changed files with 99 additions and 3 deletions
+26
View File
@@ -1,4 +1,5 @@
from http import HTTPStatus
from math import ceil
from pathlib import Path
from typing import Optional
@@ -6,6 +7,8 @@ import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
@@ -62,6 +65,29 @@ if settings.UPDATE_ALEMBIC:
logger.exception("Alembic upgrade failed on startup")
if settings.RATE_LIMIT:
@app.on_event("startup")
async def connect_redis():
async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
"""Error callback function when too many requests"""
expire = ceil(pexpire / 1000)
raise OasstError(
f"Too Many Requests. Retry After {expire} seconds.",
OasstErrorCode.TOO_MANY_REQUESTS,
HTTPStatus.TOO_MANY_REQUESTS,
)
try:
redis_client = redis.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True
)
logger.info(f"Connected to {redis_client=}")
await FastAPILimiter.init(redis_client, http_callback=http_callback)
except Exception:
logger.exception("Failed to establish Redis connection")
if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
+57 -1
View File
@@ -3,8 +3,9 @@ from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Depends, Security
from fastapi import Depends, Request, Response, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from fastapi_limiter.depends import RateLimiter
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
@@ -84,3 +85,58 @@ def get_trusted_api_client(
http_status_code=HTTPStatus.FORBIDDEN,
)
return client
class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
user = (await request.json()).get("user")
return f"{api_key}:{user.get('id')}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip when api_key and user information are not available
# (such that it will be handled by `APIClientRateLimiter`)
if not api_key or not user or not user.get("id"):
return
return await super().__call__(request, response)
class APIClientRateLimiter(RateLimiter):
def __init__(
self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
return f"{api_key}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip if user information is available
# (such that it will be handled by `UserRateLimiter`)
if not api_key or user:
return
return await super().__call__(request, response)
+8 -1
View File
@@ -127,7 +127,14 @@ def generate_task(
return task, message_tree_id, parent_message_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@router.post(
"/",
response_model=protocol_schema.AnyTask,
dependencies=[
Depends(deps.UserRateLimiter(times=100, minutes=5)),
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
],
) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
+4
View File
@@ -14,6 +14,10 @@ class Settings(BaseSettings):
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
RATE_LIMIT: bool = True
REDIS_HOST: str = "localhost"
REDIS_PORT: str = "6379"
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
+1
View File
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
SERVER_ERROR = 3
TOO_MANY_REQUESTS = 429
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
+1
View File
@@ -1,5 +1,6 @@
alembic==1.8.1
fastapi==0.88.0
fastapi-limiter==0.1.5
loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
+2 -1
View File
@@ -4,7 +4,7 @@ services:
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
backend-dev:
image: sverrirab/sleep
depends_on: [db, adminer]
depends_on: [db, adminer, redis, redis-insights]
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
frontend-dev:
@@ -91,6 +91,7 @@ services:
image: oasst-backend
environment:
- POSTGRES_HOST=db
- REDIS_HOST=redis
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1