added dockerfile for worker-full

This commit is contained in:
Yannic Kilcher
2023-02-10 22:53:40 +01:00
parent 90c3d5640e
commit d1aea98ad5
5 changed files with 48 additions and 15 deletions
+22
View File
@@ -0,0 +1,22 @@
FROM ghcr.io/huggingface/text-generation-inference
ARG MODULE="inference"
ARG SERVICE="worker"
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
WORKDIR /worker
COPY ./oasst-shared /tmp/oasst-shared
RUN conda create -n worker python=3.10 -y
RUN /opt/miniconda/envs/worker/bin/pip install /tmp/oasst-shared
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
RUN /opt/miniconda/envs/worker/bin/pip install -r requirements.txt
COPY ./${APP_RELATIVE_PATH}/*.py .
COPY ./${APP_RELATIVE_PATH}/worker_full_main.sh /entrypoint.sh
ENV MODEL_ID="distilgpt2"
ENV INFERENCE_SERVER_URL="http://localhost:80"
ENTRYPOINT ["/entrypoint.sh"]
+7 -14
View File
@@ -2,26 +2,19 @@ import interface
import rel
import requests
import sseclient
import typer
import utils
import websocket
from loguru import logger
from oasst_shared.schemas import inference, protocol
app = typer.Typer()
from settings import settings
@app.command()
def main(
backend_url: str = "ws://localhost:8000",
model_name: str = "distilgpt2",
inference_server_url: str = "http://localhost:8001",
):
utils.wait_for_inference_server(inference_server_url)
def main():
utils.wait_for_inference_server(settings.inference_server_url)
def on_open(ws: websocket.WebSocket):
logger.info("Connected to backend, sending config...")
worker_config = inference.WorkerConfig(model_name=model_name)
worker_config = inference.WorkerConfig(model_name=settings.model_name)
ws.send(worker_config.json())
logger.info("Config sent, waiting for work...")
@@ -47,7 +40,7 @@ def main(
parameters = interface.GenerateStreamParameters.from_work_request(work_request)
response = requests.post(
f"{inference_server_url}/generate_stream",
f"{settings.inference_server_url}/generate_stream",
json={
"inputs": prompt,
"parameters": parameters.dict(),
@@ -107,7 +100,7 @@ def main(
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
ws = websocket.WebSocketApp(
f"{backend_url}/work",
f"{settings.backend_url}/work",
on_message=on_message,
on_error=on_error,
on_close=on_close,
@@ -120,4 +113,4 @@ def main(
if __name__ == "__main__":
app()
main()
-1
View File
@@ -3,5 +3,4 @@ pydantic
rel
requests
sseclient-py
typer
websocket-client
+10
View File
@@ -0,0 +1,10 @@
import pydantic
class Settings(pydantic.BaseSettings):
backend_url: str = "ws://localhost:8000"
model_id: str = "distilgpt2"
inference_server_url: str = "http://localhost:8001"
settings = Settings()
+9
View File
@@ -0,0 +1,9 @@
#!/bin/bash
text-generation-launcher &
/opt/miniconda/envs/worker/bin/python /worker &
wait -n
exit $?