diff --git a/backend/main.py b/backend/main.py index 807d52ce..de173f9d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -27,6 +27,7 @@ from oasst_backend.utils.database_utils import CommitMode, managed_tx_function from oasst_shared.exceptions import OasstError, OasstErrorCode from oasst_shared.schemas import protocol as protocol_schema from oasst_shared.utils import utcnow +from prometheus_fastapi_instrumentator import Instrumentator from pydantic import BaseModel from sqlmodel import Session, select from starlette.middleware.cors import CORSMiddleware @@ -100,6 +101,13 @@ if settings.OFFICIAL_WEB_API_KEY: ) +if settings.ENABLE_PROM_METRICS: + + @app.on_event("startup") + async def enable_prom_metrics(): + Instrumentator().instrument(app).expose(app) + + if settings.RATE_LIMIT: @app.on_event("startup") diff --git a/backend/oasst_backend/config.py b/backend/oasst_backend/config.py index 36de7d9b..43d08097 100644 --- a/backend/oasst_backend/config.py +++ b/backend/oasst_backend/config.py @@ -197,6 +197,8 @@ class Settings(BaseSettings): ROOT_TOKENS: List[str] = ["1234"] # supply a string that can be parsed to a json list + ENABLE_PROM_METRICS: bool = True # enable prometheus metrics at /metrics + @validator("DATABASE_URI", pre=True) def assemble_db_connection(cls, v: Optional[str], values: Dict[str, Any]) -> Any: if isinstance(v, str): diff --git a/backend/requirements.txt b/backend/requirements.txt index 4a0008bb..e0f8c725 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -6,6 +6,7 @@ fastapi-limiter==0.1.5 fastapi-utils==0.2.1 loguru==0.6.0 numpy==1.22.4 +prometheus-fastapi-instrumentator==5.9.1 psycopg2-binary==2.9.5 pydantic==1.9.1 pydantic[email]==1.9.1 diff --git a/inference/server/main.py b/inference/server/main.py index 4b2474da..de0f607d 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -9,10 +9,18 @@ import websockets.exceptions from fastapi.middleware.cors import CORSMiddleware from loguru import logger from oasst_shared.schemas import inference, protocol +from prometheus_fastapi_instrumentator import Instrumentator from sse_starlette.sse import EventSourceResponse app = fastapi.FastAPI() + +# add prometheus metrics at /metrics +@app.on_event("startup") +async def enable_prom_metrics(): + Instrumentator().instrument(app).expose(app) + + # Allow CORS app.add_middleware( CORSMiddleware, diff --git a/inference/server/requirements.txt b/inference/server/requirements.txt index e0a00339..4790b471 100644 --- a/inference/server/requirements.txt +++ b/inference/server/requirements.txt @@ -1,5 +1,6 @@ fastapi[all] loguru +prometheus-fastapi-instrumentator==5.9.1 pydantic redis sse-starlette