mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-26 16:00:18 +08:00
Added database to inference server (#1446)
* added db for inference * fixed dockerfiles for inference
This commit is contained in:
+30
-23
@@ -136,6 +136,23 @@ services:
|
||||
- "3000:3000"
|
||||
command: bash wait-for-postgres.sh node server.js
|
||||
|
||||
# This DB is for Inference
|
||||
inference-db:
|
||||
image: postgres
|
||||
restart: always
|
||||
ports:
|
||||
- 5434:5432
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: oasst_inference
|
||||
healthcheck:
|
||||
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||
interval: 2s
|
||||
timeout: 2s
|
||||
retries: 10
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-server:
|
||||
build:
|
||||
dockerfile: docker/inference/Dockerfile.server
|
||||
@@ -145,13 +162,25 @@ services:
|
||||
environment:
|
||||
- "PORT=8000"
|
||||
- "REDIS_HOST=redis"
|
||||
- POSTGRES_HOST=inference-db
|
||||
- POSTGRES_DB=oasst_inference
|
||||
volumes:
|
||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||
- "./inference/server:/opt/inference/server"
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8000:8000"
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
inference-db:
|
||||
condition: service_healthy
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-text-generation-server:
|
||||
image: ghcr.io/huggingface/text-generation-inference
|
||||
environment:
|
||||
- "MODEL_ID=distilgpt2"
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-worker:
|
||||
@@ -167,29 +196,7 @@ services:
|
||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||
- "./inference/worker:/opt/inference/worker"
|
||||
depends_on:
|
||||
- inference-server
|
||||
- inference-text-generation-server
|
||||
deploy:
|
||||
replicas: 1
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-text-client:
|
||||
build:
|
||||
dockerfile: docker/inference/Dockerfile.text-client
|
||||
context: .
|
||||
image: oasst-inference-text-client
|
||||
environment:
|
||||
- "BACKEND_URL=http://inference-server:8000"
|
||||
tty: true
|
||||
stdin_open: true
|
||||
volumes:
|
||||
- "./inference/worker:/opt/inference/worker"
|
||||
restart: unless-stopped
|
||||
depends_on:
|
||||
- inference-server
|
||||
profiles: ["inference"]
|
||||
|
||||
inference-text-generation-server:
|
||||
image: ghcr.io/huggingface/text-generation-inference
|
||||
environment:
|
||||
- "MODEL_ID=distilgpt2"
|
||||
profiles: ["inference"]
|
||||
|
||||
@@ -7,7 +7,7 @@ ARG APP_USER="${MODULE}-${SERVICE}"
|
||||
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
||||
|
||||
|
||||
FROM python:3-slim as build
|
||||
FROM python:3.10-slim as build
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
WORKDIR /build
|
||||
@@ -22,7 +22,7 @@ RUN --mount=type=cache,target=/var/cache/pip \
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as base-env
|
||||
FROM python:3.10-slim as base-env
|
||||
ARG APP_USER
|
||||
ARG APP_RELATIVE_PATH
|
||||
ARG MODULE
|
||||
@@ -50,6 +50,9 @@ WORKDIR ${APP_ROOT}
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic alembic
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic.ini .
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/oasst_inference_server oasst_inference_server
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
|
||||
|
||||
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
# syntax=docker/dockerfile:1
|
||||
|
||||
ARG APP_USER="text-client"
|
||||
ARG APP_RELATIVE_PATH="inference/text-client"
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as build
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/pip \
|
||||
pip install \
|
||||
--cache-dir=/var/cache/pip \
|
||||
--target=lib \
|
||||
-r requirements.txt
|
||||
|
||||
|
||||
|
||||
FROM python:3.10-alpine3.17 as base-env
|
||||
ARG APP_USER
|
||||
ARG APP_RELATIVE_PATH
|
||||
|
||||
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
|
||||
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
|
||||
|
||||
ENV PATH="${PATH}:${APP_LIBS}/bin"
|
||||
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
|
||||
|
||||
|
||||
RUN adduser \
|
||||
--disabled-password \
|
||||
--no-create-home \
|
||||
"${APP_USER}"
|
||||
|
||||
USER ${APP_USER}
|
||||
|
||||
WORKDIR ${APP_ROOT}
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
||||
|
||||
|
||||
|
||||
FROM base-env as prod
|
||||
|
||||
|
||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
|
||||
@@ -48,7 +48,7 @@ WORKDIR ${APP_ROOT}
|
||||
|
||||
|
||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .
|
||||
|
||||
|
||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
# Creates a tmux window with splits for the individual services
|
||||
|
||||
tmux new-session -d -s "inference-dev-setup"
|
||||
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
|
||||
tmux send-keys "docker run --rm -it -p 5432:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
|
||||
tmux send-keys "docker run --rm -it -p 6379:6379 --name redis redis" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
|
||||
tmux split-window -h
|
||||
tmux send-keys "cd server" C-m
|
||||
tmux send-keys "uvicorn main:app --reload" C-m
|
||||
|
||||
@@ -0,0 +1,105 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to dateutil.tz.gettz()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
|
||||
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
hooks = black
|
||||
black.type = console_scripts
|
||||
black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
@@ -0,0 +1,78 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
import sqlmodel
|
||||
from alembic import context
|
||||
from oasst_inference_server import models # noqa: F401
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = sqlmodel.SQLModel.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.get_context()._ensure_version_table()
|
||||
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,36 @@
|
||||
"""initial commit
|
||||
|
||||
Revision ID: 3a4cd8777eb2
|
||||
Revises:
|
||||
Create Date: 2023-02-10 02:21:27.086772
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
import sqlmodel
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3a4cd8777eb2"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"chat",
|
||||
sa.Column("conversation", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("pending_message_request", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("message_request_state", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("chat")
|
||||
# ### end Alembic commands ###
|
||||
+89
-113
@@ -1,14 +1,22 @@
|
||||
import asyncio
|
||||
import enum
|
||||
import uuid
|
||||
import contextlib
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import alembic.command
|
||||
import alembic.config
|
||||
import fastapi
|
||||
import pydantic
|
||||
import redis.asyncio as redis
|
||||
import sqlmodel
|
||||
import websockets.exceptions
|
||||
from fastapi import Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
from oasst_inference_server import interface
|
||||
from oasst_inference_server.chat_repository import ChatRepository
|
||||
from oasst_inference_server.database import db_engine
|
||||
from oasst_inference_server.settings import settings
|
||||
from oasst_shared.schemas import inference
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
@@ -31,129 +39,97 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_db: int = 0
|
||||
|
||||
sse_retry_timeout: int = 15000
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# create async redis client
|
||||
redisClient = redis.Redis(
|
||||
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
class MessageRequest(pydantic.BaseModel):
|
||||
message: str = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
||||
return self.model_name == worker_config.model_name
|
||||
def create_session():
|
||||
with sqlmodel.Session(db_engine) as session:
|
||||
yield session
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
token: inference.TokenResponse
|
||||
def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
|
||||
repository = ChatRepository(session)
|
||||
return repository
|
||||
|
||||
|
||||
class MessageRequestState(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
complete = "complete"
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
if settings.update_alembic:
|
||||
|
||||
@app.on_event("startup")
|
||||
def alembic_upgrade():
|
||||
logger.info("Attempting to upgrade alembic on startup")
|
||||
retry = 0
|
||||
while True:
|
||||
try:
|
||||
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
|
||||
alembic.command.upgrade(alembic_cfg, "head")
|
||||
logger.info("Successfully upgraded alembic on startup")
|
||||
break
|
||||
except Exception:
|
||||
logger.exception("Alembic upgrade failed on startup")
|
||||
retry += 1
|
||||
if retry >= settings.alembic_retries:
|
||||
raise
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class ChatListEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class ChatEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
conversation: protocol.Conversation
|
||||
|
||||
|
||||
class ListChatsResponse(pydantic.BaseModel):
|
||||
chats: list[ChatListEntry]
|
||||
|
||||
|
||||
class DbChatEntry(pydantic.BaseModel):
|
||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
|
||||
pending_message_request: MessageRequest | None = None
|
||||
message_request_state: MessageRequestState | None = None
|
||||
|
||||
def to_list_entry(self) -> ChatListEntry:
|
||||
return ChatListEntry(id=self.id)
|
||||
|
||||
def to_entry(self) -> ChatEntry:
|
||||
return ChatEntry(id=self.id, conversation=self.conversation)
|
||||
|
||||
|
||||
# TODO: make real database
|
||||
CHATS: dict[str, DbChatEntry] = {}
|
||||
timeout = settings.alembic_retry_timeout * 2**retry
|
||||
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
|
||||
time.sleep(timeout)
|
||||
|
||||
|
||||
@app.get("/chat")
|
||||
async def list_chats() -> ListChatsResponse:
|
||||
async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse:
|
||||
"""Lists all chats."""
|
||||
logger.info("Listing all chats.")
|
||||
chats = [chat.to_list_entry() for chat in CHATS.values()]
|
||||
return ListChatsResponse(chats=chats)
|
||||
chats = chat_repository.get_chat_list()
|
||||
return interface.ListChatsResponse(chats=chats)
|
||||
|
||||
|
||||
@app.post("/chat")
|
||||
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
|
||||
async def create_chat(
|
||||
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
|
||||
) -> interface.ChatListEntry:
|
||||
"""Allows a client to create a new chat."""
|
||||
logger.info(f"Received {request}")
|
||||
chat = DbChatEntry()
|
||||
CHATS[chat.id] = chat
|
||||
return ChatListEntry(id=chat.id)
|
||||
chat = chat_repository.create_chat()
|
||||
return chat.to_list_entry()
|
||||
|
||||
|
||||
@app.get("/chat/{id}")
|
||||
async def get_chat(id: str) -> ChatEntry:
|
||||
async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ChatEntry:
|
||||
"""Allows a client to get the current state of a chat."""
|
||||
return CHATS[id].to_entry()
|
||||
chat = chat_repository.get_chat_entry_by_id(id)
|
||||
return chat
|
||||
|
||||
|
||||
@app.post("/chat/{id}/message")
|
||||
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
|
||||
async def create_message(
|
||||
id: str,
|
||||
message_request: interface.MessageRequest,
|
||||
fastapi_request: fastapi.Request,
|
||||
chat_repository: ChatRepository = Depends(create_chat_repository),
|
||||
) -> EventSourceResponse:
|
||||
"""Allows the client to stream the results of a request."""
|
||||
|
||||
chat = CHATS[id]
|
||||
if not chat.conversation.is_prompter_turn:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
||||
if chat.pending_message_request is not None:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Already pending")
|
||||
try:
|
||||
chat_repository.add_prompter_message(id=id, message_request=message_request)
|
||||
except Exception:
|
||||
logger.exception("Error adding prompter message")
|
||||
return fastapi.Response(status_code=500)
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=message_request.message,
|
||||
is_assistant=False,
|
||||
)
|
||||
)
|
||||
|
||||
chat.pending_message_request = message_request
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
|
||||
async def event_generator():
|
||||
async def event_generator(id):
|
||||
result_data = []
|
||||
|
||||
try:
|
||||
while True:
|
||||
if await fastapi_request.is_disconnected():
|
||||
logger.warning("Client disconnected")
|
||||
break
|
||||
return
|
||||
|
||||
item = await redisClient.blpop(chat.id, 1)
|
||||
item = await redisClient.blpop(id, 1)
|
||||
if item is None:
|
||||
continue
|
||||
|
||||
@@ -166,47 +142,44 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
||||
|
||||
yield {
|
||||
"retry": settings.sse_retry_timeout,
|
||||
"data": TokenResponseEvent(token=response_packet.token).json(),
|
||||
"data": interface.TokenResponseEvent(token=response_packet.token).json(),
|
||||
}
|
||||
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
|
||||
logger.info(f"Finished streaming {id} {len(result_data)=}")
|
||||
except Exception:
|
||||
logger.exception(f"Error streaming {chat.id}")
|
||||
logger.exception(f"Error streaming {id}")
|
||||
raise
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=response_packet.generated_text.text,
|
||||
is_assistant=True,
|
||||
)
|
||||
)
|
||||
chat.pending_message_request = None
|
||||
try:
|
||||
with contextlib.contextmanager(create_session)() as session:
|
||||
chat_repository = create_chat_repository(session)
|
||||
chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
|
||||
except Exception:
|
||||
logger.exception("Error adding assistant message")
|
||||
|
||||
return EventSourceResponse(event_generator())
|
||||
return EventSourceResponse(event_generator(id))
|
||||
|
||||
|
||||
@app.websocket("/work")
|
||||
async def work(websocket: fastapi.WebSocket):
|
||||
async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)):
|
||||
await websocket.accept()
|
||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||
try:
|
||||
while True:
|
||||
print(websocket.client_state)
|
||||
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
||||
logger.warning("Worker disconnected")
|
||||
break
|
||||
# find a pending task that matches the worker's config
|
||||
# could also be implemented using task queues
|
||||
# but general compatibility matching is tricky
|
||||
for chat in CHATS.values():
|
||||
if (request := chat.pending_message_request) is not None:
|
||||
if chat.message_request_state == MessageRequestState.pending:
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
for chat in chat_repository.get_pending_chats():
|
||||
request = chat.pending_message_request
|
||||
if request.compatible_with(worker_config):
|
||||
break
|
||||
else:
|
||||
logger.debug("No pending tasks")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
chat.message_request_state = MessageRequestState.in_progress
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
|
||||
|
||||
work_request = inference.WorkRequest(
|
||||
conversation=chat.conversation,
|
||||
@@ -214,15 +187,17 @@ async def work(websocket: fastapi.WebSocket):
|
||||
max_new_tokens=request.max_new_tokens,
|
||||
)
|
||||
|
||||
logger.info(f"Created {work_request}")
|
||||
logger.info(f"Created {work_request=}")
|
||||
try:
|
||||
await websocket.send_text(work_request.json())
|
||||
except websockets.exceptions.ConnectionClosedError:
|
||||
logger.warning("Worker disconnected")
|
||||
websocket.close()
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||
break
|
||||
|
||||
logger.debug(f"Sent {work_request=} to worker.")
|
||||
|
||||
try:
|
||||
in_progress = False
|
||||
while True:
|
||||
@@ -232,18 +207,19 @@ async def work(websocket: fastapi.WebSocket):
|
||||
in_progress = True
|
||||
await redisClient.rpush(chat.id, response_packet.json())
|
||||
if response_packet.is_end:
|
||||
logger.debug(f"Received {response_packet=} from worker. Ending.")
|
||||
break
|
||||
except fastapi.WebSocketException:
|
||||
# TODO: handle this better
|
||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||
if in_progress:
|
||||
logger.warning(f"Aborting {chat.id=}")
|
||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
|
||||
else:
|
||||
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
||||
chat.message_request_state = MessageRequestState.pending
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||
raise
|
||||
|
||||
chat.message_request_state = MessageRequestState.complete
|
||||
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
|
||||
except fastapi.WebSocketException:
|
||||
logger.exception("Websocket closed")
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
import fastapi
|
||||
import sqlmodel
|
||||
from loguru import logger
|
||||
from oasst_inference_server import interface, models
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlalchemy.sql.operators import is_not
|
||||
|
||||
|
||||
class ChatRepository:
|
||||
def __init__(self, session: sqlmodel.Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def get_chats(self) -> list[models.DbChatEntry]:
|
||||
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
|
||||
|
||||
def get_pending_chats(self) -> list[models.DbChatEntry]:
|
||||
return self.session.exec(
|
||||
sqlmodel.select(models.DbChatEntry).where(
|
||||
is_not(models.DbChatEntry.pending_message_request, None),
|
||||
models.DbChatEntry.message_request_state == interface.MessageRequestState.pending,
|
||||
)
|
||||
).all()
|
||||
|
||||
def get_chat_list(self) -> list[interface.ChatListEntry]:
|
||||
chats = self.get_chats()
|
||||
return [chat.to_list_entry() for chat in chats]
|
||||
|
||||
def get_chat_by_id(self, id: str) -> models.DbChatEntry:
|
||||
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one()
|
||||
return chat
|
||||
|
||||
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry:
|
||||
return self.get_chat_by_id(id).to_entry()
|
||||
|
||||
def create_chat(self) -> models.DbChatEntry:
|
||||
chat = models.DbChatEntry()
|
||||
self.session.add(chat)
|
||||
self.session.commit()
|
||||
return chat
|
||||
|
||||
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None:
|
||||
logger.info(f"Adding prompter message {message_request} to chat {id}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
if not chat.conversation.is_prompter_turn:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
||||
if chat.pending_message_request is not None:
|
||||
raise fastapi.HTTPException(status_code=400, detail="Already pending")
|
||||
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=message_request.message,
|
||||
is_assistant=False,
|
||||
)
|
||||
)
|
||||
|
||||
chat.pending_message_request = message_request
|
||||
chat.message_request_state = interface.MessageRequestState.pending
|
||||
self.session.commit()
|
||||
logger.debug(f"Added prompter message {message_request} to chat {id}")
|
||||
|
||||
def add_assistant_message(self, id: str, text: str) -> None:
|
||||
logger.info(f"Adding assistant message {text} to chat {id}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
chat.conversation.messages.append(
|
||||
protocol.ConversationMessage(
|
||||
text=text,
|
||||
is_assistant=True,
|
||||
)
|
||||
)
|
||||
chat.pending_message_request = None
|
||||
self.session.commit()
|
||||
logger.debug(f"Added assistant message {text} to chat {id}")
|
||||
|
||||
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None:
|
||||
logger.info(f"Setting chat {id} state to {state}")
|
||||
chat = self.get_chat_by_id(id)
|
||||
chat.message_request_state = state
|
||||
self.session.commit()
|
||||
logger.debug(f"Set chat {id} state to {state}")
|
||||
@@ -0,0 +1,41 @@
|
||||
import json
|
||||
|
||||
import pydantic.json
|
||||
import sqlmodel
|
||||
from loguru import logger
|
||||
from oasst_inference_server import models
|
||||
from oasst_inference_server.settings import settings
|
||||
|
||||
|
||||
def default_json_serializer(obj):
|
||||
class_name = obj.__class__.__name__
|
||||
encoded = pydantic.json.pydantic_encoder(obj)
|
||||
encoded["_classname_"] = class_name
|
||||
return encoded
|
||||
|
||||
|
||||
def custom_json_serializer(obj):
|
||||
return json.dumps(obj, default=default_json_serializer)
|
||||
|
||||
|
||||
def custom_json_deserializer(s):
|
||||
d = json.loads(s)
|
||||
if not isinstance(d, dict):
|
||||
return d
|
||||
match d.get("_classname_"):
|
||||
case "Conversation":
|
||||
return models.protocol.Conversation.parse_obj(d)
|
||||
case "MessageRequest":
|
||||
return models.interface.MessageRequest.parse_obj(d)
|
||||
case None:
|
||||
return d
|
||||
case _:
|
||||
logger.error(f"Unknown class {d['_classname_']}")
|
||||
raise ValueError(f"Unknown class {d['_classname_']}")
|
||||
|
||||
|
||||
db_engine = sqlmodel.create_engine(
|
||||
settings.database_uri,
|
||||
json_serializer=custom_json_serializer,
|
||||
json_deserializer=custom_json_deserializer,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
import enum
|
||||
|
||||
import pydantic
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
|
||||
|
||||
class MessageRequest(pydantic.BaseModel):
|
||||
message: str = pydantic.Field(..., repr=False)
|
||||
model_name: str = "distilgpt2"
|
||||
max_new_tokens: int = 100
|
||||
|
||||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
||||
return self.model_name == worker_config.model_name
|
||||
|
||||
|
||||
class TokenResponseEvent(pydantic.BaseModel):
|
||||
token: inference.TokenResponse
|
||||
|
||||
|
||||
class MessageRequestState(str, enum.Enum):
|
||||
pending = "pending"
|
||||
in_progress = "in_progress"
|
||||
complete = "complete"
|
||||
aborted_by_worker = "aborted_by_worker"
|
||||
|
||||
|
||||
class CreateChatRequest(pydantic.BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class ChatListEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
class ChatEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
conversation: protocol.Conversation
|
||||
|
||||
|
||||
class ListChatsResponse(pydantic.BaseModel):
|
||||
chats: list[ChatListEntry]
|
||||
@@ -0,0 +1,23 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.dialects.postgresql as pg
|
||||
from oasst_inference_server import interface
|
||||
from oasst_shared.schemas import protocol
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
|
||||
class DbChatEntry(SQLModel, table=True):
|
||||
__tablename__ = "chat"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
||||
|
||||
conversation: protocol.Conversation = Field(default_factory=protocol.Conversation, sa_column=sa.Column(pg.JSONB))
|
||||
pending_message_request: interface.MessageRequest | None = Field(None, sa_column=sa.Column(pg.JSONB))
|
||||
message_request_state: interface.MessageRequestState | None = Field(None, sa_column=sa.Column(pg.JSONB))
|
||||
|
||||
def to_list_entry(self) -> interface.ChatListEntry:
|
||||
return interface.ChatListEntry(id=self.id)
|
||||
|
||||
def to_entry(self) -> interface.ChatEntry:
|
||||
return interface.ChatEntry(id=self.id, conversation=self.conversation)
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import Any
|
||||
|
||||
import pydantic
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
redis_host: str = "localhost"
|
||||
redis_port: int = 6379
|
||||
redis_db: int = 0
|
||||
|
||||
sse_retry_timeout: int = 15000
|
||||
update_alembic: bool = True
|
||||
alembic_retries: int = 5
|
||||
alembic_retry_timeout: int = 1
|
||||
|
||||
postgres_host: str = "localhost"
|
||||
postgres_port: str = "5432"
|
||||
postgres_user: str = "postgres"
|
||||
postgres_password: str = "postgres"
|
||||
postgres_db: str = "postgres"
|
||||
|
||||
database_uri: str | None = None
|
||||
|
||||
@pydantic.validator("database_uri", pre=True)
|
||||
def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
|
||||
if isinstance(v, str):
|
||||
return v
|
||||
return pydantic.PostgresDsn.build(
|
||||
scheme="postgresql",
|
||||
user=values.get("postgres_user"),
|
||||
password=values.get("postgres_password"),
|
||||
host=values.get("postgres_host"),
|
||||
port=values.get("postgres_port"),
|
||||
path=f"/{values.get('postgres_db') or ''}",
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,7 +1,10 @@
|
||||
alembic
|
||||
fastapi[all]
|
||||
loguru
|
||||
prometheus-fastapi-instrumentator==5.9.1
|
||||
prometheus-fastapi-instrumentator
|
||||
psycopg2-binary
|
||||
pydantic
|
||||
redis
|
||||
sqlmodel
|
||||
sse-starlette
|
||||
websockets
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Simple REPL frontend."""
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
import requests
|
||||
import sseclient
|
||||
@@ -42,6 +43,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
|
||||
break
|
||||
except Exception:
|
||||
typer.echo("Error, restarting chat...")
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,6 +17,8 @@ def main(
|
||||
model_name: str = "distilgpt2",
|
||||
inference_server_url: str = "http://localhost:8001",
|
||||
):
|
||||
utils.wait_for_inference_server(inference_server_url)
|
||||
|
||||
def on_open(ws: websocket.WebSocket):
|
||||
logger.info("Connected to backend, sending config...")
|
||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||
@@ -93,6 +95,7 @@ def main(
|
||||
),
|
||||
).json()
|
||||
)
|
||||
logger.info("Work complete. Waiting for more work...")
|
||||
|
||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import collections
|
||||
import random
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
import interface
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TokenBuffer:
|
||||
@@ -38,3 +42,21 @@ class TokenBuffer:
|
||||
yield from self.tokens
|
||||
else:
|
||||
yield from self.tokens
|
||||
|
||||
|
||||
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
|
||||
health_url = f"{inference_server_url}/health"
|
||||
time_limit = time.time() + timeout
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(health_url)
|
||||
response.raise_for_status()
|
||||
except (requests.HTTPError, requests.ConnectionError):
|
||||
if time.time() > time_limit:
|
||||
raise
|
||||
sleep_duration = random.uniform(0, 10)
|
||||
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
|
||||
time.sleep(sleep_duration)
|
||||
else:
|
||||
logger.info("Inference server is ready")
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user