diff --git a/docker/inference/Dockerfile.worker-full b/docker/inference/Dockerfile.worker-full new file mode 100644 index 00000000..fa94f62e --- /dev/null +++ b/docker/inference/Dockerfile.worker-full @@ -0,0 +1,22 @@ +FROM ghcr.io/huggingface/text-generation-inference + +ARG MODULE="inference" +ARG SERVICE="worker" + +ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}" + +WORKDIR /worker +COPY ./oasst-shared /tmp/oasst-shared +RUN conda create -n worker python=3.10 -y +RUN /opt/miniconda/envs/worker/bin/pip install /tmp/oasst-shared + +COPY ./${APP_RELATIVE_PATH}/requirements.txt . +RUN /opt/miniconda/envs/worker/bin/pip install -r requirements.txt + +COPY ./${APP_RELATIVE_PATH}/*.py . +COPY ./${APP_RELATIVE_PATH}/worker_full_main.sh /entrypoint.sh + +ENV MODEL_ID="distilgpt2" +ENV INFERENCE_SERVER_URL="http://localhost:80" + +ENTRYPOINT ["/entrypoint.sh"] diff --git a/inference/worker/__main__.py b/inference/worker/__main__.py index b1d984a4..0d825ba4 100644 --- a/inference/worker/__main__.py +++ b/inference/worker/__main__.py @@ -2,26 +2,19 @@ import interface import rel import requests import sseclient -import typer import utils import websocket from loguru import logger from oasst_shared.schemas import inference, protocol - -app = typer.Typer() +from settings import settings -@app.command() -def main( - backend_url: str = "ws://localhost:8000", - model_name: str = "distilgpt2", - inference_server_url: str = "http://localhost:8001", -): - utils.wait_for_inference_server(inference_server_url) +def main(): + utils.wait_for_inference_server(settings.inference_server_url) def on_open(ws: websocket.WebSocket): logger.info("Connected to backend, sending config...") - worker_config = inference.WorkerConfig(model_name=model_name) + worker_config = inference.WorkerConfig(model_name=settings.model_name) ws.send(worker_config.json()) logger.info("Config sent, waiting for work...") @@ -47,7 +40,7 @@ def main( parameters = interface.GenerateStreamParameters.from_work_request(work_request) response = requests.post( - f"{inference_server_url}/generate_stream", + f"{settings.inference_server_url}/generate_stream", json={ "inputs": prompt, "parameters": parameters.dict(), @@ -107,7 +100,7 @@ def main( logger.warning(f"Connection closed: {close_status_code=} {close_msg=}") ws = websocket.WebSocketApp( - f"{backend_url}/work", + f"{settings.backend_url}/work", on_message=on_message, on_error=on_error, on_close=on_close, @@ -120,4 +113,4 @@ def main( if __name__ == "__main__": - app() + main() diff --git a/inference/worker/requirements.txt b/inference/worker/requirements.txt index 3afd3617..dd35a5a6 100644 --- a/inference/worker/requirements.txt +++ b/inference/worker/requirements.txt @@ -3,5 +3,4 @@ pydantic rel requests sseclient-py -typer websocket-client diff --git a/inference/worker/settings.py b/inference/worker/settings.py new file mode 100644 index 00000000..c726479c --- /dev/null +++ b/inference/worker/settings.py @@ -0,0 +1,10 @@ +import pydantic + + +class Settings(pydantic.BaseSettings): + backend_url: str = "ws://localhost:8000" + model_id: str = "distilgpt2" + inference_server_url: str = "http://localhost:8001" + + +settings = Settings() diff --git a/inference/worker/worker_full_main.sh b/inference/worker/worker_full_main.sh new file mode 100755 index 00000000..c3c7788e --- /dev/null +++ b/inference/worker/worker_full_main.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +text-generation-launcher & + +/opt/miniconda/envs/worker/bin/python /worker & + +wait -n + +exit $?