from http import HTTPStatus from math import ceil from pathlib import Path from typing import Optional import alembic.command import alembic.config import fastapi import pydantic import redis.asyncio as redis from fastapi_limiter import FastAPILimiter from loguru import logger from oasst_backend.api.deps import get_dummy_api_client from oasst_backend.api.v1.api import api_router from oasst_backend.config import settings from oasst_backend.database import engine from oasst_backend.prompt_repository import PromptRepository from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from sqlmodel import Session from starlette.middleware.cors import CORSMiddleware app = fastapi.FastAPI(title=settings.PROJECT_NAME, openapi_url=f"{settings.API_V1_STR}/openapi.json") @app.exception_handler(OasstError) async def oasst_exception_handler(request: fastapi.Request, ex: OasstError): logger.error(f"{request.method} {request.url} failed: {repr(ex)}") return fastapi.responses.JSONResponse( status_code=int(ex.http_status_code), content=protocol_schema.OasstErrorResponse( message=ex.message, error_code=OasstErrorCode(ex.error_code), ).dict(), ) @app.exception_handler(Exception) async def unhandled_exception_handler(request: fastapi.Request, ex: Exception): logger.exception(f"{request.method} {request.url} failed [UNHANDLED]: {repr(ex)}") status = HTTPStatus.INTERNAL_SERVER_ERROR return fastapi.responses.JSONResponse( status_code=status.value, content={"message": status.name, "error_code": OasstErrorCode.GENERIC_ERROR} ) # Set all CORS enabled origins if settings.BACKEND_CORS_ORIGINS: app.add_middleware( CORSMiddleware, allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if settings.UPDATE_ALEMBIC: @app.on_event("startup") def alembic_upgrade(): logger.info("Attempting to upgrade alembic on startup") try: alembic_ini_path = Path(__file__).parent / "alembic.ini" alembic_cfg = alembic.config.Config(str(alembic_ini_path)) alembic_cfg.set_main_option("sqlalchemy.url", settings.DATABASE_URI) alembic.command.upgrade(alembic_cfg, "head") logger.info("Successfully upgraded alembic on startup") except Exception: logger.exception("Alembic upgrade failed on startup") if settings.RATE_LIMIT: @app.on_event("startup") async def connect_redis(): async def http_callback(request: fastapi.Request, response: fastapi.Response, pexpire: int): """Error callback function when too many requests""" expire = ceil(pexpire / 1000) raise OasstError( f"Too Many Requests. Retry After {expire} seconds.", OasstErrorCode.TOO_MANY_REQUESTS, HTTPStatus.TOO_MANY_REQUESTS, ) try: redis_client = redis.from_url( f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/0", encoding="utf-8", decode_responses=True ) logger.info(f"Connected to {redis_client=}") await FastAPILimiter.init(redis_client, http_callback=http_callback) except Exception: logger.exception("Failed to establish Redis connection") if settings.DEBUG_USE_SEED_DATA: @app.on_event("startup") def seed_data(): class DummyMessage(pydantic.BaseModel): task_message_id: str user_message_id: str parent_message_id: Optional[str] text: str role: str try: logger.info("Seed data check began") with Session(engine) as db: api_client = get_dummy_api_client(db) dummy_user = protocol_schema.User(id="__dummy_user__", display_name="Dummy User", auth_method="local") pr = PromptRepository(db=db, api_client=api_client, user=dummy_user) dummy_messages = [ DummyMessage( task_message_id="de111fa8", user_message_id="6f1d0711", parent_message_id=None, text="Hi!", role="prompter", ), DummyMessage( task_message_id="74c381d4", user_message_id="4a24530b", parent_message_id="6f1d0711", text="Hello! How can I help you?", role="assistant", ), DummyMessage( task_message_id="3d5dc440", user_message_id="a8c01c04", parent_message_id="4a24530b", text="Do you have a recipe for potato soup?", role="prompter", ), DummyMessage( task_message_id="643716c1", user_message_id="f43a93b7", parent_message_id="4a24530b", text="Who were the 8 presidents before George Washington?", role="prompter", ), DummyMessage( task_message_id="2e4e1e6", user_message_id="c886920", parent_message_id="6f1d0711", text="Hey buddy! How can I serve you?", role="assistant", ), DummyMessage( task_message_id="970c437d", user_message_id="cec432cf", parent_message_id=None, text="euirdteunvglfe23908230892309832098 AAAAAAAA", role="prompter", ), DummyMessage( task_message_id="6066118e", user_message_id="4f85f637", parent_message_id="cec432cf", text="Sorry, I did not understand your request and it is unclear to me what you want me to do. Could you describe it in a different way?", role="assistant", ), DummyMessage( task_message_id="ba87780d", user_message_id="0e276b98", parent_message_id="cec432cf", text="I'm unsure how to interpret this. Is it a riddle?", role="assistant", ), ] for msg in dummy_messages: task = pr.fetch_task_by_frontend_message_id(msg.task_message_id) if task and not task.ack: logger.warning("Deleting unacknowledged seed data task") db.delete(task) task = None if not task: if msg.parent_message_id is None: task = pr.store_task( protocol_schema.InitialPromptTask(hint=""), message_tree_id=None, parent_message_id=None ) else: parent_message = pr.fetch_message_by_frontend_message_id( msg.parent_message_id, fail_if_missing=True ) task = pr.store_task( protocol_schema.AssistantReplyTask( conversation=protocol_schema.Conversation( messages=[protocol_schema.ConversationMessage(text="dummy", is_assistant=False)] ) ), message_tree_id=parent_message.message_tree_id, parent_message_id=parent_message.id, ) pr.bind_frontend_message_id(task.id, msg.task_message_id) message = pr.store_text_reply(msg.text, msg.task_message_id, msg.user_message_id) logger.info( f"Inserted: message_id: {message.id}, payload: {message.payload.payload}, parent_message_id: {message.parent_id}" ) else: logger.debug(f"seed data task found: {task.id}") logger.info("Seed data check completed") except Exception: logger.exception("Seed data insertion failed") app.include_router(api_router, prefix=settings.API_V1_STR) def get_openapi_schema(): return json.dumps(app.openapi()) if __name__ == "__main__": # Importing here so we don't import packages unnecessarily if we're # importing main as a module. import argparse import json import uvicorn parser = argparse.ArgumentParser() parser.add_argument( "--print-openapi-schema", help="Dumps the openapi schema to stdout", action=argparse.BooleanOptionalAction, ) parser.add_argument("--host", help="The host to run the server") parser.add_argument("--port", help="The port to run the server") args = parser.parse_args() if args.print_openapi_schema: print(get_openapi_schema()) else: uvicorn.run(app, host=args.host, port=args.port)