Added database to inference server (#1446)

* added db for inference

* fixed dockerfiles for inference
This commit is contained in:
Yannic Kilcher
2023-02-10 22:51:35 +01:00
committed by GitHub
parent 911fc2affc
commit 90c3d5640e
21 changed files with 627 additions and 192 deletions
+30 -23
View File
@@ -136,6 +136,23 @@ services:
- "3000:3000"
command: bash wait-for-postgres.sh node server.js
# This DB is for Inference
inference-db:
image: postgres
restart: always
ports:
- 5434:5432
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: oasst_inference
healthcheck:
test: ["CMD", "pg_isready", "-U", "postgres"]
interval: 2s
timeout: 2s
retries: 10
profiles: ["inference"]
inference-server:
build:
dockerfile: docker/inference/Dockerfile.server
@@ -145,13 +162,25 @@ services:
environment:
- "PORT=8000"
- "REDIS_HOST=redis"
- POSTGRES_HOST=inference-db
- POSTGRES_DB=oasst_inference
volumes:
- "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/server:/opt/inference/server"
restart: unless-stopped
ports:
- "8000:8000"
depends_on:
redis:
condition: service_healthy
inference-db:
condition: service_healthy
profiles: ["inference"]
inference-text-generation-server:
image: ghcr.io/huggingface/text-generation-inference
environment:
- "MODEL_ID=distilgpt2"
profiles: ["inference"]
inference-worker:
@@ -167,29 +196,7 @@ services:
- "./oasst-shared:/opt/inference/lib/oasst-shared"
- "./inference/worker:/opt/inference/worker"
depends_on:
- inference-server
- inference-text-generation-server
deploy:
replicas: 1
profiles: ["inference"]
inference-text-client:
build:
dockerfile: docker/inference/Dockerfile.text-client
context: .
image: oasst-inference-text-client
environment:
- "BACKEND_URL=http://inference-server:8000"
tty: true
stdin_open: true
volumes:
- "./inference/worker:/opt/inference/worker"
restart: unless-stopped
depends_on:
- inference-server
profiles: ["inference"]
inference-text-generation-server:
image: ghcr.io/huggingface/text-generation-inference
environment:
- "MODEL_ID=distilgpt2"
profiles: ["inference"]
+5 -2
View File
@@ -7,7 +7,7 @@ ARG APP_USER="${MODULE}-${SERVICE}"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
FROM python:3-slim as build
FROM python:3.10-slim as build
ARG APP_RELATIVE_PATH
WORKDIR /build
@@ -22,7 +22,7 @@ RUN --mount=type=cache,target=/var/cache/pip \
FROM python:3.10-alpine3.17 as base-env
FROM python:3.10-slim as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ARG MODULE
@@ -50,6 +50,9 @@ WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic alembic
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic.ini .
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/oasst_inference_server oasst_inference_server
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
-50
View File
@@ -1,50 +0,0 @@
# syntax=docker/dockerfile:1
ARG APP_USER="text-client"
ARG APP_RELATIVE_PATH="inference/text-client"
FROM python:3.10-alpine3.17 as build
ARG APP_RELATIVE_PATH
WORKDIR /build
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN --mount=type=cache,target=/var/cache/pip \
pip install \
--cache-dir=/var/cache/pip \
--target=lib \
-r requirements.txt
FROM python:3.10-alpine3.17 as base-env
ARG APP_USER
ARG APP_RELATIVE_PATH
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
ENV PATH="${PATH}:${APP_LIBS}/bin"
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
RUN adduser \
--disabled-password \
--no-create-home \
"${APP_USER}"
USER ${APP_USER}
WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
FROM base-env as prod
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
+1 -1
View File
@@ -48,7 +48,7 @@ WORKDIR ${APP_ROOT}
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
+4 -2
View File
@@ -3,9 +3,11 @@
# Creates a tmux window with splits for the individual services
tmux new-session -d -s "inference-dev-setup"
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
tmux send-keys "docker run --rm -it -p 5432:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
tmux send-keys "docker run --rm -it -p 6379:6379 --name redis redis" C-m
tmux split-window -h
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
tmux split-window -h
tmux send-keys "cd server" C-m
tmux send-keys "uvicorn main:app --reload" C-m
+105
View File
@@ -0,0 +1,105 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `alembic[tz]` to the pip requirements
# string value is passed to dateutil.tz.gettz()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
hooks = black
black.type = console_scripts
black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
+1
View File
@@ -0,0 +1 @@
Generic single-database configuration.
+78
View File
@@ -0,0 +1,78 @@
from logging.config import fileConfig
import sqlmodel
from alembic import context
from oasst_inference_server import models # noqa: F401
from sqlalchemy import engine_from_config, pool
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = sqlmodel.SQLModel.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.get_context()._ensure_version_table()
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
+25
View File
@@ -0,0 +1,25 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
import sqlmodel
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
@@ -0,0 +1,36 @@
"""initial commit
Revision ID: 3a4cd8777eb2
Revises:
Create Date: 2023-02-10 02:21:27.086772
"""
import sqlalchemy as sa
import sqlmodel
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3a4cd8777eb2"
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"chat",
sa.Column("conversation", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("pending_message_request", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("message_request_state", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("chat")
# ### end Alembic commands ###
+89 -113
View File
@@ -1,14 +1,22 @@
import asyncio
import enum
import uuid
import contextlib
import time
from pathlib import Path
import alembic.command
import alembic.config
import fastapi
import pydantic
import redis.asyncio as redis
import sqlmodel
import websockets.exceptions
from fastapi import Depends
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from oasst_shared.schemas import inference, protocol
from oasst_inference_server import interface
from oasst_inference_server.chat_repository import ChatRepository
from oasst_inference_server.database import db_engine
from oasst_inference_server.settings import settings
from oasst_shared.schemas import inference
from prometheus_fastapi_instrumentator import Instrumentator
from sse_starlette.sse import EventSourceResponse
@@ -31,129 +39,97 @@ app.add_middleware(
)
class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
sse_retry_timeout: int = 15000
settings = Settings()
# create async redis client
redisClient = redis.Redis(
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
)
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name
def create_session():
with sqlmodel.Session(db_engine) as session:
yield session
class TokenResponseEvent(pydantic.BaseModel):
token: inference.TokenResponse
def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
repository = ChatRepository(session)
return repository
class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
if settings.update_alembic:
@app.on_event("startup")
def alembic_upgrade():
logger.info("Attempting to upgrade alembic on startup")
retry = 0
while True:
try:
alembic_ini_path = Path(__file__).parent / "alembic.ini"
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
alembic.command.upgrade(alembic_cfg, "head")
logger.info("Successfully upgraded alembic on startup")
break
except Exception:
logger.exception("Alembic upgrade failed on startup")
retry += 1
if retry >= settings.alembic_retries:
raise
class CreateChatRequest(pydantic.BaseModel):
pass
class ChatListEntry(pydantic.BaseModel):
id: str
class ChatEntry(pydantic.BaseModel):
id: str
conversation: protocol.Conversation
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
class DbChatEntry(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
pending_message_request: MessageRequest | None = None
message_request_state: MessageRequestState | None = None
def to_list_entry(self) -> ChatListEntry:
return ChatListEntry(id=self.id)
def to_entry(self) -> ChatEntry:
return ChatEntry(id=self.id, conversation=self.conversation)
# TODO: make real database
CHATS: dict[str, DbChatEntry] = {}
timeout = settings.alembic_retry_timeout * 2**retry
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
time.sleep(timeout)
@app.get("/chat")
async def list_chats() -> ListChatsResponse:
async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse:
"""Lists all chats."""
logger.info("Listing all chats.")
chats = [chat.to_list_entry() for chat in CHATS.values()]
return ListChatsResponse(chats=chats)
chats = chat_repository.get_chat_list()
return interface.ListChatsResponse(chats=chats)
@app.post("/chat")
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
async def create_chat(
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
) -> interface.ChatListEntry:
"""Allows a client to create a new chat."""
logger.info(f"Received {request}")
chat = DbChatEntry()
CHATS[chat.id] = chat
return ChatListEntry(id=chat.id)
chat = chat_repository.create_chat()
return chat.to_list_entry()
@app.get("/chat/{id}")
async def get_chat(id: str) -> ChatEntry:
async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ChatEntry:
"""Allows a client to get the current state of a chat."""
return CHATS[id].to_entry()
chat = chat_repository.get_chat_entry_by_id(id)
return chat
@app.post("/chat/{id}/message")
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
async def create_message(
id: str,
message_request: interface.MessageRequest,
fastapi_request: fastapi.Request,
chat_repository: ChatRepository = Depends(create_chat_repository),
) -> EventSourceResponse:
"""Allows the client to stream the results of a request."""
chat = CHATS[id]
if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None:
raise fastapi.HTTPException(status_code=400, detail="Already pending")
try:
chat_repository.add_prompter_message(id=id, message_request=message_request)
except Exception:
logger.exception("Error adding prompter message")
return fastapi.Response(status_code=500)
chat.conversation.messages.append(
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)
chat.pending_message_request = message_request
chat.message_request_state = MessageRequestState.pending
async def event_generator():
async def event_generator(id):
result_data = []
try:
while True:
if await fastapi_request.is_disconnected():
logger.warning("Client disconnected")
break
return
item = await redisClient.blpop(chat.id, 1)
item = await redisClient.blpop(id, 1)
if item is None:
continue
@@ -166,47 +142,44 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
yield {
"retry": settings.sse_retry_timeout,
"data": TokenResponseEvent(token=response_packet.token).json(),
"data": interface.TokenResponseEvent(token=response_packet.token).json(),
}
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
logger.info(f"Finished streaming {id} {len(result_data)=}")
except Exception:
logger.exception(f"Error streaming {chat.id}")
logger.exception(f"Error streaming {id}")
raise
chat.conversation.messages.append(
protocol.ConversationMessage(
text=response_packet.generated_text.text,
is_assistant=True,
)
)
chat.pending_message_request = None
try:
with contextlib.contextmanager(create_session)() as session:
chat_repository = create_chat_repository(session)
chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
except Exception:
logger.exception("Error adding assistant message")
return EventSourceResponse(event_generator())
return EventSourceResponse(event_generator(id))
@app.websocket("/work")
async def work(websocket: fastapi.WebSocket):
async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)):
await websocket.accept()
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
try:
while True:
print(websocket.client_state)
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
logger.warning("Worker disconnected")
break
# find a pending task that matches the worker's config
# could also be implemented using task queues
# but general compatibility matching is tricky
for chat in CHATS.values():
if (request := chat.pending_message_request) is not None:
if chat.message_request_state == MessageRequestState.pending:
if request.compatible_with(worker_config):
break
for chat in chat_repository.get_pending_chats():
request = chat.pending_message_request
if request.compatible_with(worker_config):
break
else:
logger.debug("No pending tasks")
await asyncio.sleep(1)
continue
chat.message_request_state = MessageRequestState.in_progress
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
work_request = inference.WorkRequest(
conversation=chat.conversation,
@@ -214,15 +187,17 @@ async def work(websocket: fastapi.WebSocket):
max_new_tokens=request.max_new_tokens,
)
logger.info(f"Created {work_request}")
logger.info(f"Created {work_request=}")
try:
await websocket.send_text(work_request.json())
except websockets.exceptions.ConnectionClosedError:
logger.warning("Worker disconnected")
websocket.close()
chat.message_request_state = MessageRequestState.pending
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
break
logger.debug(f"Sent {work_request=} to worker.")
try:
in_progress = False
while True:
@@ -232,18 +207,19 @@ async def work(websocket: fastapi.WebSocket):
in_progress = True
await redisClient.rpush(chat.id, response_packet.json())
if response_packet.is_end:
logger.debug(f"Received {response_packet=} from worker. Ending.")
break
except fastapi.WebSocketException:
# TODO: handle this better
logger.exception(f"Websocket closed during handling of {chat.id}")
if in_progress:
logger.warning(f"Aborting {chat.id=}")
chat.message_request_state = MessageRequestState.aborted_by_worker
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
else:
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
chat.message_request_state = MessageRequestState.pending
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
raise
chat.message_request_state = MessageRequestState.complete
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
except fastapi.WebSocketException:
logger.exception("Websocket closed")
@@ -0,0 +1,79 @@
import fastapi
import sqlmodel
from loguru import logger
from oasst_inference_server import interface, models
from oasst_shared.schemas import protocol
from sqlalchemy.sql.operators import is_not
class ChatRepository:
def __init__(self, session: sqlmodel.Session) -> None:
self.session = session
def get_chats(self) -> list[models.DbChatEntry]:
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
def get_pending_chats(self) -> list[models.DbChatEntry]:
return self.session.exec(
sqlmodel.select(models.DbChatEntry).where(
is_not(models.DbChatEntry.pending_message_request, None),
models.DbChatEntry.message_request_state == interface.MessageRequestState.pending,
)
).all()
def get_chat_list(self) -> list[interface.ChatListEntry]:
chats = self.get_chats()
return [chat.to_list_entry() for chat in chats]
def get_chat_by_id(self, id: str) -> models.DbChatEntry:
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one()
return chat
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry:
return self.get_chat_by_id(id).to_entry()
def create_chat(self) -> models.DbChatEntry:
chat = models.DbChatEntry()
self.session.add(chat)
self.session.commit()
return chat
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None:
logger.info(f"Adding prompter message {message_request} to chat {id}")
chat = self.get_chat_by_id(id)
if not chat.conversation.is_prompter_turn:
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
if chat.pending_message_request is not None:
raise fastapi.HTTPException(status_code=400, detail="Already pending")
chat.conversation.messages.append(
protocol.ConversationMessage(
text=message_request.message,
is_assistant=False,
)
)
chat.pending_message_request = message_request
chat.message_request_state = interface.MessageRequestState.pending
self.session.commit()
logger.debug(f"Added prompter message {message_request} to chat {id}")
def add_assistant_message(self, id: str, text: str) -> None:
logger.info(f"Adding assistant message {text} to chat {id}")
chat = self.get_chat_by_id(id)
chat.conversation.messages.append(
protocol.ConversationMessage(
text=text,
is_assistant=True,
)
)
chat.pending_message_request = None
self.session.commit()
logger.debug(f"Added assistant message {text} to chat {id}")
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None:
logger.info(f"Setting chat {id} state to {state}")
chat = self.get_chat_by_id(id)
chat.message_request_state = state
self.session.commit()
logger.debug(f"Set chat {id} state to {state}")
@@ -0,0 +1,41 @@
import json
import pydantic.json
import sqlmodel
from loguru import logger
from oasst_inference_server import models
from oasst_inference_server.settings import settings
def default_json_serializer(obj):
class_name = obj.__class__.__name__
encoded = pydantic.json.pydantic_encoder(obj)
encoded["_classname_"] = class_name
return encoded
def custom_json_serializer(obj):
return json.dumps(obj, default=default_json_serializer)
def custom_json_deserializer(s):
d = json.loads(s)
if not isinstance(d, dict):
return d
match d.get("_classname_"):
case "Conversation":
return models.protocol.Conversation.parse_obj(d)
case "MessageRequest":
return models.interface.MessageRequest.parse_obj(d)
case None:
return d
case _:
logger.error(f"Unknown class {d['_classname_']}")
raise ValueError(f"Unknown class {d['_classname_']}")
db_engine = sqlmodel.create_engine(
settings.database_uri,
json_serializer=custom_json_serializer,
json_deserializer=custom_json_deserializer,
)
@@ -0,0 +1,41 @@
import enum
import pydantic
from oasst_shared.schemas import inference, protocol
class MessageRequest(pydantic.BaseModel):
message: str = pydantic.Field(..., repr=False)
model_name: str = "distilgpt2"
max_new_tokens: int = 100
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
return self.model_name == worker_config.model_name
class TokenResponseEvent(pydantic.BaseModel):
token: inference.TokenResponse
class MessageRequestState(str, enum.Enum):
pending = "pending"
in_progress = "in_progress"
complete = "complete"
aborted_by_worker = "aborted_by_worker"
class CreateChatRequest(pydantic.BaseModel):
pass
class ChatListEntry(pydantic.BaseModel):
id: str
class ChatEntry(pydantic.BaseModel):
id: str
conversation: protocol.Conversation
class ListChatsResponse(pydantic.BaseModel):
chats: list[ChatListEntry]
@@ -0,0 +1,23 @@
from uuid import uuid4
import sqlalchemy as sa
import sqlalchemy.dialects.postgresql as pg
from oasst_inference_server import interface
from oasst_shared.schemas import protocol
from sqlmodel import Field, SQLModel
class DbChatEntry(SQLModel, table=True):
__tablename__ = "chat"
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
conversation: protocol.Conversation = Field(default_factory=protocol.Conversation, sa_column=sa.Column(pg.JSONB))
pending_message_request: interface.MessageRequest | None = Field(None, sa_column=sa.Column(pg.JSONB))
message_request_state: interface.MessageRequestState | None = Field(None, sa_column=sa.Column(pg.JSONB))
def to_list_entry(self) -> interface.ChatListEntry:
return interface.ChatListEntry(id=self.id)
def to_entry(self) -> interface.ChatEntry:
return interface.ChatEntry(id=self.id, conversation=self.conversation)
@@ -0,0 +1,38 @@
from typing import Any
import pydantic
class Settings(pydantic.BaseSettings):
redis_host: str = "localhost"
redis_port: int = 6379
redis_db: int = 0
sse_retry_timeout: int = 15000
update_alembic: bool = True
alembic_retries: int = 5
alembic_retry_timeout: int = 1
postgres_host: str = "localhost"
postgres_port: str = "5432"
postgres_user: str = "postgres"
postgres_password: str = "postgres"
postgres_db: str = "postgres"
database_uri: str | None = None
@pydantic.validator("database_uri", pre=True)
def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
if isinstance(v, str):
return v
return pydantic.PostgresDsn.build(
scheme="postgresql",
user=values.get("postgres_user"),
password=values.get("postgres_password"),
host=values.get("postgres_host"),
port=values.get("postgres_port"),
path=f"/{values.get('postgres_db') or ''}",
)
settings = Settings()
+4 -1
View File
@@ -1,7 +1,10 @@
alembic
fastapi[all]
loguru
prometheus-fastapi-instrumentator==5.9.1
prometheus-fastapi-instrumentator
psycopg2-binary
pydantic
redis
sqlmodel
sse-starlette
websockets
+2
View File
@@ -1,6 +1,7 @@
"""Simple REPL frontend."""
import json
import time
import requests
import sseclient
@@ -42,6 +43,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
break
except Exception:
typer.echo("Error, restarting chat...")
time.sleep(1)
if __name__ == "__main__":
+3
View File
@@ -17,6 +17,8 @@ def main(
model_name: str = "distilgpt2",
inference_server_url: str = "http://localhost:8001",
):
utils.wait_for_inference_server(inference_server_url)
def on_open(ws: websocket.WebSocket):
logger.info("Connected to backend, sending config...")
worker_config = inference.WorkerConfig(model_name=model_name)
@@ -93,6 +95,7 @@ def main(
),
).json()
)
logger.info("Work complete. Waiting for more work...")
def on_error(ws: websocket.WebSocket, error: Exception):
try:
+22
View File
@@ -1,7 +1,11 @@
import collections
import random
import time
from typing import Literal
import interface
import requests
from loguru import logger
class TokenBuffer:
@@ -38,3 +42,21 @@ class TokenBuffer:
yield from self.tokens
else:
yield from self.tokens
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
health_url = f"{inference_server_url}/health"
time_limit = time.time() + timeout
while True:
try:
response = requests.get(health_url)
response.raise_for_status()
except (requests.HTTPError, requests.ConnectionError):
if time.time() > time_limit:
raise
sleep_duration = random.uniform(0, 10)
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
time.sleep(sleep_duration)
else:
logger.info("Inference server is ready")
break