Merge branch 'LAION-AI:main' into add-devcontainer

This commit is contained in:
Andrew Maguire
2023-01-02 17:54:44 +00:00
committed by GitHub
8 changed files with 152 additions and 3 deletions
+26
View File
@@ -1,4 +1,5 @@
from http import HTTPStatus
from math import ceil
from pathlib import Path
from typing import Optional
@@ -6,6 +7,8 @@ import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
from fastapi_limiter import FastAPILimiter
from loguru import logger
from oasst_backend.api.deps import get_dummy_api_client
from oasst_backend.api.v1.api import api_router
@@ -62,6 +65,29 @@ if settings.UPDATE_ALEMBIC:
logger.exception("Alembic upgrade failed on startup")
if settings.RATE_LIMIT:
@app.on_event("startup")
async def connect_redis():
async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int):
"""Error callback function when too many requests"""
expire = ceil(pexpire / 1000)
raise OasstError(
f"Too Many Requests. Retry After {expire} seconds.",
OasstErrorCode.TOO_MANY_REQUESTS,
HTTPStatus.TOO_MANY_REQUESTS,
)
try:
redis_client = redis.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True
)
logger.info(f"Connected to {redis_client=}")
await FastAPILimiter.init(redis_client, http_callback=http_callback)
except Exception:
logger.exception("Failed to establish Redis connection")
if settings.DEBUG_USE_SEED_DATA:
@app.on_event("startup")
+57 -1
View File
@@ -3,8 +3,9 @@ from secrets import token_hex
from typing import Generator
from uuid import UUID
from fastapi import Depends, Security
from fastapi import Depends, Request, Response, Security
from fastapi.security.api_key import APIKey, APIKeyHeader, APIKeyQuery
from fastapi_limiter.depends import RateLimiter
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.database import engine
@@ -84,3 +85,58 @@ def get_trusted_api_client(
http_status_code=HTTPStatus.FORBIDDEN,
)
return client
class UserRateLimiter(RateLimiter):
def __init__(
self, times: int = 100, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
user = (await request.json()).get("user")
return f"{api_key}:{user.get('id')}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip when api_key and user information are not available
# (such that it will be handled by `APIClientRateLimiter`)
if not api_key or not user or not user.get("id"):
return
return await super().__call__(request, response)
class APIClientRateLimiter(RateLimiter):
def __init__(
self, times: int = 10_000, milliseconds: int = 0, seconds: int = 0, minutes: int = 1, hours: int = 0
) -> None:
async def identifier(request: Request) -> str:
"""Identify a request based on api_key and user.id"""
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
return f"{api_key}"
super().__init__(times, milliseconds, seconds, minutes, hours, identifier)
async def __call__(self, request: Request, response: Response, api_key: str = Depends(get_api_key)) -> None:
# Skip if rate limiting is disabled
if not settings.RATE_LIMIT:
return
# Attempt to retrieve api_key and user information
user = (await request.json()).get("user")
# Skip if user information is available
# (such that it will be handled by `UserRateLimiter`)
if not api_key or user:
return
return await super().__call__(request, response)
+8 -1
View File
@@ -127,7 +127,14 @@ def generate_task(
return task, message_tree_id, parent_message_id
@router.post("/", response_model=protocol_schema.AnyTask) # work with Union once more types are added
@router.post(
"/",
response_model=protocol_schema.AnyTask,
dependencies=[
Depends(deps.UserRateLimiter(times=100, minutes=5)),
Depends(deps.APIClientRateLimiter(times=10_000, minutes=1)),
],
) # work with Union once more types are added
def request_task(
*,
db: Session = Depends(deps.get_db),
+4
View File
@@ -14,6 +14,10 @@ class Settings(BaseSettings):
POSTGRES_DB: str = "postgres"
DATABASE_URI: Optional[PostgresDsn] = None
RATE_LIMIT: bool = True
REDIS_HOST: str = "localhost"
REDIS_PORT: str = "6379"
DEBUG_ALLOW_ANY_API_KEY: bool = False
DEBUG_SKIP_API_KEY_CHECK: bool = False
DEBUG_USE_SEED_DATA: bool = False
+1
View File
@@ -17,6 +17,7 @@ class OasstErrorCode(IntEnum):
DATABASE_URI_NOT_SET = 1
API_CLIENT_NOT_AUTHORIZED = 2
SERVER_ERROR = 3
TOO_MANY_REQUESTS = 429
# 1000-2000: tasks endpoint
TASK_INVALID_REQUEST_TYPE = 1000
+1
View File
@@ -1,5 +1,6 @@
alembic==1.8.1
fastapi==0.88.0
fastapi-limiter==0.1.5
loguru==0.6.0
numpy==1.22.4
psycopg2-binary==2.9.5
+2 -1
View File
@@ -4,7 +4,7 @@ services:
# Use `docker compose up backend-dev --build --attach-dependencies` to start a database and work and the backend.
backend-dev:
image: sverrirab/sleep
depends_on: [db, adminer]
depends_on: [db, adminer, redis, redis-insights]
# Use `docker compose up frontend-dev --build --attach-dependencies` to start all services needed to work on the frontend.
frontend-dev:
@@ -91,6 +91,7 @@ services:
image: oasst-backend
environment:
- POSTGRES_HOST=db
- REDIS_HOST=redis
- DEBUG_SKIP_API_KEY_CHECK=True
- DEBUG_USE_SEED_DATA=True
- MAX_WORKERS=1
+53
View File
@@ -0,0 +1,53 @@
#!/usr/bin/env python3
"""This file is for moderators to verify new users in the lobby.
First, moderators read the brief introduction people write in the lobby.
If all people's introductions are acceptable, moderators run this script.
Needs BOT_TOKEN environment variable to be set to the bot token.
"""
import discord
import pydantic
import tqdm.asyncio as tqdm
class Settings(pydantic.BaseSettings):
bot_token: str
settings = Settings()
intents = discord.Intents.default()
intents.message_content = True
intents.members = True
client = discord.Client(intents=intents)
@client.event
async def on_ready():
lobby_channel = discord.utils.get(client.get_all_channels(), name="lobby")
# obtain the role object for the verified role
verified_role = discord.utils.get(lobby_channel.guild.roles, name="verified")
async for message in tqdm.tqdm(lobby_channel.history(limit=None)):
if not isinstance(message.author, discord.Member):
print(f"{message.author} is not a member")
continue
for role in message.author.roles:
if role.name == "unverified":
print(f"{message.author} has the unverified role.")
break
else:
continue
# un-assign the unverified role
await message.author.remove_roles(role)
# assign the verified role
await message.author.add_roles(verified_role)
print(f"Assigned verified role to {message.author}")
await client.close()
client.run(settings.bot_token)