OpenAI Server refactoring (#2360)

This commit is contained in:
FlorianJoncour
2024-01-17 05:33:14 +00:00
committed by GitHub
parent e1957c6ebd
commit 14cc317ba4
8 changed files with 954 additions and 643 deletions
@@ -1,12 +1,12 @@
from argparse import Namespace
from dataclasses import dataclass
import os
import pathlib
import pytest
from fastapi.testclient import TestClient
from vllm.entrypoints.openai.api_server import *
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath(
__file__))).parent.parent / "examples/template_chatml.jinja"
@@ -48,7 +48,6 @@ TEST_MESSAGES = [
'content': 'What is the capital of'
},
]
client = TestClient(app)
@dataclass
@@ -56,13 +55,17 @@ class MockTokenizer:
chat_template = None
@dataclass
class MockServingChat:
tokenizer: MockTokenizer
def test_load_chat_template():
# Testing chatml template
mock_args = Namespace(chat_template=chatml_jinja_path)
tokenizer = MockTokenizer()
# Call the function with the mocked args
load_chat_template(mock_args, tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=chatml_jinja_path)
template_content = tokenizer.chat_template
@@ -76,11 +79,11 @@ def test_load_chat_template():
def test_no_load_chat_template():
# Testing chatml template
template = "../../examples/does_not_exist"
mock_args = Namespace(chat_template=template)
tokenizer = MockTokenizer()
# Call the function with the mocked args
load_chat_template(mock_args, tokenizer=tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
template_content = tokenizer.chat_template
# Test assertions
@@ -97,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
expected_output):
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
mock_args = Namespace(chat_template=template)
load_chat_template(mock_args, tokenizer)
mock_serving_chat = MockServingChat(tokenizer)
OpenAIServingChat._load_chat_template(mock_serving_chat,
chat_template=template)
# Create a mock request object using keyword arguments
mock_request = ChatCompletionRequest(
@@ -115,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt,
# Test assertion
assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}"
def test_health_endpoint():
response = client.get("/health")
assert response.status_code == 200