From 90c3d5640e3b302eb5ef1e2ea57e7e1980734db6 Mon Sep 17 00:00:00 2001 From: Yannic Kilcher Date: Fri, 10 Feb 2023 22:51:35 +0100 Subject: [PATCH] Added database to inference server (#1446) * added db for inference * fixed dockerfiles for inference --- docker-compose.yaml | 53 +++-- docker/inference/Dockerfile.server | 7 +- docker/inference/Dockerfile.text-client | 50 ----- docker/inference/Dockerfile.worker | 2 +- inference/full-dev-setup.sh | 6 +- inference/server/alembic.ini | 105 +++++++++ inference/server/alembic/README | 1 + inference/server/alembic/env.py | 78 +++++++ inference/server/alembic/script.py.mako | 25 +++ inference/server/alembic/versions/.gitinclude | 0 ..._02_10_0221-3a4cd8777eb2_initial_commit.py | 36 ++++ inference/server/main.py | 202 ++++++++---------- .../oasst_inference_server/chat_repository.py | 79 +++++++ .../server/oasst_inference_server/database.py | 41 ++++ .../oasst_inference_server/interface.py | 41 ++++ .../server/oasst_inference_server/models.py | 23 ++ .../server/oasst_inference_server/settings.py | 38 ++++ inference/server/requirements.txt | 5 +- inference/text-client/__main__.py | 2 + inference/worker/__main__.py | 3 + inference/worker/utils.py | 22 ++ 21 files changed, 627 insertions(+), 192 deletions(-) delete mode 100644 docker/inference/Dockerfile.text-client create mode 100644 inference/server/alembic.ini create mode 100644 inference/server/alembic/README create mode 100644 inference/server/alembic/env.py create mode 100644 inference/server/alembic/script.py.mako create mode 100644 inference/server/alembic/versions/.gitinclude create mode 100644 inference/server/alembic/versions/2023_02_10_0221-3a4cd8777eb2_initial_commit.py create mode 100644 inference/server/oasst_inference_server/chat_repository.py create mode 100644 inference/server/oasst_inference_server/database.py create mode 100644 inference/server/oasst_inference_server/interface.py create mode 100644 inference/server/oasst_inference_server/models.py create mode 100644 inference/server/oasst_inference_server/settings.py diff --git a/docker-compose.yaml b/docker-compose.yaml index 3cfb4c21..52085f62 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -136,6 +136,23 @@ services: - "3000:3000" command: bash wait-for-postgres.sh node server.js + # This DB is for Inference + inference-db: + image: postgres + restart: always + ports: + - 5434:5432 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: oasst_inference + healthcheck: + test: ["CMD", "pg_isready", "-U", "postgres"] + interval: 2s + timeout: 2s + retries: 10 + profiles: ["inference"] + inference-server: build: dockerfile: docker/inference/Dockerfile.server @@ -145,13 +162,25 @@ services: environment: - "PORT=8000" - "REDIS_HOST=redis" + - POSTGRES_HOST=inference-db + - POSTGRES_DB=oasst_inference volumes: - "./oasst-shared:/opt/inference/lib/oasst-shared" - "./inference/server:/opt/inference/server" restart: unless-stopped + ports: + - "8000:8000" depends_on: redis: condition: service_healthy + inference-db: + condition: service_healthy + profiles: ["inference"] + + inference-text-generation-server: + image: ghcr.io/huggingface/text-generation-inference + environment: + - "MODEL_ID=distilgpt2" profiles: ["inference"] inference-worker: @@ -167,29 +196,7 @@ services: - "./oasst-shared:/opt/inference/lib/oasst-shared" - "./inference/worker:/opt/inference/worker" depends_on: - - inference-server + - inference-text-generation-server deploy: replicas: 1 profiles: ["inference"] - - inference-text-client: - build: - dockerfile: docker/inference/Dockerfile.text-client - context: . - image: oasst-inference-text-client - environment: - - "BACKEND_URL=http://inference-server:8000" - tty: true - stdin_open: true - volumes: - - "./inference/worker:/opt/inference/worker" - restart: unless-stopped - depends_on: - - inference-server - profiles: ["inference"] - - inference-text-generation-server: - image: ghcr.io/huggingface/text-generation-inference - environment: - - "MODEL_ID=distilgpt2" - profiles: ["inference"] diff --git a/docker/inference/Dockerfile.server b/docker/inference/Dockerfile.server index 0838a21e..f5823a9a 100644 --- a/docker/inference/Dockerfile.server +++ b/docker/inference/Dockerfile.server @@ -7,7 +7,7 @@ ARG APP_USER="${MODULE}-${SERVICE}" ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}" -FROM python:3-slim as build +FROM python:3.10-slim as build ARG APP_RELATIVE_PATH WORKDIR /build @@ -22,7 +22,7 @@ RUN --mount=type=cache,target=/var/cache/pip \ -FROM python:3.10-alpine3.17 as base-env +FROM python:3.10-slim as base-env ARG APP_USER ARG APP_RELATIVE_PATH ARG MODULE @@ -50,6 +50,9 @@ WORKDIR ${APP_ROOT} COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS} +COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic alembic +COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/alembic.ini . +COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/oasst_inference_server oasst_inference_server COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/main.py . diff --git a/docker/inference/Dockerfile.text-client b/docker/inference/Dockerfile.text-client deleted file mode 100644 index 23a54abe..00000000 --- a/docker/inference/Dockerfile.text-client +++ /dev/null @@ -1,50 +0,0 @@ -# syntax=docker/dockerfile:1 - -ARG APP_USER="text-client" -ARG APP_RELATIVE_PATH="inference/text-client" - - -FROM python:3.10-alpine3.17 as build -ARG APP_RELATIVE_PATH - -WORKDIR /build - -COPY ./${APP_RELATIVE_PATH}/requirements.txt . - -RUN --mount=type=cache,target=/var/cache/pip \ - pip install \ - --cache-dir=/var/cache/pip \ - --target=lib \ - -r requirements.txt - - - -FROM python:3.10-alpine3.17 as base-env -ARG APP_USER -ARG APP_RELATIVE_PATH - -ENV APP_ROOT="/opt/${APP_RELATIVE_PATH}" -ENV APP_LIBS="/var/opt/${APP_RELATIVE_PATH}/lib" - -ENV PATH="${PATH}:${APP_LIBS}/bin" -ENV PYTHONPATH="${PYTHONPATH}:${APP_LIBS}" - - -RUN adduser \ - --disabled-password \ - --no-create-home \ - "${APP_USER}" - -USER ${APP_USER} - -WORKDIR ${APP_ROOT} - -COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS} -COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py . - - - -FROM base-env as prod - - -CMD python3 __main__.py --backend-url "${BACKEND_URL}" diff --git a/docker/inference/Dockerfile.worker b/docker/inference/Dockerfile.worker index 06f040ab..64e8655a 100644 --- a/docker/inference/Dockerfile.worker +++ b/docker/inference/Dockerfile.worker @@ -48,7 +48,7 @@ WORKDIR ${APP_ROOT} COPY --chown="${APP_USER}:${APP_USER}" --from=build /build/lib ${APP_LIBS} -COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/__main__.py . +COPY --chown="${APP_USER}:${APP_USER}" ./${APP_RELATIVE_PATH}/*.py . CMD python3 __main__.py --backend-url "${BACKEND_URL}" --inference-server-url "${INFERENCE_SERVER_URL}" diff --git a/inference/full-dev-setup.sh b/inference/full-dev-setup.sh index 5ef754d2..c5afb412 100755 --- a/inference/full-dev-setup.sh +++ b/inference/full-dev-setup.sh @@ -3,9 +3,11 @@ # Creates a tmux window with splits for the individual services tmux new-session -d -s "inference-dev-setup" -tmux send-keys "docker run --rm -it -p 6379:6379 redis" C-m +tmux send-keys "docker run --rm -it -p 5432:5432 -e POSTGRES_PASSWORD=postgres --name postgres postgres" C-m tmux split-window -h -tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 ghcr.io/huggingface/text-generation-inference" C-m +tmux send-keys "docker run --rm -it -p 6379:6379 --name redis redis" C-m +tmux split-window -h +tmux send-keys "docker run --rm -it -p 8001:80 -e MODEL_ID=distilgpt2 -v $HOME/.cache/huggingface:/root/.cache/huggingface --name text-generation-inference ghcr.io/huggingface/text-generation-inference" C-m tmux split-window -h tmux send-keys "cd server" C-m tmux send-keys "uvicorn main:app --reload" C-m diff --git a/inference/server/alembic.ini b/inference/server/alembic.ini new file mode 100644 index 00000000..44829313 --- /dev/null +++ b/inference/server/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = %(here)s/alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# sqlalchemy.url = postgresql://:@/ +sqlalchemy.url = postgresql://postgres:postgres@localhost:5432/postgres + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +hooks = black +black.type = console_scripts +black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/inference/server/alembic/README b/inference/server/alembic/README new file mode 100644 index 00000000..2500aa1b --- /dev/null +++ b/inference/server/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. diff --git a/inference/server/alembic/env.py b/inference/server/alembic/env.py new file mode 100644 index 00000000..55e16d32 --- /dev/null +++ b/inference/server/alembic/env.py @@ -0,0 +1,78 @@ +from logging.config import fileConfig + +import sqlmodel +from alembic import context +from oasst_inference_server import models # noqa: F401 +from sqlalchemy import engine_from_config, pool + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = sqlmodel.SQLModel.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.get_context()._ensure_version_table() + connection.execute("LOCK TABLE alembic_version IN ACCESS EXCLUSIVE MODE") + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/inference/server/alembic/script.py.mako b/inference/server/alembic/script.py.mako new file mode 100644 index 00000000..3124b62c --- /dev/null +++ b/inference/server/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +import sqlmodel +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/inference/server/alembic/versions/.gitinclude b/inference/server/alembic/versions/.gitinclude new file mode 100644 index 00000000..e69de29b diff --git a/inference/server/alembic/versions/2023_02_10_0221-3a4cd8777eb2_initial_commit.py b/inference/server/alembic/versions/2023_02_10_0221-3a4cd8777eb2_initial_commit.py new file mode 100644 index 00000000..3fa7cd73 --- /dev/null +++ b/inference/server/alembic/versions/2023_02_10_0221-3a4cd8777eb2_initial_commit.py @@ -0,0 +1,36 @@ +"""initial commit + +Revision ID: 3a4cd8777eb2 +Revises: +Create Date: 2023-02-10 02:21:27.086772 + +""" +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "3a4cd8777eb2" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "chat", + sa.Column("conversation", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("pending_message_request", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("message_request_state", postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column("id", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("chat") + # ### end Alembic commands ### diff --git a/inference/server/main.py b/inference/server/main.py index de0f607d..072bcbde 100644 --- a/inference/server/main.py +++ b/inference/server/main.py @@ -1,14 +1,22 @@ import asyncio -import enum -import uuid +import contextlib +import time +from pathlib import Path +import alembic.command +import alembic.config import fastapi -import pydantic import redis.asyncio as redis +import sqlmodel import websockets.exceptions +from fastapi import Depends from fastapi.middleware.cors import CORSMiddleware from loguru import logger -from oasst_shared.schemas import inference, protocol +from oasst_inference_server import interface +from oasst_inference_server.chat_repository import ChatRepository +from oasst_inference_server.database import db_engine +from oasst_inference_server.settings import settings +from oasst_shared.schemas import inference from prometheus_fastapi_instrumentator import Instrumentator from sse_starlette.sse import EventSourceResponse @@ -31,129 +39,97 @@ app.add_middleware( ) -class Settings(pydantic.BaseSettings): - redis_host: str = "localhost" - redis_port: int = 6379 - redis_db: int = 0 - - sse_retry_timeout: int = 15000 - - -settings = Settings() - # create async redis client redisClient = redis.Redis( host=settings.redis_host, port=settings.redis_port, db=settings.redis_db, decode_responses=True ) -class MessageRequest(pydantic.BaseModel): - message: str = pydantic.Field(..., repr=False) - model_name: str = "distilgpt2" - max_new_tokens: int = 100 - - def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: - return self.model_name == worker_config.model_name +def create_session(): + with sqlmodel.Session(db_engine) as session: + yield session -class TokenResponseEvent(pydantic.BaseModel): - token: inference.TokenResponse +def create_chat_repository(session: sqlmodel.Session = Depends(create_session)): + repository = ChatRepository(session) + return repository -class MessageRequestState(str, enum.Enum): - pending = "pending" - in_progress = "in_progress" - complete = "complete" - aborted_by_worker = "aborted_by_worker" +if settings.update_alembic: + @app.on_event("startup") + def alembic_upgrade(): + logger.info("Attempting to upgrade alembic on startup") + retry = 0 + while True: + try: + alembic_ini_path = Path(__file__).parent / "alembic.ini" + alembic_cfg = alembic.config.Config(str(alembic_ini_path)) + alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri) + alembic.command.upgrade(alembic_cfg, "head") + logger.info("Successfully upgraded alembic on startup") + break + except Exception: + logger.exception("Alembic upgrade failed on startup") + retry += 1 + if retry >= settings.alembic_retries: + raise -class CreateChatRequest(pydantic.BaseModel): - pass - - -class ChatListEntry(pydantic.BaseModel): - id: str - - -class ChatEntry(pydantic.BaseModel): - id: str - conversation: protocol.Conversation - - -class ListChatsResponse(pydantic.BaseModel): - chats: list[ChatListEntry] - - -class DbChatEntry(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=lambda: str(uuid.uuid4())) - conversation: protocol.Conversation = pydantic.Field(default_factory=protocol.Conversation) - pending_message_request: MessageRequest | None = None - message_request_state: MessageRequestState | None = None - - def to_list_entry(self) -> ChatListEntry: - return ChatListEntry(id=self.id) - - def to_entry(self) -> ChatEntry: - return ChatEntry(id=self.id, conversation=self.conversation) - - -# TODO: make real database -CHATS: dict[str, DbChatEntry] = {} + timeout = settings.alembic_retry_timeout * 2**retry + logger.warning(f"Retrying alembic upgrade in {timeout} seconds") + time.sleep(timeout) @app.get("/chat") -async def list_chats() -> ListChatsResponse: +async def list_chats(chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ListChatsResponse: """Lists all chats.""" logger.info("Listing all chats.") - chats = [chat.to_list_entry() for chat in CHATS.values()] - return ListChatsResponse(chats=chats) + chats = chat_repository.get_chat_list() + return interface.ListChatsResponse(chats=chats) @app.post("/chat") -async def create_chat(request: CreateChatRequest) -> ChatListEntry: +async def create_chat( + request: interface.CreateChatRequest, chat_repository: ChatRepository = Depends(create_chat_repository) +) -> interface.ChatListEntry: """Allows a client to create a new chat.""" logger.info(f"Received {request}") - chat = DbChatEntry() - CHATS[chat.id] = chat - return ChatListEntry(id=chat.id) + chat = chat_repository.create_chat() + return chat.to_list_entry() @app.get("/chat/{id}") -async def get_chat(id: str) -> ChatEntry: +async def get_chat(id: str, chat_repository: ChatRepository = Depends(create_chat_repository)) -> interface.ChatEntry: """Allows a client to get the current state of a chat.""" - return CHATS[id].to_entry() + chat = chat_repository.get_chat_entry_by_id(id) + return chat @app.post("/chat/{id}/message") -async def create_message(id: str, message_request: MessageRequest, fastapi_request: fastapi.Request): +async def create_message( + id: str, + message_request: interface.MessageRequest, + fastapi_request: fastapi.Request, + chat_repository: ChatRepository = Depends(create_chat_repository), +) -> EventSourceResponse: """Allows the client to stream the results of a request.""" - chat = CHATS[id] - if not chat.conversation.is_prompter_turn: - raise fastapi.HTTPException(status_code=400, detail="Not your turn") - if chat.pending_message_request is not None: - raise fastapi.HTTPException(status_code=400, detail="Already pending") + try: + chat_repository.add_prompter_message(id=id, message_request=message_request) + except Exception: + logger.exception("Error adding prompter message") + return fastapi.Response(status_code=500) - chat.conversation.messages.append( - protocol.ConversationMessage( - text=message_request.message, - is_assistant=False, - ) - ) - - chat.pending_message_request = message_request - chat.message_request_state = MessageRequestState.pending - - async def event_generator(): + async def event_generator(id): result_data = [] try: while True: if await fastapi_request.is_disconnected(): logger.warning("Client disconnected") - break + return - item = await redisClient.blpop(chat.id, 1) + item = await redisClient.blpop(id, 1) if item is None: continue @@ -166,47 +142,44 @@ async def create_message(id: str, message_request: MessageRequest, fastapi_reque yield { "retry": settings.sse_retry_timeout, - "data": TokenResponseEvent(token=response_packet.token).json(), + "data": interface.TokenResponseEvent(token=response_packet.token).json(), } - logger.info(f"Finished streaming {chat.id} {len(result_data)=}") + logger.info(f"Finished streaming {id} {len(result_data)=}") except Exception: - logger.exception(f"Error streaming {chat.id}") + logger.exception(f"Error streaming {id}") + raise - chat.conversation.messages.append( - protocol.ConversationMessage( - text=response_packet.generated_text.text, - is_assistant=True, - ) - ) - chat.pending_message_request = None + try: + with contextlib.contextmanager(create_session)() as session: + chat_repository = create_chat_repository(session) + chat_repository.add_assistant_message(id=id, text=response_packet.generated_text.text) + except Exception: + logger.exception("Error adding assistant message") - return EventSourceResponse(event_generator()) + return EventSourceResponse(event_generator(id)) @app.websocket("/work") -async def work(websocket: fastapi.WebSocket): +async def work(websocket: fastapi.WebSocket, chat_repository: ChatRepository = Depends(create_chat_repository)): await websocket.accept() worker_config = inference.WorkerConfig.parse_raw(await websocket.receive_text()) try: while True: - print(websocket.client_state) if websocket.client_state == fastapi.websockets.WebSocketState.DISCONNECTED: logger.warning("Worker disconnected") break # find a pending task that matches the worker's config # could also be implemented using task queues # but general compatibility matching is tricky - for chat in CHATS.values(): - if (request := chat.pending_message_request) is not None: - if chat.message_request_state == MessageRequestState.pending: - if request.compatible_with(worker_config): - break + for chat in chat_repository.get_pending_chats(): + request = chat.pending_message_request + if request.compatible_with(worker_config): + break else: - logger.debug("No pending tasks") await asyncio.sleep(1) continue - chat.message_request_state = MessageRequestState.in_progress + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.in_progress) work_request = inference.WorkRequest( conversation=chat.conversation, @@ -214,15 +187,17 @@ async def work(websocket: fastapi.WebSocket): max_new_tokens=request.max_new_tokens, ) - logger.info(f"Created {work_request}") + logger.info(f"Created {work_request=}") try: await websocket.send_text(work_request.json()) except websockets.exceptions.ConnectionClosedError: logger.warning("Worker disconnected") websocket.close() - chat.message_request_state = MessageRequestState.pending + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) break + logger.debug(f"Sent {work_request=} to worker.") + try: in_progress = False while True: @@ -232,18 +207,19 @@ async def work(websocket: fastapi.WebSocket): in_progress = True await redisClient.rpush(chat.id, response_packet.json()) if response_packet.is_end: + logger.debug(f"Received {response_packet=} from worker. Ending.") break except fastapi.WebSocketException: # TODO: handle this better logger.exception(f"Websocket closed during handling of {chat.id}") if in_progress: logger.warning(f"Aborting {chat.id=}") - chat.message_request_state = MessageRequestState.aborted_by_worker + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.aborted_by_worker) else: logger.warning(f"Marking {chat.id=} as pending since no work was done.") - chat.message_request_state = MessageRequestState.pending + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.pending) raise - chat.message_request_state = MessageRequestState.complete + chat_repository.set_chat_state(chat.id, interface.MessageRequestState.complete) except fastapi.WebSocketException: logger.exception("Websocket closed") diff --git a/inference/server/oasst_inference_server/chat_repository.py b/inference/server/oasst_inference_server/chat_repository.py new file mode 100644 index 00000000..52cb9543 --- /dev/null +++ b/inference/server/oasst_inference_server/chat_repository.py @@ -0,0 +1,79 @@ +import fastapi +import sqlmodel +from loguru import logger +from oasst_inference_server import interface, models +from oasst_shared.schemas import protocol +from sqlalchemy.sql.operators import is_not + + +class ChatRepository: + def __init__(self, session: sqlmodel.Session) -> None: + self.session = session + + def get_chats(self) -> list[models.DbChatEntry]: + return self.session.exec(sqlmodel.select(models.DbChatEntry)).all() + + def get_pending_chats(self) -> list[models.DbChatEntry]: + return self.session.exec( + sqlmodel.select(models.DbChatEntry).where( + is_not(models.DbChatEntry.pending_message_request, None), + models.DbChatEntry.message_request_state == interface.MessageRequestState.pending, + ) + ).all() + + def get_chat_list(self) -> list[interface.ChatListEntry]: + chats = self.get_chats() + return [chat.to_list_entry() for chat in chats] + + def get_chat_by_id(self, id: str) -> models.DbChatEntry: + chat = self.session.exec(sqlmodel.select(models.DbChatEntry).where(models.DbChatEntry.id == id)).one() + return chat + + def get_chat_entry_by_id(self, id: str) -> interface.ChatEntry: + return self.get_chat_by_id(id).to_entry() + + def create_chat(self) -> models.DbChatEntry: + chat = models.DbChatEntry() + self.session.add(chat) + self.session.commit() + return chat + + def add_prompter_message(self, id: str, message_request: interface.MessageRequest) -> None: + logger.info(f"Adding prompter message {message_request} to chat {id}") + chat = self.get_chat_by_id(id) + if not chat.conversation.is_prompter_turn: + raise fastapi.HTTPException(status_code=400, detail="Not your turn") + if chat.pending_message_request is not None: + raise fastapi.HTTPException(status_code=400, detail="Already pending") + + chat.conversation.messages.append( + protocol.ConversationMessage( + text=message_request.message, + is_assistant=False, + ) + ) + + chat.pending_message_request = message_request + chat.message_request_state = interface.MessageRequestState.pending + self.session.commit() + logger.debug(f"Added prompter message {message_request} to chat {id}") + + def add_assistant_message(self, id: str, text: str) -> None: + logger.info(f"Adding assistant message {text} to chat {id}") + chat = self.get_chat_by_id(id) + chat.conversation.messages.append( + protocol.ConversationMessage( + text=text, + is_assistant=True, + ) + ) + chat.pending_message_request = None + self.session.commit() + logger.debug(f"Added assistant message {text} to chat {id}") + + def set_chat_state(self, id: str, state: interface.MessageRequestState) -> None: + logger.info(f"Setting chat {id} state to {state}") + chat = self.get_chat_by_id(id) + chat.message_request_state = state + self.session.commit() + logger.debug(f"Set chat {id} state to {state}") diff --git a/inference/server/oasst_inference_server/database.py b/inference/server/oasst_inference_server/database.py new file mode 100644 index 00000000..c714b6c9 --- /dev/null +++ b/inference/server/oasst_inference_server/database.py @@ -0,0 +1,41 @@ +import json + +import pydantic.json +import sqlmodel +from loguru import logger +from oasst_inference_server import models +from oasst_inference_server.settings import settings + + +def default_json_serializer(obj): + class_name = obj.__class__.__name__ + encoded = pydantic.json.pydantic_encoder(obj) + encoded["_classname_"] = class_name + return encoded + + +def custom_json_serializer(obj): + return json.dumps(obj, default=default_json_serializer) + + +def custom_json_deserializer(s): + d = json.loads(s) + if not isinstance(d, dict): + return d + match d.get("_classname_"): + case "Conversation": + return models.protocol.Conversation.parse_obj(d) + case "MessageRequest": + return models.interface.MessageRequest.parse_obj(d) + case None: + return d + case _: + logger.error(f"Unknown class {d['_classname_']}") + raise ValueError(f"Unknown class {d['_classname_']}") + + +db_engine = sqlmodel.create_engine( + settings.database_uri, + json_serializer=custom_json_serializer, + json_deserializer=custom_json_deserializer, +) diff --git a/inference/server/oasst_inference_server/interface.py b/inference/server/oasst_inference_server/interface.py new file mode 100644 index 00000000..7fecffa1 --- /dev/null +++ b/inference/server/oasst_inference_server/interface.py @@ -0,0 +1,41 @@ +import enum + +import pydantic +from oasst_shared.schemas import inference, protocol + + +class MessageRequest(pydantic.BaseModel): + message: str = pydantic.Field(..., repr=False) + model_name: str = "distilgpt2" + max_new_tokens: int = 100 + + def compatible_with(self, worker_config: inference.WorkerConfig) -> bool: + return self.model_name == worker_config.model_name + + +class TokenResponseEvent(pydantic.BaseModel): + token: inference.TokenResponse + + +class MessageRequestState(str, enum.Enum): + pending = "pending" + in_progress = "in_progress" + complete = "complete" + aborted_by_worker = "aborted_by_worker" + + +class CreateChatRequest(pydantic.BaseModel): + pass + + +class ChatListEntry(pydantic.BaseModel): + id: str + + +class ChatEntry(pydantic.BaseModel): + id: str + conversation: protocol.Conversation + + +class ListChatsResponse(pydantic.BaseModel): + chats: list[ChatListEntry] diff --git a/inference/server/oasst_inference_server/models.py b/inference/server/oasst_inference_server/models.py new file mode 100644 index 00000000..f1a32438 --- /dev/null +++ b/inference/server/oasst_inference_server/models.py @@ -0,0 +1,23 @@ +from uuid import uuid4 + +import sqlalchemy as sa +import sqlalchemy.dialects.postgresql as pg +from oasst_inference_server import interface +from oasst_shared.schemas import protocol +from sqlmodel import Field, SQLModel + + +class DbChatEntry(SQLModel, table=True): + __tablename__ = "chat" + + id: str = Field(default_factory=lambda: str(uuid4()), primary_key=True) + + conversation: protocol.Conversation = Field(default_factory=protocol.Conversation, sa_column=sa.Column(pg.JSONB)) + pending_message_request: interface.MessageRequest | None = Field(None, sa_column=sa.Column(pg.JSONB)) + message_request_state: interface.MessageRequestState | None = Field(None, sa_column=sa.Column(pg.JSONB)) + + def to_list_entry(self) -> interface.ChatListEntry: + return interface.ChatListEntry(id=self.id) + + def to_entry(self) -> interface.ChatEntry: + return interface.ChatEntry(id=self.id, conversation=self.conversation) diff --git a/inference/server/oasst_inference_server/settings.py b/inference/server/oasst_inference_server/settings.py new file mode 100644 index 00000000..e0a4d914 --- /dev/null +++ b/inference/server/oasst_inference_server/settings.py @@ -0,0 +1,38 @@ +from typing import Any + +import pydantic + + +class Settings(pydantic.BaseSettings): + redis_host: str = "localhost" + redis_port: int = 6379 + redis_db: int = 0 + + sse_retry_timeout: int = 15000 + update_alembic: bool = True + alembic_retries: int = 5 + alembic_retry_timeout: int = 1 + + postgres_host: str = "localhost" + postgres_port: str = "5432" + postgres_user: str = "postgres" + postgres_password: str = "postgres" + postgres_db: str = "postgres" + + database_uri: str | None = None + + @pydantic.validator("database_uri", pre=True) + def assemble_db_connection(cls, v: str | None, values: dict[str, Any]) -> Any: + if isinstance(v, str): + return v + return pydantic.PostgresDsn.build( + scheme="postgresql", + user=values.get("postgres_user"), + password=values.get("postgres_password"), + host=values.get("postgres_host"), + port=values.get("postgres_port"), + path=f"/{values.get('postgres_db') or ''}", + ) + + +settings = Settings() diff --git a/inference/server/requirements.txt b/inference/server/requirements.txt index 4790b471..fdc60ccd 100644 --- a/inference/server/requirements.txt +++ b/inference/server/requirements.txt @@ -1,7 +1,10 @@ +alembic fastapi[all] loguru -prometheus-fastapi-instrumentator==5.9.1 +prometheus-fastapi-instrumentator +psycopg2-binary pydantic redis +sqlmodel sse-starlette websockets diff --git a/inference/text-client/__main__.py b/inference/text-client/__main__.py index e56feaa1..4a0d9e47 100644 --- a/inference/text-client/__main__.py +++ b/inference/text-client/__main__.py @@ -1,6 +1,7 @@ """Simple REPL frontend.""" import json +import time import requests import sseclient @@ -42,6 +43,7 @@ def main(backend_url: str = "http://127.0.0.1:8000"): break except Exception: typer.echo("Error, restarting chat...") + time.sleep(1) if __name__ == "__main__": diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index 2a6514b7..b1d984a4 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -17,6 +17,8 @@ def main( model_name: str = "distilgpt2", inference_server_url: str = "http://localhost:8001", ): + utils.wait_for_inference_server(inference_server_url) + def on_open(ws: websocket.WebSocket): logger.info("Connected to backend, sending config...") worker_config = inference.WorkerConfig(model_name=model_name) @@ -93,6 +95,7 @@ def main( ), ).json() ) + logger.info("Work complete. Waiting for more work...") def on_error(ws: websocket.WebSocket, error: Exception): try: diff --git a/inference/worker/utils.py b/inference/worker/utils.py index 2cababcf..414b6958 100644 --- a/inference/worker/utils.py +++ b/inference/worker/utils.py @@ -1,7 +1,11 @@ import collections +import random +import time from typing import Literal import interface +import requests +from loguru import logger class TokenBuffer: @@ -38,3 +42,21 @@ class TokenBuffer: yield from self.tokens else: yield from self.tokens + + +def wait_for_inference_server(inference_server_url: str, timeout: int = 600): + health_url = f"{inference_server_url}/health" + time_limit = time.time() + timeout + while True: + try: + response = requests.get(health_url) + response.raise_for_status() + except (requests.HTTPError, requests.ConnectionError): + if time.time() > time_limit: + raise + sleep_duration = random.uniform(0, 10) + logger.warning(f"Inference server not ready. Retrying in {sleep_duration} seconds") + time.sleep(sleep_duration) + else: + logger.info("Inference server is ready") + break