mirror of
https://github.com/wassname/Open-Assistant.git
synced 2026-06-27 16:10:30 +08:00
added dockerfile for worker-full
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
FROM ghcr.io/huggingface/text-generation-inference
|
||||
|
||||
ARG MODULE="inference"
|
||||
ARG SERVICE="worker"
|
||||
|
||||
ARG APP_RELATIVE_PATH="${MODULE}/${SERVICE}"
|
||||
|
||||
WORKDIR /worker
|
||||
COPY ./oasst-shared /tmp/oasst-shared
|
||||
RUN conda create -n worker python=3.10 -y
|
||||
RUN /opt/miniconda/envs/worker/bin/pip install /tmp/oasst-shared
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/requirements.txt .
|
||||
RUN /opt/miniconda/envs/worker/bin/pip install -r requirements.txt
|
||||
|
||||
COPY ./${APP_RELATIVE_PATH}/*.py .
|
||||
COPY ./${APP_RELATIVE_PATH}/worker_full_main.sh /entrypoint.sh
|
||||
|
||||
ENV MODEL_ID="distilgpt2"
|
||||
ENV INFERENCE_SERVER_URL="http://localhost:80"
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
@@ -2,26 +2,19 @@ import interface
|
||||
import rel
|
||||
import requests
|
||||
import sseclient
|
||||
import typer
|
||||
import utils
|
||||
import websocket
|
||||
from loguru import logger
|
||||
from oasst_shared.schemas import inference, protocol
|
||||
|
||||
app = typer.Typer()
|
||||
from settings import settings
|
||||
|
||||
|
||||
@app.command()
|
||||
def main(
|
||||
backend_url: str = "ws://localhost:8000",
|
||||
model_name: str = "distilgpt2",
|
||||
inference_server_url: str = "http://localhost:8001",
|
||||
):
|
||||
utils.wait_for_inference_server(inference_server_url)
|
||||
def main():
|
||||
utils.wait_for_inference_server(settings.inference_server_url)
|
||||
|
||||
def on_open(ws: websocket.WebSocket):
|
||||
logger.info("Connected to backend, sending config...")
|
||||
worker_config = inference.WorkerConfig(model_name=model_name)
|
||||
worker_config = inference.WorkerConfig(model_name=settings.model_name)
|
||||
ws.send(worker_config.json())
|
||||
logger.info("Config sent, waiting for work...")
|
||||
|
||||
@@ -47,7 +40,7 @@ def main(
|
||||
|
||||
parameters = interface.GenerateStreamParameters.from_work_request(work_request)
|
||||
response = requests.post(
|
||||
f"{inference_server_url}/generate_stream",
|
||||
f"{settings.inference_server_url}/generate_stream",
|
||||
json={
|
||||
"inputs": prompt,
|
||||
"parameters": parameters.dict(),
|
||||
@@ -107,7 +100,7 @@ def main(
|
||||
logger.warning(f"Connection closed: {close_status_code=} {close_msg=}")
|
||||
|
||||
ws = websocket.WebSocketApp(
|
||||
f"{backend_url}/work",
|
||||
f"{settings.backend_url}/work",
|
||||
on_message=on_message,
|
||||
on_error=on_error,
|
||||
on_close=on_close,
|
||||
@@ -120,4 +113,4 @@ def main(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
main()
|
||||
|
||||
@@ -3,5 +3,4 @@ pydantic
|
||||
rel
|
||||
requests
|
||||
sseclient-py
|
||||
typer
|
||||
websocket-client
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
import pydantic
|
||||
|
||||
|
||||
class Settings(pydantic.BaseSettings):
|
||||
backend_url: str = "ws://localhost:8000"
|
||||
model_id: str = "distilgpt2"
|
||||
inference_server_url: str = "http://localhost:8001"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
Executable
+9
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
text-generation-launcher &
|
||||
|
||||
/opt/miniconda/envs/worker/bin/python /worker &
|
||||
|
||||
wait -n
|
||||
|
||||
exit $?
|
||||
Reference in New Issue
Block a user