From 144a5bc424679c5cff50365d764dfd7b2c4f98a3 Mon Sep 17 00:00:00 2001 From: kiritowu Date: Mon, 2 Jan 2023 08:04:59 +0800 Subject: [PATCH] Initialise redis connection for Fastapi limiter --- backend/main.py | 26 ++++++++++++++++++++++++++ backend/oasst_backend/config.py | 4 ++++ 2 files changed, 30 insertions(+) diff --git a/backend/main.py b/backend/main.py index 9cf43701..81adcccd 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from http import HTTPStatus +from math import ceil from pathlib import Path from typing import Optional @@ -7,6 +8,8 @@ import alembic.command import alembic.config import fastapi import pydantic +import redis.asyncio as redis +from fastapi_limiter import FastAPILimiter from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router @@ -63,6 +66,29 @@ if settings.UPDATE_ALEMBIC: logger.exception("Alembic upgrade failed on startup") +if settings.RATE_LIMIT: + + @app.on_event("startup") + async def connect_redis(): + async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int): + """Error callback function when too many requests""" + expire = ceil(pexpire / 1000) + raise OasstError( + f"Too Many Requests. Retry After {expire} seconds.", + OasstErrorCode.TOO_MANY_REQUESTS, + HTTPStatus.TOO_MANY_REQUESTS, + ) + + try: + redis_client = redis.from_url( + f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True + ) + logger.info(f"Connected to {redis_client=}") + await FastAPILimiter.init(redis_client, http_callback=http_callback) + except Exception: + logger.exception("Failed to establish Redis connection") + + if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 602780be..cfe0d4c7 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -15,6 +15,10 @@ class Settings(BaseSettings): POSTGRES_DB: str = "postgres" DATABASE_URI: Optional[PostgresDsn] = None + RATE_LIMIT: bool = True + REDIS_HOST: str = "localhost" + REDIS_PORT: str = "6379" + DEBUG_ALLOW_ANY_API_KEY: bool = False DEBUG_SKIP_API_KEY_CHECK: bool = False DEBUG_USE_SEED_DATA: bool = False