mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-29 16:30:24 +08:00
Merge pull request #281 from kiritowu/implement_fastapi_limiter
Implement Backend Fastapi limiter
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from http import HTTPStatus
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,6 +7,8 @@ import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from loguru import logger
|
||||
from oasst_backend.api.deps import get_dummy_api_client
|
||||
from oasst_backend.api.v1.api import api_router
|
||||
@@ -62,6 +65,29 @@ if settings.UPDATE_ALEMBIC:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
|
||||
|
||||
if settings.RATE_LIMIT:
|
||||
|
||||
@app.on_event("startup")
|
||||
async def connect_redis():
|
||||
async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
|
||||
"""Error callback function when too many requests"""
|
||||
expire = ceil(pexpire / 1000)
|
||||
raise OasstError(
|
||||
f"Too Many Requests. Retry After {expire} seconds.",
|
||||
OasstErrorCode.TOO_MANY_REQUESTS,
|
||||
HTTPStatus.TOO_MANY_REQUESTS,
|
||||
)
|
||||
|
||||
try:
|
||||
redis_client = redis.from_url(
|
||||
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True
|
||||
)
|
||||
logger.info(f"Connected to {redis_client=}")
|
||||
await FastAPILimiter.init(redis_client, http_callback=http_callback)
|
||||
except Exception:
|
||||
logger.exception("Failed to establish Redis connection")
|
||||
|
||||
|
||||
if settings.DEBUG_USE_SEED_DATA:
|
||||
|
||||
@app.on_event("startup")
|
||||
|
||||
@@ -3,8 +3,9 @@ from secrets import token_hex
|
||||
from typing import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, Security
|
||||
from fastapi import Depends, Request, Response, Security
|
||||
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
from loguru import logger
|
||||
from oasst_backend.config import settings
|
||||
from oasst_backend.database import engine
|
||||
@@ -84,3 +85,58 @@ def get_trusted_api_client(
|
||||
http_status_code=HTTPStatus.FORBIDDEN,
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class UserRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
) -> None:
|
||||
async def identifier(request: Request) -> str:
|
||||
"""Identify a request based on api_key and user.id"""
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
user = (await request.json()).get("user")
|
||||
return f"{api_key}:{user.get('id')}"
|
||||
|
||||
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
|
||||
|
||||
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.RATE_LIMIT:
|
||||
return
|
||||
|
||||
# Attempt to retrieve api_key and user information
|
||||
user = (await request.json()).get("user")
|
||||
|
||||
# Skip when api_key and user information are not available
|
||||
# (such that it will be handled by `APIClientRateLimiter`)
|
||||
if not api_key or not user or not user.get("id"):
|
||||
return
|
||||
|
||||
return await super().__call__(request, response)
|
||||
|
||||
|
||||
class APIClientRateLimiter(RateLimiter):
|
||||
def __init__(
|
||||
self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
|
||||
) -> None:
|
||||
async def identifier(request: Request) -> str:
|
||||
"""Identify a request based on api_key and user.id"""
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
return f"{api_key}"
|
||||
|
||||
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
|
||||
|
||||
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
|
||||
# Skip if rate limiting is disabled
|
||||
if not settings.RATE_LIMIT:
|
||||
return
|
||||
|
||||
# Attempt to retrieve api_key and user information
|
||||
user = (await request.json()).get("user")
|
||||
|
||||
# Skip if user information is available
|
||||
# (such that it will be handled by `UserRateLimiter`)
|
||||
if not api_key or user:
|
||||
return
|
||||
|
||||
return await super().__call__(request, response)
|
||||
|
||||
@@ -127,7 +127,14 @@ def generate_task(
|
||||
return task, message_tree_id, parent_message_id
|
||||
|
||||
|
||||
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=protocol_schema.AnyTask,
|
||||
dependencies=[
|
||||
Depends(deps.UserRateLimiter(times=100, minutes=5)),
|
||||
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
|
||||
],
|
||||
) # work with Union once more types are added
|
||||
def request_task(
|
||||
*,
|
||||
db: Session = Depends(deps.get_db),
|
||||
|
||||
@@ -14,6 +14,10 @@ class Settings(BaseSettings):
|
||||
POSTGRES_DB: str = "postgres"
|
||||
DATABASE_URI: Optional[PostgresDsn] = None
|
||||
|
||||
RATE_LIMIT: bool = True
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: str = "6379"
|
||||
|
||||
DEBUG_ALLOW_ANY_API_KEY: bool = False
|
||||
DEBUG_SKIP_API_KEY_CHECK: bool = False
|
||||
DEBUG_USE_SEED_DATA: bool = False
|
||||
|
||||
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
|
||||
DATABASE_URI_NOT_SET = 1
|
||||
API_CLIENT_NOT_AUTHORIZED = 2
|
||||
SERVER_ERROR = 3
|
||||
TOO_MANY_REQUESTS = 429
|
||||
|
||||
# 1000-2000: tasks endpoint
|
||||
TASK_INVALID_REQUEST_TYPE = 1000
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
alembic==1.8.1
|
||||
fastapi==0.88.0
|
||||
fastapi-limiter==0.1.5
|
||||
loguru==0.6.0
|
||||
numpy==1.22.4
|
||||
psycopg2-binary==2.9.5
|
||||
|
||||
+2
-1
@@ -4,7 +4,7 @@ services:
|
||||
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
|
||||
backend-dev:
|
||||
image: sverrirab/sleep
|
||||
depends_on: [db, adminer]
|
||||
depends_on: [db, adminer, redis, redis-insights]
|
||||
|
||||
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
|
||||
frontend-dev:
|
||||
@@ -91,6 +91,7 @@ services:
|
||||
image: oasst-backend
|
||||
environment:
|
||||
- POSTGRES_HOST=db
|
||||
- REDIS_HOST=redis
|
||||
- DEBUG_SKIP_API_KEY_CHECK=True
|
||||
- DEBUG_USE_SEED_DATA=True
|
||||
- MAX_WORKERS=1
|
||||
|
||||
Reference in New Issue
Block a user