diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh index c5afb412..93031b8d 100755 --- a/inference/full-dev-setup.sh +++ b/inference/full-dev-setup.sh @@ -10,7 +10,7 @@ tmux split-window -h tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m tmux split-window -h tmux send-keys "cd server" C-m -tmux send-keys "uvicorn main:app --reload" C-m +tmux send-keys "DEBUG_API_KEYS='[\"0000\"]' uvicorn main:app --reload" C-m tmux split-window -h tmux send-keys "cd worker" C-m tmux send-keys "python __main__.py" C-m diff --git a/inference/server/alembic/versions/2023_02_11_0147-569cd595bb10_added_worker_db_entry.py b/inference/server/alembic/versions/2023_02_11_0147-569cd595bb10_added_worker_db_entry.py new file mode 100644 index 00000000..887dd809 --- /dev/null +++ b/inference/server/alembic/versions/2023_02_11_0147-569cd595bb10_added_worker_db_entry.py @@ -0,0 +1,35 @@ +"""added worker db entry + +Revision ID: 569cd595bb10 +Revises: 3a4cd8777eb2 +Create Date: 2023-02-11 01:47:34.880485 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "569cd595bb10" +down_revision = "3a4cd8777eb2" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "worker", + sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index(op.f("ix_worker_api_key"), "worker", ["api_key"], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f("ix_worker_api_key"), table_name="worker") + op.drop_table("worker") + # ### end Alembic commands ### diff --git a/inference/server/alembic/versions/2023_02_11_0151-4b846e7314b4_added_name_field.py b/inference/server/alembic/versions/2023_02_11_0151-4b846e7314b4_added_name_field.py new file mode 100644 index 00000000..d2ced62b --- /dev/null +++ b/inference/server/alembic/versions/2023_02_11_0151-4b846e7314b4_added_name_field.py @@ -0,0 +1,28 @@ +"""added name field + +Revision ID: 4b846e7314b4 +Revises: 569cd595bb10 +Create Date: 2023-02-11 01:51:16.288860 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op + +# revision identifiers, used by Alembic. +revision = "4b846e7314b4" +down_revision = "569cd595bb10" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("worker", sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("worker", "name") + # ### end Alembic commands ### diff --git a/inference/server/main.py b/inference/server/main.py index 988ce05b..dc8c734a 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -12,7 +12,7 @@ import websockets.exceptions from fastapi import Depends from fastapi.middleware.cors import CORSMiddleware from loguru import logger -from oasst_inference_server import interface, queueing +from oasst_inference_server import interface, models, queueing from oasst_inference_server.chat_repository import ChatRepository from oasst_inference_server.database import db_engine from oasst_inference_server.settings import settings @@ -50,6 +50,12 @@ def create_session(): yield session +@contextlib.contextmanager +def manual_create_session(): + with contextlib.contextmanager(create_session)() as session: + yield session + + def create_chat_repository(session: sqlmodel.Session = Depends(create_session)): repository = ChatRepository(session) return repository @@ -57,33 +63,99 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)): @contextlib.contextmanager def manual_chat_repository(): - with contextlib.contextmanager(create_session)() as session: + with manual_create_session() as session: yield create_chat_repository(session) -if settings.update_alembic: +api_key_header = fastapi.Header(None, alias="X-API-Key") - @app.on_event("startup") - def alembic_upgrade(): - logger.info("Attempting to upgrade alembic on startup") - retry = 0 - while True: - try: - alembic_ini_path = Path(__file__).parent / "alembic.ini" - alembic_cfg = alembic.config.Config(str(alembic_ini_path)) - alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri) - alembic.command.upgrade(alembic_cfg, "head") - logger.info("Successfully upgraded alembic on startup") - break - except Exception: - logger.exception("Alembic upgrade failed on startup") - retry += 1 - if retry >= settings.alembic_retries: - raise - timeout = settings.alembic_retry_timeout * 2**retry - logger.warning(f"Retrying alembic upgrade in {timeout} seconds") - time.sleep(timeout) +def get_api_key(api_key_header: str = api_key_header) -> str: + if api_key_header is None: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Missing API key", + ) + return api_key_header + + +def get_worker( + api_key: str = Depends(get_api_key), + session: sqlmodel.Session = Depends(create_session), +) -> models.DbWorkerEntry: + worker = session.exec( + sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key) + ).one_or_none() + if worker is None: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + ) + return worker + + +def get_bearer_token(authorization_header: str) -> str: + if not authorization_header.startswith("Bearer "): + raise ValueError("Authorization header must start with 'Bearer '") + return authorization_header[len("Bearer ") :] + + +def get_root_token(token: str = Depends(get_bearer_token)) -> str: + root_token = settings.root_token + if token == root_token: + return token + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, + detail="Invalid token", + ) + + +@app.on_event("startup") +def alembic_upgrade(): + if not settings.update_alembic: + logger.info("Skipping alembic upgrade on startup (update_alembic is False)") + return + logger.info("Attempting to upgrade alembic on startup") + retry = 0 + while True: + try: + alembic_ini_path = Path(__file__).parent / "alembic.ini" + alembic_cfg = alembic.config.Config(str(alembic_ini_path)) + alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri) + alembic.command.upgrade(alembic_cfg, "head") + logger.info("Successfully upgraded alembic on startup") + break + except Exception: + logger.exception("Alembic upgrade failed on startup") + retry += 1 + if retry >= settings.alembic_retries: + raise + + timeout = settings.alembic_retry_timeout * 2**retry + logger.warning(f"Retrying alembic upgrade in {timeout} seconds") + time.sleep(timeout) + + +@app.on_event("startup") +def maybe_add_debug_api_keys(): + if not settings.debug_api_keys: + logger.info("No debug API keys configured, skipping") + return + logger.info("Adding debug API keys") + with manual_create_session() as session: + for api_key in settings.debug_api_keys: + logger.info(f"Checking if debug API key {api_key} exists") + if ( + session.exec( + sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key) + ).one_or_none() + is None + ): + logger.info(f"Adding debug API key {api_key}") + session.add(models.DbWorkerEntry(api_key=api_key, name="Debug API Key")) + session.commit() + else: + logger.info(f"Debug API key {api_key} already exists") @app.get("/chat") @@ -168,7 +240,7 @@ async def create_message( @app.websocket("/work") -async def work(websocket: fastapi.WebSocket): +async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) queue_id = f"work:{worker_config.compat_hash}" @@ -240,3 +312,42 @@ async def work(websocket: fastapi.WebSocket): chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete) except fastapi.WebSocketException: logger.exception("Websocket closed") + + +@app.put("/worker") +def create_worker( + request: interface.CreateWorkerRequest, + root_token: str = fastapi.Depends(get_root_token), + session: sqlmodel.Session = fastapi.Depends(create_session), +): + """Allows a client to register a worker.""" + worker = models.DbWorkerEntry( + name=request.name, + ) + session.add(worker) + session.commit() + session.refresh(worker) + return worker + + +@app.get("/worker") +def list_workers( + root_token: str = fastapi.Depends(get_root_token), + session: sqlmodel.Session = fastapi.Depends(create_session), +): + """Lists all workers.""" + workers = session.exec(sqlmodel.select(models.DbWorkerEntry)).all() + return list(workers) + + +@app.delete("/worker/{worker_id}") +def delete_worker( + worker_id: str, + root_token: str = fastapi.Depends(get_root_token), + session: sqlmodel.Session = fastapi.Depends(create_session), +): + """Deletes a worker.""" + worker = session.get(models.DbWorkerEntry, worker_id) + session.delete(worker) + session.commit() + return fastapi.Response(status_code=200) diff --git a/inference/server/oasst_inference_server/interface.py b/inference/server/oasst_inference_server/interface.py index 5c5d8fb7..55c24a3d 100644 --- a/inference/server/oasst_inference_server/interface.py +++ b/inference/server/oasst_inference_server/interface.py @@ -40,3 +40,7 @@ class ChatEntry(pydantic.BaseModel): class ListChatsResponse(pydantic.BaseModel): chats: list[ChatListEntry] + + +class CreateWorkerRequest(pydantic.BaseModel): + name: str diff --git a/inference/server/oasst_inference_server/models.py b/inference/server/oasst_inference_server/models.py index f1a32438..cd8f2bc5 100644 --- a/inference/server/oasst_inference_server/models.py +++ b/inference/server/oasst_inference_server/models.py @@ -21,3 +21,11 @@ class DbChatEntry(SQLModel, table=True): def to_entry(self) -> interface.ChatEntry: return interface.ChatEntry(id=self.id, conversation=self.conversation) + + +class DbWorkerEntry(SQLModel, table=True): + __tablename__ = "worker" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + api_key: str = Field(default_factory=lambda: str(uuid4()), index=True) + name: str diff --git a/inference/server/oasst_inference_server/settings.py b/inference/server/oasst_inference_server/settings.py index 5f79fb55..dadd5fd3 100644 --- a/inference/server/oasst_inference_server/settings.py +++ b/inference/server/oasst_inference_server/settings.py @@ -36,5 +36,9 @@ class Settings(pydantic.BaseSettings): path=f"/{values.get('postgres_db') or ''}", ) + root_token: str = "1234" + + debug_api_keys: list[str] = [] + settings = Settings() diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 74cf4d1d..811c0805 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -107,6 +107,7 @@ def main(): on_error=on_error, on_close=on_close, on_open=on_open, + header={"X-API-Key": settings.api_key}, ) ws.run_forever(dispatcher=rel, reconnect=5) diff --git a/inference/worker/settings.py b/inference/worker/settings.py index c726479c..868e2140 100644 --- a/inference/worker/settings.py +++ b/inference/worker/settings.py @@ -5,6 +5,7 @@ class Settings(pydantic.BaseSettings): backend_url: str = "ws://localhost:8000" model_id: str = "distilgpt2" inference_server_url: str = "http://localhost:8001" + api_key: str = "0000" settings = Settings()