mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-26 16:00:18 +08:00
introduced api keys to inference
This commit is contained in:
@@ -10,7 +10,7 @@ tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd server" C-m
|
||||
tmux send-keys "uvicorn main:app --reload" C-m
|
||||
tmux send-keys "DEBUG_API_KEYS='[\"0000\"]' uvicorn main:app --reload" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd worker" C-m
|
||||
tmux send-keys "python __main__.py" C-m
|
||||
|
||||
+35
@@ -0,0 +1,35 @@
|
||||
"""added worker db entry
|
||||
|
||||
Revision ID: 569cd595bb10
|
||||
Revises: 3a4cd8777eb2
|
||||
Create Date: 2023-02-11 01:47:34.880485
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "569cd595bb10"
|
||||
down_revision = "3a4cd8777eb2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"worker",
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.Column("api_key", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(op.f("ix_worker_api_key"), "worker", ["api_key"], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f("ix_worker_api_key"), table_name="worker")
|
||||
op.drop_table("worker")
|
||||
# ### end Alembic commands ###
|
||||
@@ -0,0 +1,28 @@
|
||||
"""added name field
|
||||
|
||||
Revision ID: 4b846e7314b4
|
||||
Revises: 569cd595bb10
|
||||
Create Date: 2023-02-11 01:51:16.288860
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4b846e7314b4"
|
||||
down_revision = "569cd595bb10"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("worker", sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("worker", "name")
|
||||
# ### end Alembic commands ###
|
||||
+135
-24
@@ -12,7 +12,7 @@ import websockets.exceptions
|
||||
from fastapi import Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_inference_server import interface, queueing
|
||||
from oasst_inference_server import interface, models, queueing
|
||||
from oasst_inference_server.chat_repository import ChatRepository
|
||||
from oasst_inference_server.database import db_engine
|
||||
from oasst_inference_server.settings import settings
|
||||
@@ -50,6 +50,12 @@ def create_session():
|
||||
yield session
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_create_session():
|
||||
with contextlib.contextmanager(create_session)() as session:
|
||||
yield session
|
||||
|
||||
|
||||
def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
|
||||
repository = ChatRepository(session)
|
||||
return repository
|
||||
@@ -57,33 +63,99 @@ def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manual_chat_repository():
|
||||
with contextlib.contextmanager(create_session)() as session:
|
||||
with manual_create_session() as session:
|
||||
yield create_chat_repository(session)
|
||||
|
||||
|
||||
if settings.update_alembic:
|
||||
api_key_header = fastapi.Header(None, alias="X-API-Key")
|
||||
|
||||
@app.on_event("startup")
|
||||
def alembic_upgrade():
|
||||
logger.info("Attempting to upgrade alembic on startup")
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
|
||||
alembic.command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Successfully upgraded alembic on startup")
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
retry += 1
|
||||
if retry >= settings.alembic_retries:
|
||||
raise
|
||||
|
||||
timeout = settings.alembic_retry_timeout * 2**retry
|
||||
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
|
||||
time.sleep(timeout)
|
||||
def get_api_key(api_key_header: str = api_key_header) -> str:
|
||||
if api_key_header is None:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Missing API key",
|
||||
)
|
||||
return api_key_header
|
||||
|
||||
|
||||
def get_worker(
|
||||
api_key: str = Depends(get_api_key),
|
||||
session: sqlmodel.Session = Depends(create_session),
|
||||
) -> models.DbWorkerEntry:
|
||||
worker = session.exec(
|
||||
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
|
||||
).one_or_none()
|
||||
if worker is None:
|
||||
raise fastapi.HTTPException(
|
||||
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid API key",
|
||||
)
|
||||
return worker
|
||||
|
||||
|
||||
def get_bearer_token(authorization_header: str) -> str:
|
||||
if not authorization_header.startswith("Bearer "):
|
||||
raise ValueError("Authorization header must start with 'Bearer '")
|
||||
return authorization_header[len("Bearer ") :]
|
||||
|
||||
|
||||
def get_root_token(token: str = Depends(get_bearer_token)) -> str:
|
||||
root_token = settings.root_token
|
||||
if token == root_token:
|
||||
return token
|
||||
raise fastapi.HTTPException(
|
||||
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token",
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def alembic_upgrade():
|
||||
if not settings.update_alembic:
|
||||
logger.info("Skipping alembic upgrade on startup (update_alembic is False)")
|
||||
return
|
||||
logger.info("Attempting to upgrade alembic on startup")
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
|
||||
alembic.command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Successfully upgraded alembic on startup")
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
retry += 1
|
||||
if retry >= settings.alembic_retries:
|
||||
raise
|
||||
|
||||
timeout = settings.alembic_retry_timeout * 2**retry
|
||||
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def maybe_add_debug_api_keys():
|
||||
if not settings.debug_api_keys:
|
||||
logger.info("No debug API keys configured, skipping")
|
||||
return
|
||||
logger.info("Adding debug API keys")
|
||||
with manual_create_session() as session:
|
||||
for api_key in settings.debug_api_keys:
|
||||
logger.info(f"Checking if debug API key {api_key} exists")
|
||||
if (
|
||||
session.exec(
|
||||
sqlmodel.select(models.DbWorkerEntry).where(models.DbWorkerEntry.api_key == api_key)
|
||||
).one_or_none()
|
||||
is None
|
||||
):
|
||||
logger.info(f"Adding debug API key {api_key}")
|
||||
session.add(models.DbWorkerEntry(api_key=api_key, name="Debug API Key"))
|
||||
session.commit()
|
||||
else:
|
||||
logger.info(f"Debug API key {api_key} already exists")
|
||||
|
||||
|
||||
@app.get("/chat")
|
||||
@@ -168,7 +240,7 @@ async def create_message(
|
||||
|
||||
|
||||
@app.websocket("/work")
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
async def work(websocket: fastapi.WebSocket, worker: models.DbWorkerEntry = Depends(get_worker)):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
queue_id = f"work:{worker_config.compat_hash}"
|
||||
@@ -240,3 +312,42 @@ async def work(websocket: fastapi.WebSocket):
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
|
||||
except fastapi.WebSocketException:
|
||||
logger.exception("Websocket closed")
|
||||
|
||||
|
||||
@app.put("/worker")
|
||||
def create_worker(
|
||||
request: interface.CreateWorkerRequest,
|
||||
root_token: str = fastapi.Depends(get_root_token),
|
||||
session: sqlmodel.Session = fastapi.Depends(create_session),
|
||||
):
|
||||
"""Allows a client to register a worker."""
|
||||
worker = models.DbWorkerEntry(
|
||||
name=request.name,
|
||||
)
|
||||
session.add(worker)
|
||||
session.commit()
|
||||
session.refresh(worker)
|
||||
return worker
|
||||
|
||||
|
||||
@app.get("/worker")
|
||||
def list_workers(
|
||||
root_token: str = fastapi.Depends(get_root_token),
|
||||
session: sqlmodel.Session = fastapi.Depends(create_session),
|
||||
):
|
||||
"""Lists all workers."""
|
||||
workers = session.exec(sqlmodel.select(models.DbWorkerEntry)).all()
|
||||
return list(workers)
|
||||
|
||||
|
||||
@app.delete("/worker/{worker_id}")
|
||||
def delete_worker(
|
||||
worker_id: str,
|
||||
root_token: str = fastapi.Depends(get_root_token),
|
||||
session: sqlmodel.Session = fastapi.Depends(create_session),
|
||||
):
|
||||
"""Deletes a worker."""
|
||||
worker = session.get(models.DbWorkerEntry, worker_id)
|
||||
session.delete(worker)
|
||||
session.commit()
|
||||
return fastapi.Response(status_code=200)
|
||||
|
||||
@@ -40,3 +40,7 @@ class ChatEntry(pydantic.BaseModel):
|
||||
|
||||
class ListChatsResponse(pydantic.BaseModel):
|
||||
chats: list[ChatListEntry]
|
||||
|
||||
|
||||
class CreateWorkerRequest(pydantic.BaseModel):
|
||||
name: str
|
||||
|
||||
@@ -21,3 +21,11 @@ class DbChatEntry(SQLModel, table=True):
|
||||
|
||||
def to_entry(self) -> interface.ChatEntry:
|
||||
return interface.ChatEntry(id=self.id, conversation=self.conversation)
|
||||
|
||||
|
||||
class DbWorkerEntry(SQLModel, table=True):
|
||||
__tablename__ = "worker"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
||||
api_key: str = Field(default_factory=lambda: str(uuid4()), index=True)
|
||||
name: str
|
||||
|
||||
@@ -36,5 +36,9 @@ class Settings(pydantic.BaseSettings):
|
||||
path=f"/{values.get('postgres_db') or ''}",
|
||||
)
|
||||
|
||||
root_token: str = "1234"
|
||||
|
||||
debug_api_keys: list[str] = []
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -107,6 +107,7 @@ def main():
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
on_open=on_open,
|
||||
header={"X-API-Key": settings.api_key},
|
||||
)
|
||||
|
||||
ws.run_forever(dispatcher=rel, reconnect=5)
|
||||
|
||||
@@ -5,6 +5,7 @@ class Settings(pydantic.BaseSettings):
|
||||
backend_url: str = "ws://localhost:8000"
|
||||
model_id: str = "distilgpt2"
|
||||
inference_server_url: str = "http://localhost:8001"
|
||||
api_key: str = "0000"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
Reference in New Issue
Block a user