Added database to inference server (#1446)

* added db for inference

* fixed dockerfiles for inference
This commit is contained in:
Yannic Kilcher
2023-02-10 22:51:35 +01:00
committed by GitHub
parent 911fc2affc
commit 90c3d5640e
21 changed files with 627 additions and 192 deletions
+30 -23
View File
@@ -136,6 +136,23 @@ services:
- "3000:3000" - "3000:3000"
command: bash wait-for-postgres.sh node server.js command: bash wait-for-postgres.sh node server.js
# This DB is for Inference
inference-db:
image: postgres
restart: always
ports:
- 5434:5432
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: oasst_inference
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
profiles: ["inference"]
inference-server: inference-server:
build: build:
dockerfile: docker/inference/Dockerfile.server dockerfile: docker/inference/Dockerfile.server
@@ -145,13 +162,25 @@ services:
environment: environment:
- "PORT=8000" - "PORT=8000"
- "REDIS_HOST=redis" - "REDIS_HOST=redis"
- POSTGRES_HOST=inference-db
- POSTGRES_DB=oasst_inference
volumes: volumes:
- "./oasst-shared:/opt/inference/lib/oasst-shared" - "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/server:/opt/inference/server" - "./inference/server:/opt/inference/server"
restart: unless-stopped restart: unless-stopped
ports:
- "8000:8000"
depends_on: depends_on:
redis: redis:
condition: service_healthy condition: service_healthy
inference-db:
condition: service_healthy
profiles: ["inference"]
inference-text-generation-server:
image: ghcr.io/huggingface/text-generation-inference
environment:
- "MODEL_ID=distilgpt2"
profiles: ["inference"] profiles: ["inference"]
inference-worker: inference-worker:
@@ -167,29 +196,7 @@ services:
- "./oasst-shared:/opt/inference/lib/oasst-shared" - "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/worker:/opt/inference/worker" - "./inference/worker:/opt/inference/worker"
depends_on: depends_on:
- inference-server - inference-text-generation-server
deploy: deploy:
replicas: 1 replicas: 1
profiles: ["inference"] profiles: ["inference"]
inference-text-client:
build:
dockerfile: docker/inference/Dockerfile.text-client
context: .
image: oasst-inference-text-client
environment:
- "BACKEND_URL=http://inference-server:8000"
tty: true
stdin_open: true
volumes:
- "./inference/worker:/opt/inference/worker"
restart: unless-stopped
depends_on:
- inference-server
profiles: ["inference"]
inference-text-generation-server:
image: ghcr.io/huggingface/text-generation-inference
environment:
- "MODEL_ID=distilgpt2"
profiles: ["inference"]
+5 -2
View File
@@ -7,7 +7,7 @@ ARG APP_USER="${MODULE}-${SERVICE}"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}" ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
FROM python:3-slim as build FROM python:3.10-slim as build
ARG APP_RELATIVE_PATH ARG APP_RELATIVE_PATH
WORKDIR /build WORKDIR /build
@@ -22,7 +22,7 @@ RUN --mount=type=cache,target=/var/cache/pip \
FROM python:3.10-alpine3.17 as base-env FROM python:3.10-slim as base-env
ARG APP_USER ARG APP_USER
ARG APP_RELATIVE_PATH ARG APP_RELATIVE_PATH
ARG MODULE ARG MODULE
@@ -50,6 +50,9 @@ WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS} COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic alembic
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic.ini .
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/oasst_inference_server oasst_inference_server
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py . COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
-50
View File
@@ -1,50 +0,0 @@
# syntax=docker/dockerfile:1
ARG APP_USER="text-client"
ARG APP_RELATIVE_PATH="inference/text-client"
FROM python:3.10-alpine3.17 as build
ARG APP_RELATIVE_PATH
WORKDIR /build
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt
FROM python:3.10-alpine3.17 as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
RUN adduser \
--disabled-password \
--no-create-home \
"${APP_USER}"
USER ${APP_USER}
WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
FROM base-env as prod
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
+1 -1
View File
@@ -48,7 +48,7 @@ WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS} COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py . COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}" CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
+4 -2
View File
@@ -3,9 +3,11 @@
# Creates a tmux window with splits for the individual services # Creates a tmux window with splits for the individual services
tmux new-session -d -s "inference-dev-setup" tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m tmux send-keys "docker run --rm -it -p 5432:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux split-window -h tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m tmux send-keys "docker run --rm -it -p 6379:6379 --name redis redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h tmux split-window -h
tmux send-keys "cd server" C-m tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m tmux send-keys "uvicorn main:app --reload" C-m
+105
View File
@@ -0,0 +1,105 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
hooks = black
black.type = console_scripts
black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+1
View File
@@ -0,0 +1 @@
Generic single-database configuration.
+78
View File
@@ -0,0 +1,78 @@
from logging.config import fileConfig
import sqlmodel
from alembic import context
from oasst_inference_server import models # noqa: F401
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = sqlmodel.SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.get_context()._ensure_version_table()
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+25
View File
@@ -0,0 +1,25 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
@@ -0,0 +1,36 @@
"""initial commit
Revision ID: 3a4cd8777eb2
Revises:
Create Date: 2023-02-10 02:21:27.086772
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3a4cd8777eb2"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"chat",
sa.Column("conversation", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("pending_message_request", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("message_request_state", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("chat")
# ### end Alembic commands ###
+89 -113
View File
@@ -1,14 +1,22 @@
import asyncio import asyncio
import enum import contextlib
import uuid import time
from pathlib import Path
import alembic.command
import alembic.config
import fastapi import fastapi
import pydantic
import redis.asyncio as redis import redis.asyncio as redis
import sqlmodel
import websockets.exceptions import websockets.exceptions
from fastapi import Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from oasst_shared.schemas import inference, protocol from oasst_inference_server import interface
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import db_engine
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from prometheus_fastapi_instrumentator import Instrumentator from prometheus_fastapi_instrumentator import Instrumentator
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
@@ -31,129 +39,97 @@ app.add_middleware(
) )
class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
sse_retry_timeout: int = 15000
settings = Settings()
# create async redis client # create async redis client
redisClient = redis.Redis( redisClient = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
) )
class MessageRequest(pydantic.BaseModel): def create_session():
message: str = pydantic.Field(..., repr=False) with sqlmodel.Session(db_engine) as session:
model_name: str = "distilgpt2" yield session
max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name
class TokenResponseEvent(pydantic.BaseModel): def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
token: inference.TokenResponse repository = ChatRepository(session)
return repository
class MessageRequestState(str, enum.Enum): if settings.update_alembic:
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
@app.on_event("startup")
def alembic_upgrade():
logger.info("Attempting to upgrade alembic on startup")
retry = 0
while True:
try:
alembic_ini_path = Path(__file__).parent / "alembic.ini"
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
alembic.command.upgrade(alembic_cfg, "head")
logger.info("Successfully upgraded alembic on startup")
break
except Exception:
logger.exception("Alembic upgrade failed on startup")
retry += 1
if retry >= settings.alembic_retries:
raise
class CreateChatRequest(pydantic.BaseModel): timeout = settings.alembic_retry_timeout * 2**retry
pass logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
time.sleep(timeout)
class ChatListEntry(pydantic.BaseModel):
id: str
class ChatEntry(pydantic.BaseModel):
id: str
conversation: protocol.Conversation
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None
def to_list_entry(self) -> ChatListEntry:
return ChatListEntry(id=self.id)
def to_entry(self) -> ChatEntry:
return ChatEntry(id=self.id, conversation=self.conversation)
# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}
@app.get("/chat") @app.get("/chat")
async def list_chats() -> ListChatsResponse: async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse:
"""Lists all chats.""" """Lists all chats."""
logger.info("Listing all chats.") logger.info("Listing all chats.")
chats = [chat.to_list_entry() for chat in CHATS.values()] chats = chat_repository.get_chat_list()
return ListChatsResponse(chats=chats) return interface.ListChatsResponse(chats=chats)
@app.post("/chat") @app.post("/chat")
async def create_chat(request: CreateChatRequest) -> ChatListEntry: async def create_chat(
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
) -> interface.ChatListEntry:
"""Allows a client to create a new chat.""" """Allows a client to create a new chat."""
logger.info(f"Received {request}") logger.info(f"Received {request}")
chat = DbChatEntry() chat = chat_repository.create_chat()
CHATS[chat.id] = chat return chat.to_list_entry()
return ChatListEntry(id=chat.id)
@app.get("/chat/{id}") @app.get("/chat/{id}")
async def get_chat(id: str) -> ChatEntry: async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ChatEntry:
"""Allows a client to get the current state of a chat.""" """Allows a client to get the current state of a chat."""
return CHATS[id].to_entry() chat = chat_repository.get_chat_entry_by_id(id)
return chat
@app.post("/chat/{id}/message") @app.post("/chat/{id}/message")
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request): async def create_message(
id: str,
message_request: interface.MessageRequest,
fastapi_request: fastapi.Request,
chat_repository: ChatRepository = Depends(create_chat_repository),
) -> EventSourceResponse:
"""Allows the client to stream the results of a request.""" """Allows the client to stream the results of a request."""
chat = CHATS[id] try:
if not chat.conversation.is_prompter_turn: chat_repository.add_prompter_message(id=id, message_request=message_request)
raise fastapi.HTTPException(status_code=400, detail="Not your turn") except Exception:
if chat.pending_message_request is not None: logger.exception("Error adding prompter message")
raise fastapi.HTTPException(status_code=400, detail="Already pending") return fastapi.Response(status_code=500)
chat.conversation.messages.append( async def event_generator(id):
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)
chat.pending_message_request = message_request
chat.message_request_state = MessageRequestState.pending
async def event_generator():
result_data = [] result_data = []
try: try:
while True: while True:
if await fastapi_request.is_disconnected(): if await fastapi_request.is_disconnected():
logger.warning("Client disconnected") logger.warning("Client disconnected")
break return
item = await redisClient.blpop(chat.id, 1) item = await redisClient.blpop(id, 1)
if item is None: if item is None:
continue continue
@@ -166,47 +142,44 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
yield { yield {
"retry": settings.sse_retry_timeout, "retry": settings.sse_retry_timeout,
"data": TokenResponseEvent(token=response_packet.token).json(), "data": interface.TokenResponseEvent(token=response_packet.token).json(),
} }
logger.info(f"Finished streaming {chat.id} {len(result_data)=}") logger.info(f"Finished streaming {id} {len(result_data)=}")
except Exception: except Exception:
logger.exception(f"Error streaming {chat.id}") logger.exception(f"Error streaming {id}")
raise
chat.conversation.messages.append( try:
protocol.ConversationMessage( with contextlib.contextmanager(create_session)() as session:
text=response_packet.generated_text.text, chat_repository = create_chat_repository(session)
is_assistant=True, chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
) except Exception:
) logger.exception("Error adding assistant message")
chat.pending_message_request = None
return EventSourceResponse(event_generator()) return EventSourceResponse(event_generator(id))
@app.websocket("/work") @app.websocket("/work")
async def work(websocket: fastapi.WebSocket): async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)):
await websocket.accept() await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
try: try:
while True: while True:
print(websocket.client_state)
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
logger.warning("Worker disconnected") logger.warning("Worker disconnected")
break break
# find a pending task that matches the worker's config # find a pending task that matches the worker's config
# could also be implemented using task queues # could also be implemented using task queues
# but general compatibility matching is tricky # but general compatibility matching is tricky
for chat in CHATS.values(): for chat in chat_repository.get_pending_chats():
if (request := chat.pending_message_request) is not None: request = chat.pending_message_request
if chat.message_request_state == MessageRequestState.pending: if request.compatible_with(worker_config):
if request.compatible_with(worker_config): break
break
else: else:
logger.debug("No pending tasks")
await asyncio.sleep(1) await asyncio.sleep(1)
continue continue
chat.message_request_state = MessageRequestState.in_progress chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
work_request = inference.WorkRequest( work_request = inference.WorkRequest(
conversation=chat.conversation, conversation=chat.conversation,
@@ -214,15 +187,17 @@ async def work(websocket: fastapi.WebSocket):
max_new_tokens=request.max_new_tokens, max_new_tokens=request.max_new_tokens,
) )
logger.info(f"Created {work_request}") logger.info(f"Created {work_request=}")
try: try:
await websocket.send_text(work_request.json()) await websocket.send_text(work_request.json())
except websockets.exceptions.ConnectionClosedError: except websockets.exceptions.ConnectionClosedError:
logger.warning("Worker disconnected") logger.warning("Worker disconnected")
websocket.close() websocket.close()
chat.message_request_state = MessageRequestState.pending chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
break break
logger.debug(f"Sent {work_request=} to worker.")
try: try:
in_progress = False in_progress = False
while True: while True:
@@ -232,18 +207,19 @@ async def work(websocket: fastapi.WebSocket):
in_progress = True in_progress = True
await redisClient.rpush(chat.id, response_packet.json()) await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end: if response_packet.is_end:
logger.debug(f"Received {response_packet=} from worker. Ending.")
break break
except fastapi.WebSocketException: except fastapi.WebSocketException:
# TODO: handle this better # TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}") logger.exception(f"Websocket closed during handling of {chat.id}")
if in_progress: if in_progress:
logger.warning(f"Aborting {chat.id=}") logger.warning(f"Aborting {chat.id=}")
chat.message_request_state = MessageRequestState.aborted_by_worker chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
else: else:
logger.warning(f"Marking {chat.id=} as pending since no work was done.") logger.warning(f"Marking {chat.id=} as pending since no work was done.")
chat.message_request_state = MessageRequestState.pending chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
raise raise
chat.message_request_state = MessageRequestState.complete chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
except fastapi.WebSocketException: except fastapi.WebSocketException:
logger.exception("Websocket closed") logger.exception("Websocket closed")
@@ -0,0 +1,79 @@
import fastapi
import sqlmodel
from loguru import logger
from oasst_inference_server import interface, models
from oasst_shared.schemas import protocol
from sqlalchemy.sql.operators import is_not
class ChatRepository:
def __init__(self, session: sqlmodel.Session) -> None:
self.session = session
def get_chats(self) -> list[models.DbChatEntry]:
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
def get_pending_chats(self) -> list[models.DbChatEntry]:
return self.session.exec(
sqlmodel.select(models.DbChatEntry).where(
is_not(models.DbChatEntry.pending_message_request, None),
models.DbChatEntry.message_request_state == interface.MessageRequestState.pending,
)
).all()
def get_chat_list(self) -> list[interface.ChatListEntry]:
chats = self.get_chats()
return [chat.to_list_entry() for chat in chats]
def get_chat_by_id(self, id: str) -> models.DbChatEntry:
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one()
return chat
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry:
return self.get_chat_by_id(id).to_entry()
def create_chat(self) -> models.DbChatEntry:
chat = models.DbChatEntry()
self.session.add(chat)
self.session.commit()
return chat
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None:
logger.info(f"Adding prompter message {message_request} to chat {id}")
chat = self.get_chat_by_id(id)
if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None:
raise fastapi.HTTPException(status_code=400, detail="Already pending")
chat.conversation.messages.append(
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)
chat.pending_message_request = message_request
chat.message_request_state = interface.MessageRequestState.pending
self.session.commit()
logger.debug(f"Added prompter message {message_request} to chat {id}")
def add_assistant_message(self, id: str, text: str) -> None:
logger.info(f"Adding assistant message {text} to chat {id}")
chat = self.get_chat_by_id(id)
chat.conversation.messages.append(
protocol.ConversationMessage(
text=text,
is_assistant=True,
)
)
chat.pending_message_request = None
self.session.commit()
logger.debug(f"Added assistant message {text} to chat {id}")
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None:
logger.info(f"Setting chat {id} state to {state}")
chat = self.get_chat_by_id(id)
chat.message_request_state = state
self.session.commit()
logger.debug(f"Set chat {id} state to {state}")
@@ -0,0 +1,41 @@
import json
import pydantic.json
import sqlmodel
from loguru import logger
from oasst_inference_server import models
from oasst_inference_server.settings import settings
def default_json_serializer(obj):
class_name = obj.__class__.__name__
encoded = pydantic.json.pydantic_encoder(obj)
encoded["_classname_"] = class_name
return encoded
def custom_json_serializer(obj):
return json.dumps(obj, default=default_json_serializer)
def custom_json_deserializer(s):
d = json.loads(s)
if not isinstance(d, dict):
return d
match d.get("_classname_"):
case "Conversation":
return models.protocol.Conversation.parse_obj(d)
case "MessageRequest":
return models.interface.MessageRequest.parse_obj(d)
case None:
return d
case _:
logger.error(f"Unknown class {d['_classname_']}")
raise ValueError(f"Unknown class {d['_classname_']}")
db_engine = sqlmodel.create_engine(
settings.database_uri,
json_serializer=custom_json_serializer,
json_deserializer=custom_json_deserializer,
)
@@ -0,0 +1,41 @@
import enum
import pydantic
from oasst_shared.schemas import inference, protocol
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name
class TokenResponseEvent(pydantic.BaseModel):
token: inference.TokenResponse
class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
class CreateChatRequest(pydantic.BaseModel):
pass
class ChatListEntry(pydantic.BaseModel):
id: str
class ChatEntry(pydantic.BaseModel):
id: str
conversation: protocol.Conversation
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
@@ -0,0 +1,23 @@
from uuid import uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_inference_server import interface
from oasst_shared.schemas import protocol
from sqlmodel import Field, SQLModel
class DbChatEntry(SQLModel, table=True):
__tablename__ = "chat"
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
conversation: protocol.Conversation = Field(default_factory=protocol.Conversation, sa_column=sa.Column(pg.JSONB))
pending_message_request: interface.MessageRequest | None = Field(None, sa_column=sa.Column(pg.JSONB))
message_request_state: interface.MessageRequestState | None = Field(None, sa_column=sa.Column(pg.JSONB))
def to_list_entry(self) -> interface.ChatListEntry:
return interface.ChatListEntry(id=self.id)
def to_entry(self) -> interface.ChatEntry:
return interface.ChatEntry(id=self.id, conversation=self.conversation)
@@ -0,0 +1,38 @@
from typing import Any
import pydantic
class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
sse_retry_timeout: int = 15000
update_alembic: bool = True
alembic_retries: int = 5
alembic_retry_timeout: int = 1
postgres_host: str = "localhost"
postgres_port: str = "5432"
postgres_user: str = "postgres"
postgres_password: str = "postgres"
postgres_db: str = "postgres"
database_uri: str | None = None
@pydantic.validator("database_uri", pre=True)
def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
if isinstance(v, str):
return v
return pydantic.PostgresDsn.build(
scheme="postgresql",
user=values.get("postgres_user"),
password=values.get("postgres_password"),
host=values.get("postgres_host"),
port=values.get("postgres_port"),
path=f"/{values.get('postgres_db') or ''}",
)
settings = Settings()
+4 -1
View File
@@ -1,7 +1,10 @@
alembic
fastapi[all] fastapi[all]
loguru loguru
prometheus-fastapi-instrumentator==5.9.1 prometheus-fastapi-instrumentator
psycopg2-binary
pydantic pydantic
redis redis
sqlmodel
sse-starlette sse-starlette
websockets websockets
+2
View File
@@ -1,6 +1,7 @@
"""Simple REPL frontend.""" """Simple REPL frontend."""
import json import json
import time
import requests import requests
import sseclient import sseclient
@@ -42,6 +43,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
break break
except Exception: except Exception:
typer.echo("Error, restarting chat...") typer.echo("Error, restarting chat...")
time.sleep(1)
if __name__ == "__main__": if __name__ == "__main__":
+3
View File
@@ -17,6 +17,8 @@ def main(
model_name: str = "distilgpt2", model_name: str = "distilgpt2",
inference_server_url: str = "http://localhost:8001", inference_server_url: str = "http://localhost:8001",
): ):
utils.wait_for_inference_server(inference_server_url)
def on_open(ws: websocket.WebSocket): def on_open(ws: websocket.WebSocket):
logger.info("Connected to backend, sending config...") logger.info("Connected to backend, sending config...")
worker_config = inference.WorkerConfig(model_name=model_name) worker_config = inference.WorkerConfig(model_name=model_name)
@@ -93,6 +95,7 @@ def main(
), ),
).json() ).json()
) )
logger.info("Work complete. Waiting for more work...")
def on_error(ws: websocket.WebSocket, error: Exception): def on_error(ws: websocket.WebSocket, error: Exception):
try: try:
+22
View File
@@ -1,7 +1,11 @@
import collections import collections
import random
import time
from typing import Literal from typing import Literal
import interface import interface
import requests
from loguru import logger
class TokenBuffer: class TokenBuffer:
@@ -38,3 +42,21 @@ class TokenBuffer:
yield from self.tokens yield from self.tokens
else: else:
yield from self.tokens yield from self.tokens
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
health_url = f"{inference_server_url}/health"
time_limit = time.time() + timeout
while True:
try:
response = requests.get(health_url)
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
raise
sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
time.sleep(sleep_duration)
else:
logger.info("Inference server is ready")
break