[CI/Build] Update pixtral tests to use JSON (#8436)

This commit is contained in:
Cyrus Leung
2024-09-13 11:47:52 +08:00
committed by GitHub
parent 3f79bc3d1a
commit 8427550488
6 changed files with 42 additions and 18 deletions
+1 -1
View File
@@ -76,7 +76,7 @@ exclude = [
[tool.codespell]
ignore-words-list = "dout, te, indicies, subtile"
skip = "./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build"
[tool.isort]
use_parentheses = true
File diff suppressed because one or more lines are too long
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
+39 -17
View File
@@ -2,9 +2,10 @@
Run `pytest tests/models/test_mistral.py`.
"""
import pickle
import json
import uuid
from typing import Any, Dict, List
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple
import pytest
from mistral_common.protocol.instruct.messages import ImageURLChunk
@@ -14,6 +15,7 @@ from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.sequence import Logprob, SampleLogprobs
from .utils import check_logprobs_close
@@ -81,13 +83,33 @@ SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
LIMIT_MM_PER_PROMPT = dict(image=4)
MAX_MODEL_LEN = [8192, 65536]
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.json"
OutputsLogprobs = List[Tuple[List[int], str, Optional[SampleLogprobs]]]
def load_logprobs(filename: str) -> Any:
with open(filename, 'rb') as f:
return pickle.load(f)
# For the test author to store golden output in JSON
def _dump_outputs_w_logprobs(outputs: OutputsLogprobs, filename: str) -> None:
json_data = [(tokens, text,
[{k: asdict(v)
for k, v in token_logprobs.items()}
for token_logprobs in (logprobs or [])])
for tokens, text, logprobs in outputs]
with open(filename, "w") as f:
json.dump(json_data, f)
def load_outputs_w_logprobs(filename: str) -> OutputsLogprobs:
with open(filename, "rb") as f:
json_data = json.load(f)
return [(tokens, text,
[{int(k): Logprob(**v)
for k, v in token_logprobs.items()}
for token_logprobs in logprobs])
for tokens, text, logprobs in json_data]
@pytest.mark.skip(
@@ -103,7 +125,7 @@ def test_chat(
model: str,
dtype: str,
) -> None:
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
with vllm_runner(
model,
dtype=dtype,
@@ -120,10 +142,10 @@ def test_chat(
outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@pytest.mark.skip(
@@ -133,7 +155,7 @@ def test_chat(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
@@ -162,7 +184,7 @@ def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=logprobs,
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
name_0="output",
name_1="h100_ref")
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")