introduced api keys to inference

This commit is contained in:
Yannic Kilcher
2023-02-11 10:50:51 +01:00
parent 48212079f4
commit edd9268c9a
9 changed files with 217 additions and 25 deletions
+1 -1
View File
@@ -10,7 +10,7 @@ tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
tmux send-keys "DEBUG_API_KEYS='[\"0000\"]' uvicorn main:app --reload" C-m
tmux split-window -h
tmux send-keys "cd worker" C-m
tmux send-keys "python __main__.py" C-m
@@ -0,0 +1,35 @@
"""added worker db entry
Revision ID: 569cd595bb10
Revises: 3a4cd8777eb2
Create Date: 2023-02-11 01:47:34.880485
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "569cd595bb10"
down_revision = "3a4cd8777eb2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"worker",
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(op.f("ix_worker_api_key"), "worker", ["api_key"], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_worker_api_key"), table_name="worker")
op.drop_table("worker")
# ### end Alembic commands ###
@@ -0,0 +1,28 @@
"""added name field
Revision ID: 4b846e7314b4
Revises: 569cd595bb10
Create Date: 2023-02-11 01:51:16.288860
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
# revision identifiers, used by Alembic.
revision = "4b846e7314b4"
down_revision = "569cd595bb10"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("worker", sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("worker", "name")
# ### end Alembic commands ###
+115 -4
View File
@@ -12,7 +12,7 @@ import websockets.exceptions
from fastapi import Depends
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_inference_server import interface, queueing
from oasst_inference_server import interface, models, queueing
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import db_engine
from oasst_inference_server.settings import settings
@@ -50,6 +50,12 @@ def create_session():
yield session
@contextlib.contextmanager
def manual_create_session():
with contextlib.contextmanager(create_session)() as session:
yield session
def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
repository = ChatRepository(session)
return repository
@@ -57,14 +63,58 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
@contextlib.contextmanager
def manual_chat_repository():
with contextlib.contextmanager(create_session)() as session:
with manual_create_session() as session:
yield create_chat_repository(session)
if settings.update_alembic:
api_key_header = fastapi.Header(None, alias="X-API-Key")
def get_api_key(api_key_header: str = api_key_header) -> str:
if api_key_header is None:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Missing API key",
)
return api_key_header
def get_worker(
api_key: str = Depends(get_api_key),
session: sqlmodel.Session = Depends(create_session),
) -> models.DbWorkerEntry:
worker = session.exec(
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
).one_or_none()
if worker is None:
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
)
return worker
def get_bearer_token(authorization_header: str) -> str:
if not authorization_header.startswith("Bearer "):
raise ValueError("Authorization header must start with 'Bearer '")
return authorization_header[len("Bearer ") :]
def get_root_token(token: str = Depends(get_bearer_token)) -> str:
root_token = settings.root_token
if token == root_token:
return token
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="Invalid token",
)
@app.on_event("startup")
def alembic_upgrade():
if not settings.update_alembic:
logger.info("Skipping alembic upgrade on startup (update_alembic is False)")
return
logger.info("Attempting to upgrade alembic on startup")
retry = 0
while True:
@@ -86,6 +136,28 @@ if settings.update_alembic:
time.sleep(timeout)
@app.on_event("startup")
def maybe_add_debug_api_keys():
if not settings.debug_api_keys:
logger.info("No debug API keys configured, skipping")
return
logger.info("Adding debug API keys")
with manual_create_session() as session:
for api_key in settings.debug_api_keys:
logger.info(f"Checking if debug API key {api_key} exists")
if (
session.exec(
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
).one_or_none()
is None
):
logger.info(f"Adding debug API key {api_key}")
session.add(models.DbWorkerEntry(api_key=api_key, name="Debug API Key"))
session.commit()
else:
logger.info(f"Debug API key {api_key} already exists")
@app.get("/chat")
async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse:
"""Lists all chats."""
@@ -168,7 +240,7 @@ async def create_message(
@app.websocket("/work")
async def work(websocket: fastapi.WebSocket):
async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
queue_id = f"work:{worker_config.compat_hash}"
@@ -240,3 +312,42 @@ async def work(websocket: fastapi.WebSocket):
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
except fastapi.WebSocketException:
logger.exception("Websocket closed")
@app.put("/worker")
def create_worker(
request: interface.CreateWorkerRequest,
root_token: str = fastapi.Depends(get_root_token),
session: sqlmodel.Session = fastapi.Depends(create_session),
):
"""Allows a client to register a worker."""
worker = models.DbWorkerEntry(
name=request.name,
)
session.add(worker)
session.commit()
session.refresh(worker)
return worker
@app.get("/worker")
def list_workers(
root_token: str = fastapi.Depends(get_root_token),
session: sqlmodel.Session = fastapi.Depends(create_session),
):
"""Lists all workers."""
workers = session.exec(sqlmodel.select(models.DbWorkerEntry)).all()
return list(workers)
@app.delete("/worker/{worker_id}")
def delete_worker(
worker_id: str,
root_token: str = fastapi.Depends(get_root_token),
session: sqlmodel.Session = fastapi.Depends(create_session),
):
"""Deletes a worker."""
worker = session.get(models.DbWorkerEntry, worker_id)
session.delete(worker)
session.commit()
return fastapi.Response(status_code=200)
@@ -40,3 +40,7 @@ class ChatEntry(pydantic.BaseModel):
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
class CreateWorkerRequest(pydantic.BaseModel):
name: str
@@ -21,3 +21,11 @@ class DbChatEntry(SQLModel, table=True):
def to_entry(self) -> interface.ChatEntry:
return interface.ChatEntry(id=self.id, conversation=self.conversation)
class DbWorkerEntry(SQLModel, table=True):
__tablename__ = "worker"
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
api_key: str = Field(default_factory=lambda: str(uuid4()), index=True)
name: str
@@ -36,5 +36,9 @@ class Settings(pydantic.BaseSettings):
path=f"/{values.get('postgres_db') or ''}",
)
root_token: str = "1234"
debug_api_keys: list[str] = []
settings = Settings()
+1
View File
@@ -107,6 +107,7 @@ def main():
on_error=on_error,
on_close=on_close,
on_open=on_open,
header={"X-API-Key": settings.api_key},
)
ws.run_forever(dispatcher=rel, reconnect=5)
+1
View File
@@ -5,6 +5,7 @@ class Settings(pydantic.BaseSettings):
backend_url: str = "ws://localhost:8000"
model_id: str = "distilgpt2"
inference_server_url: str = "http://localhost:8001"
api_key: str = "0000"
settings = Settings()