mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
Added database to inference server (#1446)
* added db for inference * fixed dockerfiles for inference
This commit is contained in:
+30
-23
@@ -136,6 +136,23 @@ services:
|
|||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
command: bash wait-for-postgres.sh node server.js
|
command: bash wait-for-postgres.sh node server.js
|
||||||
|
|
||||||
|
# This DB is for Inference
|
||||||
|
inference-db:
|
||||||
|
image: postgres
|
||||||
|
restart: always
|
||||||
|
ports:
|
||||||
|
- 5434:5432
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: oasst_inference
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "pg_isready", "-U", "postgres"]
|
||||||
|
interval: 2s
|
||||||
|
timeout: 2s
|
||||||
|
retries: 10
|
||||||
|
profiles: ["inference"]
|
||||||
|
|
||||||
inference-server:
|
inference-server:
|
||||||
build:
|
build:
|
||||||
dockerfile: docker/inference/Dockerfile.server
|
dockerfile: docker/inference/Dockerfile.server
|
||||||
@@ -145,13 +162,25 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- "PORT=8000"
|
- "PORT=8000"
|
||||||
- "REDIS_HOST=redis"
|
- "REDIS_HOST=redis"
|
||||||
|
- POSTGRES_HOST=inference-db
|
||||||
|
- POSTGRES_DB=oasst_inference
|
||||||
volumes:
|
volumes:
|
||||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||||
- "./inference/server:/opt/inference/server"
|
- "./inference/server:/opt/inference/server"
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
depends_on:
|
depends_on:
|
||||||
redis:
|
redis:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
|
inference-db:
|
||||||
|
condition: service_healthy
|
||||||
|
profiles: ["inference"]
|
||||||
|
|
||||||
|
inference-text-generation-server:
|
||||||
|
image: ghcr.io/huggingface/text-generation-inference
|
||||||
|
environment:
|
||||||
|
- "MODEL_ID=distilgpt2"
|
||||||
profiles: ["inference"]
|
profiles: ["inference"]
|
||||||
|
|
||||||
inference-worker:
|
inference-worker:
|
||||||
@@ -167,29 +196,7 @@ services:
|
|||||||
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
- "./oasst-shared:/opt/inference/lib/oasst-shared"
|
||||||
- "./inference/worker:/opt/inference/worker"
|
- "./inference/worker:/opt/inference/worker"
|
||||||
depends_on:
|
depends_on:
|
||||||
- inference-server
|
- inference-text-generation-server
|
||||||
deploy:
|
deploy:
|
||||||
replicas: 1
|
replicas: 1
|
||||||
profiles: ["inference"]
|
profiles: ["inference"]
|
||||||
|
|
||||||
inference-text-client:
|
|
||||||
build:
|
|
||||||
dockerfile: docker/inference/Dockerfile.text-client
|
|
||||||
context: .
|
|
||||||
image: oasst-inference-text-client
|
|
||||||
environment:
|
|
||||||
- "BACKEND_URL=http://inference-server:8000"
|
|
||||||
tty: true
|
|
||||||
stdin_open: true
|
|
||||||
volumes:
|
|
||||||
- "./inference/worker:/opt/inference/worker"
|
|
||||||
restart: unless-stopped
|
|
||||||
depends_on:
|
|
||||||
- inference-server
|
|
||||||
profiles: ["inference"]
|
|
||||||
|
|
||||||
inference-text-generation-server:
|
|
||||||
image: ghcr.io/huggingface/text-generation-inference
|
|
||||||
environment:
|
|
||||||
- "MODEL_ID=distilgpt2"
|
|
||||||
profiles: ["inference"]
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ ARG APP_USER="${MODULE}-${SERVICE}"
|
|||||||
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
||||||
|
|
||||||
|
|
||||||
FROM python:3-slim as build
|
FROM python:3.10-slim as build
|
||||||
ARG APP_RELATIVE_PATH
|
ARG APP_RELATIVE_PATH
|
||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
@@ -22,7 +22,7 @@ RUN --mount=type=cache,target=/var/cache/pip \
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
FROM python:3.10-alpine3.17 as base-env
|
FROM python:3.10-slim as base-env
|
||||||
ARG APP_USER
|
ARG APP_USER
|
||||||
ARG APP_RELATIVE_PATH
|
ARG APP_RELATIVE_PATH
|
||||||
ARG MODULE
|
ARG MODULE
|
||||||
@@ -50,6 +50,9 @@ WORKDIR ${APP_ROOT}
|
|||||||
|
|
||||||
|
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||||
|
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic alembic
|
||||||
|
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic.ini .
|
||||||
|
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/oasst_inference_server oasst_inference_server
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
|
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py .
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
# syntax=docker/dockerfile:1
|
|
||||||
|
|
||||||
ARG APP_USER="text-client"
|
|
||||||
ARG APP_RELATIVE_PATH="inference/text-client"
|
|
||||||
|
|
||||||
|
|
||||||
FROM python:3.10-alpine3.17 as build
|
|
||||||
ARG APP_RELATIVE_PATH
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/pip \
|
|
||||||
pip install \
|
|
||||||
--cache-dir=/var/cache/pip \
|
|
||||||
--target=lib \
|
|
||||||
-r requirements.txt
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FROM python:3.10-alpine3.17 as base-env
|
|
||||||
ARG APP_USER
|
|
||||||
ARG APP_RELATIVE_PATH
|
|
||||||
|
|
||||||
ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}"
|
|
||||||
ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib"
|
|
||||||
|
|
||||||
ENV PATH="${PATH}:${APP_LIBS}/bin"
|
|
||||||
ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}"
|
|
||||||
|
|
||||||
|
|
||||||
RUN adduser \
|
|
||||||
--disabled-password \
|
|
||||||
--no-create-home \
|
|
||||||
"${APP_USER}"
|
|
||||||
|
|
||||||
USER ${APP_USER}
|
|
||||||
|
|
||||||
WORKDIR ${APP_ROOT}
|
|
||||||
|
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
FROM base-env as prod
|
|
||||||
|
|
||||||
|
|
||||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}"
|
|
||||||
@@ -48,7 +48,7 @@ WORKDIR ${APP_ROOT}
|
|||||||
|
|
||||||
|
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS}
|
||||||
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py .
|
COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py .
|
||||||
|
|
||||||
|
|
||||||
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
|
CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}"
|
||||||
|
|||||||
@@ -3,9 +3,11 @@
|
|||||||
# Creates a tmux window with splits for the individual services
|
# Creates a tmux window with splits for the individual services
|
||||||
|
|
||||||
tmux new-session -d -s "inference-dev-setup"
|
tmux new-session -d -s "inference-dev-setup"
|
||||||
tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m
|
tmux send-keys "docker run --rm -it -p 5432:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m
|
||||||
tmux split-window -h
|
tmux split-window -h
|
||||||
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m
|
tmux send-keys "docker run --rm -it -p 6379:6379 --name redis redis" C-m
|
||||||
|
tmux split-window -h
|
||||||
|
tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m
|
||||||
tmux split-window -h
|
tmux split-window -h
|
||||||
tmux send-keys "cd server" C-m
|
tmux send-keys "cd server" C-m
|
||||||
tmux send-keys "uvicorn main:app --reload" C-m
|
tmux send-keys "uvicorn main:app --reload" C-m
|
||||||
|
|||||||
@@ -0,0 +1,105 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = %(here)s/alembic
|
||||||
|
|
||||||
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
|
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||||
|
# for all available tokens
|
||||||
|
file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python-dateutil library that can be
|
||||||
|
# installed by adding `alembic[tz]` to the pip requirements
|
||||||
|
# string value is passed to dateutil.tz.gettz()
|
||||||
|
# leave blank for localtime
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the
|
||||||
|
# "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to alembic/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||||
|
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||||
|
# Valid values for version_path_separator are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
# sqlalchemy.url = postgresql://<username>:<password>@<host>/<database_name>
|
||||||
|
sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
hooks = black
|
||||||
|
black.type = console_scripts
|
||||||
|
black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
Generic single-database configuration.
|
||||||
@@ -0,0 +1,78 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
import sqlmodel
|
||||||
|
from alembic import context
|
||||||
|
from oasst_inference_server import models # noqa: F401
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# from myapp import mymodel
|
||||||
|
# target_metadata = mymodel.Base.metadata
|
||||||
|
target_metadata = sqlmodel.SQLModel.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.get_context()._ensure_version_table()
|
||||||
|
connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE")
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = ${repr(up_revision)}
|
||||||
|
down_revision = ${repr(down_revision)}
|
||||||
|
branch_labels = ${repr(branch_labels)}
|
||||||
|
depends_on = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
"""initial commit
|
||||||
|
|
||||||
|
Revision ID: 3a4cd8777eb2
|
||||||
|
Revises:
|
||||||
|
Create Date: 2023-02-10 02:21:27.086772
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlmodel
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "3a4cd8777eb2"
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"chat",
|
||||||
|
sa.Column("conversation", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("pending_message_request", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("message_request_state", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||||
|
sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("chat")
|
||||||
|
# ### end Alembic commands ###
|
||||||
+89
-113
@@ -1,14 +1,22 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import enum
|
import contextlib
|
||||||
import uuid
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import alembic.command
|
||||||
|
import alembic.config
|
||||||
import fastapi
|
import fastapi
|
||||||
import pydantic
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
|
import sqlmodel
|
||||||
import websockets.exceptions
|
import websockets.exceptions
|
||||||
|
from fastapi import Depends
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from oasst_shared.schemas import inference, protocol
|
from oasst_inference_server import interface
|
||||||
|
from oasst_inference_server.chat_repository import ChatRepository
|
||||||
|
from oasst_inference_server.database import db_engine
|
||||||
|
from oasst_inference_server.settings import settings
|
||||||
|
from oasst_shared.schemas import inference
|
||||||
from prometheus_fastapi_instrumentator import Instrumentator
|
from prometheus_fastapi_instrumentator import Instrumentator
|
||||||
from sse_starlette.sse import EventSourceResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
|
|
||||||
@@ -31,129 +39,97 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Settings(pydantic.BaseSettings):
|
|
||||||
redis_host: str = "localhost"
|
|
||||||
redis_port: int = 6379
|
|
||||||
redis_db: int = 0
|
|
||||||
|
|
||||||
sse_retry_timeout: int = 15000
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
|
|
||||||
# create async redis client
|
# create async redis client
|
||||||
redisClient = redis.Redis(
|
redisClient = redis.Redis(
|
||||||
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
|
host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MessageRequest(pydantic.BaseModel):
|
def create_session():
|
||||||
message: str = pydantic.Field(..., repr=False)
|
with sqlmodel.Session(db_engine) as session:
|
||||||
model_name: str = "distilgpt2"
|
yield session
|
||||||
max_new_tokens: int = 100
|
|
||||||
|
|
||||||
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
|
||||||
return self.model_name == worker_config.model_name
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponseEvent(pydantic.BaseModel):
|
def create_chat_repository(session: sqlmodel.Session = Depends(create_session)):
|
||||||
token: inference.TokenResponse
|
repository = ChatRepository(session)
|
||||||
|
return repository
|
||||||
|
|
||||||
|
|
||||||
class MessageRequestState(str, enum.Enum):
|
if settings.update_alembic:
|
||||||
pending = "pending"
|
|
||||||
in_progress = "in_progress"
|
|
||||||
complete = "complete"
|
|
||||||
aborted_by_worker = "aborted_by_worker"
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
def alembic_upgrade():
|
||||||
|
logger.info("Attempting to upgrade alembic on startup")
|
||||||
|
retry = 0
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
alembic_ini_path = Path(__file__).parent / "alembic.ini"
|
||||||
|
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
|
||||||
|
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)
|
||||||
|
alembic.command.upgrade(alembic_cfg, "head")
|
||||||
|
logger.info("Successfully upgraded alembic on startup")
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Alembic upgrade failed on startup")
|
||||||
|
retry += 1
|
||||||
|
if retry >= settings.alembic_retries:
|
||||||
|
raise
|
||||||
|
|
||||||
class CreateChatRequest(pydantic.BaseModel):
|
timeout = settings.alembic_retry_timeout * 2**retry
|
||||||
pass
|
logger.warning(f"Retrying alembic upgrade in {timeout} seconds")
|
||||||
|
time.sleep(timeout)
|
||||||
|
|
||||||
class ChatListEntry(pydantic.BaseModel):
|
|
||||||
id: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChatEntry(pydantic.BaseModel):
|
|
||||||
id: str
|
|
||||||
conversation: protocol.Conversation
|
|
||||||
|
|
||||||
|
|
||||||
class ListChatsResponse(pydantic.BaseModel):
|
|
||||||
chats: list[ChatListEntry]
|
|
||||||
|
|
||||||
|
|
||||||
class DbChatEntry(pydantic.BaseModel):
|
|
||||||
id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4()))
|
|
||||||
conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation)
|
|
||||||
pending_message_request: MessageRequest | None = None
|
|
||||||
message_request_state: MessageRequestState | None = None
|
|
||||||
|
|
||||||
def to_list_entry(self) -> ChatListEntry:
|
|
||||||
return ChatListEntry(id=self.id)
|
|
||||||
|
|
||||||
def to_entry(self) -> ChatEntry:
|
|
||||||
return ChatEntry(id=self.id, conversation=self.conversation)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: make real database
|
|
||||||
CHATS: dict[str, DbChatEntry] = {}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/chat")
|
@app.get("/chat")
|
||||||
async def list_chats() -> ListChatsResponse:
|
async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse:
|
||||||
"""Lists all chats."""
|
"""Lists all chats."""
|
||||||
logger.info("Listing all chats.")
|
logger.info("Listing all chats.")
|
||||||
chats = [chat.to_list_entry() for chat in CHATS.values()]
|
chats = chat_repository.get_chat_list()
|
||||||
return ListChatsResponse(chats=chats)
|
return interface.ListChatsResponse(chats=chats)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat")
|
@app.post("/chat")
|
||||||
async def create_chat(request: CreateChatRequest) -> ChatListEntry:
|
async def create_chat(
|
||||||
|
request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository)
|
||||||
|
) -> interface.ChatListEntry:
|
||||||
"""Allows a client to create a new chat."""
|
"""Allows a client to create a new chat."""
|
||||||
logger.info(f"Received {request}")
|
logger.info(f"Received {request}")
|
||||||
chat = DbChatEntry()
|
chat = chat_repository.create_chat()
|
||||||
CHATS[chat.id] = chat
|
return chat.to_list_entry()
|
||||||
return ChatListEntry(id=chat.id)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/chat/{id}")
|
@app.get("/chat/{id}")
|
||||||
async def get_chat(id: str) -> ChatEntry:
|
async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ChatEntry:
|
||||||
"""Allows a client to get the current state of a chat."""
|
"""Allows a client to get the current state of a chat."""
|
||||||
return CHATS[id].to_entry()
|
chat = chat_repository.get_chat_entry_by_id(id)
|
||||||
|
return chat
|
||||||
|
|
||||||
|
|
||||||
@app.post("/chat/{id}/message")
|
@app.post("/chat/{id}/message")
|
||||||
async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request):
|
async def create_message(
|
||||||
|
id: str,
|
||||||
|
message_request: interface.MessageRequest,
|
||||||
|
fastapi_request: fastapi.Request,
|
||||||
|
chat_repository: ChatRepository = Depends(create_chat_repository),
|
||||||
|
) -> EventSourceResponse:
|
||||||
"""Allows the client to stream the results of a request."""
|
"""Allows the client to stream the results of a request."""
|
||||||
|
|
||||||
chat = CHATS[id]
|
try:
|
||||||
if not chat.conversation.is_prompter_turn:
|
chat_repository.add_prompter_message(id=id, message_request=message_request)
|
||||||
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
except Exception:
|
||||||
if chat.pending_message_request is not None:
|
logger.exception("Error adding prompter message")
|
||||||
raise fastapi.HTTPException(status_code=400, detail="Already pending")
|
return fastapi.Response(status_code=500)
|
||||||
|
|
||||||
chat.conversation.messages.append(
|
async def event_generator(id):
|
||||||
protocol.ConversationMessage(
|
|
||||||
text=message_request.message,
|
|
||||||
is_assistant=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat.pending_message_request = message_request
|
|
||||||
chat.message_request_state = MessageRequestState.pending
|
|
||||||
|
|
||||||
async def event_generator():
|
|
||||||
result_data = []
|
result_data = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
if await fastapi_request.is_disconnected():
|
if await fastapi_request.is_disconnected():
|
||||||
logger.warning("Client disconnected")
|
logger.warning("Client disconnected")
|
||||||
break
|
return
|
||||||
|
|
||||||
item = await redisClient.blpop(chat.id, 1)
|
item = await redisClient.blpop(id, 1)
|
||||||
if item is None:
|
if item is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -166,47 +142,44 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque
|
|||||||
|
|
||||||
yield {
|
yield {
|
||||||
"retry": settings.sse_retry_timeout,
|
"retry": settings.sse_retry_timeout,
|
||||||
"data": TokenResponseEvent(token=response_packet.token).json(),
|
"data": interface.TokenResponseEvent(token=response_packet.token).json(),
|
||||||
}
|
}
|
||||||
logger.info(f"Finished streaming {chat.id} {len(result_data)=}")
|
logger.info(f"Finished streaming {id} {len(result_data)=}")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(f"Error streaming {chat.id}")
|
logger.exception(f"Error streaming {id}")
|
||||||
|
raise
|
||||||
|
|
||||||
chat.conversation.messages.append(
|
try:
|
||||||
protocol.ConversationMessage(
|
with contextlib.contextmanager(create_session)() as session:
|
||||||
text=response_packet.generated_text.text,
|
chat_repository = create_chat_repository(session)
|
||||||
is_assistant=True,
|
chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text)
|
||||||
)
|
except Exception:
|
||||||
)
|
logger.exception("Error adding assistant message")
|
||||||
chat.pending_message_request = None
|
|
||||||
|
|
||||||
return EventSourceResponse(event_generator())
|
return EventSourceResponse(event_generator(id))
|
||||||
|
|
||||||
|
|
||||||
@app.websocket("/work")
|
@app.websocket("/work")
|
||||||
async def work(websocket: fastapi.WebSocket):
|
async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)):
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text())
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
print(websocket.client_state)
|
|
||||||
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED:
|
||||||
logger.warning("Worker disconnected")
|
logger.warning("Worker disconnected")
|
||||||
break
|
break
|
||||||
# find a pending task that matches the worker's config
|
# find a pending task that matches the worker's config
|
||||||
# could also be implemented using task queues
|
# could also be implemented using task queues
|
||||||
# but general compatibility matching is tricky
|
# but general compatibility matching is tricky
|
||||||
for chat in CHATS.values():
|
for chat in chat_repository.get_pending_chats():
|
||||||
if (request := chat.pending_message_request) is not None:
|
request = chat.pending_message_request
|
||||||
if chat.message_request_state == MessageRequestState.pending:
|
if request.compatible_with(worker_config):
|
||||||
if request.compatible_with(worker_config):
|
break
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
logger.debug("No pending tasks")
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chat.message_request_state = MessageRequestState.in_progress
|
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress)
|
||||||
|
|
||||||
work_request = inference.WorkRequest(
|
work_request = inference.WorkRequest(
|
||||||
conversation=chat.conversation,
|
conversation=chat.conversation,
|
||||||
@@ -214,15 +187,17 @@ async def work(websocket: fastapi.WebSocket):
|
|||||||
max_new_tokens=request.max_new_tokens,
|
max_new_tokens=request.max_new_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Created {work_request}")
|
logger.info(f"Created {work_request=}")
|
||||||
try:
|
try:
|
||||||
await websocket.send_text(work_request.json())
|
await websocket.send_text(work_request.json())
|
||||||
except websockets.exceptions.ConnectionClosedError:
|
except websockets.exceptions.ConnectionClosedError:
|
||||||
logger.warning("Worker disconnected")
|
logger.warning("Worker disconnected")
|
||||||
websocket.close()
|
websocket.close()
|
||||||
chat.message_request_state = MessageRequestState.pending
|
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
logger.debug(f"Sent {work_request=} to worker.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
in_progress = False
|
in_progress = False
|
||||||
while True:
|
while True:
|
||||||
@@ -232,18 +207,19 @@ async def work(websocket: fastapi.WebSocket):
|
|||||||
in_progress = True
|
in_progress = True
|
||||||
await redisClient.rpush(chat.id, response_packet.json())
|
await redisClient.rpush(chat.id, response_packet.json())
|
||||||
if response_packet.is_end:
|
if response_packet.is_end:
|
||||||
|
logger.debug(f"Received {response_packet=} from worker. Ending.")
|
||||||
break
|
break
|
||||||
except fastapi.WebSocketException:
|
except fastapi.WebSocketException:
|
||||||
# TODO: handle this better
|
# TODO: handle this better
|
||||||
logger.exception(f"Websocket closed during handling of {chat.id}")
|
logger.exception(f"Websocket closed during handling of {chat.id}")
|
||||||
if in_progress:
|
if in_progress:
|
||||||
logger.warning(f"Aborting {chat.id=}")
|
logger.warning(f"Aborting {chat.id=}")
|
||||||
chat.message_request_state = MessageRequestState.aborted_by_worker
|
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
logger.warning(f"Marking {chat.id=} as pending since no work was done.")
|
||||||
chat.message_request_state = MessageRequestState.pending
|
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
chat.message_request_state = MessageRequestState.complete
|
chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete)
|
||||||
except fastapi.WebSocketException:
|
except fastapi.WebSocketException:
|
||||||
logger.exception("Websocket closed")
|
logger.exception("Websocket closed")
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
import fastapi
|
||||||
|
import sqlmodel
|
||||||
|
from loguru import logger
|
||||||
|
from oasst_inference_server import interface, models
|
||||||
|
from oasst_shared.schemas import protocol
|
||||||
|
from sqlalchemy.sql.operators import is_not
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRepository:
|
||||||
|
def __init__(self, session: sqlmodel.Session) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
def get_chats(self) -> list[models.DbChatEntry]:
|
||||||
|
return self.session.exec(sqlmodel.select(models.DbChatEntry)).all()
|
||||||
|
|
||||||
|
def get_pending_chats(self) -> list[models.DbChatEntry]:
|
||||||
|
return self.session.exec(
|
||||||
|
sqlmodel.select(models.DbChatEntry).where(
|
||||||
|
is_not(models.DbChatEntry.pending_message_request, None),
|
||||||
|
models.DbChatEntry.message_request_state == interface.MessageRequestState.pending,
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
|
||||||
|
def get_chat_list(self) -> list[interface.ChatListEntry]:
|
||||||
|
chats = self.get_chats()
|
||||||
|
return [chat.to_list_entry() for chat in chats]
|
||||||
|
|
||||||
|
def get_chat_by_id(self, id: str) -> models.DbChatEntry:
|
||||||
|
chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one()
|
||||||
|
return chat
|
||||||
|
|
||||||
|
def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry:
|
||||||
|
return self.get_chat_by_id(id).to_entry()
|
||||||
|
|
||||||
|
def create_chat(self) -> models.DbChatEntry:
|
||||||
|
chat = models.DbChatEntry()
|
||||||
|
self.session.add(chat)
|
||||||
|
self.session.commit()
|
||||||
|
return chat
|
||||||
|
|
||||||
|
def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None:
|
||||||
|
logger.info(f"Adding prompter message {message_request} to chat {id}")
|
||||||
|
chat = self.get_chat_by_id(id)
|
||||||
|
if not chat.conversation.is_prompter_turn:
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail="Not your turn")
|
||||||
|
if chat.pending_message_request is not None:
|
||||||
|
raise fastapi.HTTPException(status_code=400, detail="Already pending")
|
||||||
|
|
||||||
|
chat.conversation.messages.append(
|
||||||
|
protocol.ConversationMessage(
|
||||||
|
text=message_request.message,
|
||||||
|
is_assistant=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chat.pending_message_request = message_request
|
||||||
|
chat.message_request_state = interface.MessageRequestState.pending
|
||||||
|
self.session.commit()
|
||||||
|
logger.debug(f"Added prompter message {message_request} to chat {id}")
|
||||||
|
|
||||||
|
def add_assistant_message(self, id: str, text: str) -> None:
|
||||||
|
logger.info(f"Adding assistant message {text} to chat {id}")
|
||||||
|
chat = self.get_chat_by_id(id)
|
||||||
|
chat.conversation.messages.append(
|
||||||
|
protocol.ConversationMessage(
|
||||||
|
text=text,
|
||||||
|
is_assistant=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
chat.pending_message_request = None
|
||||||
|
self.session.commit()
|
||||||
|
logger.debug(f"Added assistant message {text} to chat {id}")
|
||||||
|
|
||||||
|
def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None:
|
||||||
|
logger.info(f"Setting chat {id} state to {state}")
|
||||||
|
chat = self.get_chat_by_id(id)
|
||||||
|
chat.message_request_state = state
|
||||||
|
self.session.commit()
|
||||||
|
logger.debug(f"Set chat {id} state to {state}")
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import pydantic.json
|
||||||
|
import sqlmodel
|
||||||
|
from loguru import logger
|
||||||
|
from oasst_inference_server import models
|
||||||
|
from oasst_inference_server.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def default_json_serializer(obj):
|
||||||
|
class_name = obj.__class__.__name__
|
||||||
|
encoded = pydantic.json.pydantic_encoder(obj)
|
||||||
|
encoded["_classname_"] = class_name
|
||||||
|
return encoded
|
||||||
|
|
||||||
|
|
||||||
|
def custom_json_serializer(obj):
|
||||||
|
return json.dumps(obj, default=default_json_serializer)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_json_deserializer(s):
|
||||||
|
d = json.loads(s)
|
||||||
|
if not isinstance(d, dict):
|
||||||
|
return d
|
||||||
|
match d.get("_classname_"):
|
||||||
|
case "Conversation":
|
||||||
|
return models.protocol.Conversation.parse_obj(d)
|
||||||
|
case "MessageRequest":
|
||||||
|
return models.interface.MessageRequest.parse_obj(d)
|
||||||
|
case None:
|
||||||
|
return d
|
||||||
|
case _:
|
||||||
|
logger.error(f"Unknown class {d['_classname_']}")
|
||||||
|
raise ValueError(f"Unknown class {d['_classname_']}")
|
||||||
|
|
||||||
|
|
||||||
|
db_engine = sqlmodel.create_engine(
|
||||||
|
settings.database_uri,
|
||||||
|
json_serializer=custom_json_serializer,
|
||||||
|
json_deserializer=custom_json_deserializer,
|
||||||
|
)
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
from oasst_shared.schemas import inference, protocol
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRequest(pydantic.BaseModel):
|
||||||
|
message: str = pydantic.Field(..., repr=False)
|
||||||
|
model_name: str = "distilgpt2"
|
||||||
|
max_new_tokens: int = 100
|
||||||
|
|
||||||
|
def compatible_with(self, worker_config: inference.WorkerConfig) -> bool:
|
||||||
|
return self.model_name == worker_config.model_name
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponseEvent(pydantic.BaseModel):
|
||||||
|
token: inference.TokenResponse
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRequestState(str, enum.Enum):
|
||||||
|
pending = "pending"
|
||||||
|
in_progress = "in_progress"
|
||||||
|
complete = "complete"
|
||||||
|
aborted_by_worker = "aborted_by_worker"
|
||||||
|
|
||||||
|
|
||||||
|
class CreateChatRequest(pydantic.BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ChatListEntry(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEntry(pydantic.BaseModel):
|
||||||
|
id: str
|
||||||
|
conversation: protocol.Conversation
|
||||||
|
|
||||||
|
|
||||||
|
class ListChatsResponse(pydantic.BaseModel):
|
||||||
|
chats: list[ChatListEntry]
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlalchemy.dialects.postgresql as pg
|
||||||
|
from oasst_inference_server import interface
|
||||||
|
from oasst_shared.schemas import protocol
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
|
||||||
|
class DbChatEntry(SQLModel, table=True):
|
||||||
|
__tablename__ = "chat"
|
||||||
|
|
||||||
|
id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True)
|
||||||
|
|
||||||
|
conversation: protocol.Conversation = Field(default_factory=protocol.Conversation, sa_column=sa.Column(pg.JSONB))
|
||||||
|
pending_message_request: interface.MessageRequest | None = Field(None, sa_column=sa.Column(pg.JSONB))
|
||||||
|
message_request_state: interface.MessageRequestState | None = Field(None, sa_column=sa.Column(pg.JSONB))
|
||||||
|
|
||||||
|
def to_list_entry(self) -> interface.ChatListEntry:
|
||||||
|
return interface.ChatListEntry(id=self.id)
|
||||||
|
|
||||||
|
def to_entry(self) -> interface.ChatEntry:
|
||||||
|
return interface.ChatEntry(id=self.id, conversation=self.conversation)
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(pydantic.BaseSettings):
|
||||||
|
redis_host: str = "localhost"
|
||||||
|
redis_port: int = 6379
|
||||||
|
redis_db: int = 0
|
||||||
|
|
||||||
|
sse_retry_timeout: int = 15000
|
||||||
|
update_alembic: bool = True
|
||||||
|
alembic_retries: int = 5
|
||||||
|
alembic_retry_timeout: int = 1
|
||||||
|
|
||||||
|
postgres_host: str = "localhost"
|
||||||
|
postgres_port: str = "5432"
|
||||||
|
postgres_user: str = "postgres"
|
||||||
|
postgres_password: str = "postgres"
|
||||||
|
postgres_db: str = "postgres"
|
||||||
|
|
||||||
|
database_uri: str | None = None
|
||||||
|
|
||||||
|
@pydantic.validator("database_uri", pre=True)
|
||||||
|
def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any:
|
||||||
|
if isinstance(v, str):
|
||||||
|
return v
|
||||||
|
return pydantic.PostgresDsn.build(
|
||||||
|
scheme="postgresql",
|
||||||
|
user=values.get("postgres_user"),
|
||||||
|
password=values.get("postgres_password"),
|
||||||
|
host=values.get("postgres_host"),
|
||||||
|
port=values.get("postgres_port"),
|
||||||
|
path=f"/{values.get('postgres_db') or ''}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
@@ -1,7 +1,10 @@
|
|||||||
|
alembic
|
||||||
fastapi[all]
|
fastapi[all]
|
||||||
loguru
|
loguru
|
||||||
prometheus-fastapi-instrumentator==5.9.1
|
prometheus-fastapi-instrumentator
|
||||||
|
psycopg2-binary
|
||||||
pydantic
|
pydantic
|
||||||
redis
|
redis
|
||||||
|
sqlmodel
|
||||||
sse-starlette
|
sse-starlette
|
||||||
websockets
|
websockets
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
"""Simple REPL frontend."""
|
"""Simple REPL frontend."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import sseclient
|
import sseclient
|
||||||
@@ -42,6 +43,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"):
|
|||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
typer.echo("Error, restarting chat...")
|
typer.echo("Error, restarting chat...")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ def main(
|
|||||||
model_name: str = "distilgpt2",
|
model_name: str = "distilgpt2",
|
||||||
inference_server_url: str = "http://localhost:8001",
|
inference_server_url: str = "http://localhost:8001",
|
||||||
):
|
):
|
||||||
|
utils.wait_for_inference_server(inference_server_url)
|
||||||
|
|
||||||
def on_open(ws: websocket.WebSocket):
|
def on_open(ws: websocket.WebSocket):
|
||||||
logger.info("Connected to backend, sending config...")
|
logger.info("Connected to backend, sending config...")
|
||||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||||
@@ -93,6 +95,7 @@ def main(
|
|||||||
),
|
),
|
||||||
).json()
|
).json()
|
||||||
)
|
)
|
||||||
|
logger.info("Work complete. Waiting for more work...")
|
||||||
|
|
||||||
def on_error(ws: websocket.WebSocket, error: Exception):
|
def on_error(ws: websocket.WebSocket, error: Exception):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
import collections
|
import collections
|
||||||
|
import random
|
||||||
|
import time
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import interface
|
import interface
|
||||||
|
import requests
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
class TokenBuffer:
|
class TokenBuffer:
|
||||||
@@ -38,3 +42,21 @@ class TokenBuffer:
|
|||||||
yield from self.tokens
|
yield from self.tokens
|
||||||
else:
|
else:
|
||||||
yield from self.tokens
|
yield from self.tokens
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_inference_server(inference_server_url: str, timeout: int = 600):
|
||||||
|
health_url = f"{inference_server_url}/health"
|
||||||
|
time_limit = time.time() + timeout
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = requests.get(health_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
except (requests.HTTPError, requests.ConnectionError):
|
||||||
|
if time.time() > time_limit:
|
||||||
|
raise
|
||||||
|
sleep_duration = random.uniform(0, 10)
|
||||||
|
logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds")
|
||||||
|
time.sleep(sleep_duration)
|
||||||
|
else:
|
||||||
|
logger.info("Inference server is ready")
|
||||||
|
break
|
||||||
|
|||||||
Reference in New Issue
Block a user