Skip to content

Commit

Permalink
[VLM] Merged multi-modal processors for LLaVA-NeXT-Video and LLaVA-On…
Browse files Browse the repository at this point in the history
…eVision (vllm-project#11717)

Signed-off-by: DarkLight1337 <[email protected]>
  • Loading branch information
DarkLight1337 authored Jan 4, 2025
1 parent 300acb8 commit eed11eb
Show file tree
Hide file tree
Showing 31 changed files with 1,114 additions and 983 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
from PIL import Image
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext

from ....utils import build_model_context


# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_next():
from vllm.model_executor.models.llava_next import (
LlavaNextMultiModalProcessor)
return LlavaNextMultiModalProcessor


# FIXME: image_size [(198, 176), (176, 198)]
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_next,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaNextMultiModalProcessor handles prompt replacement properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}

# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_next(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs

first_placeholder = image_placeholders[0]

# NOTE: There is a BOS token
assert first_placeholder["offset"] == 1
assert first_placeholder["length"] == (
len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
from PIL import Image
from transformers import AutoTokenizer

from vllm.inputs import InputProcessingContext

from ....utils import build_model_context


# Fixtures lazy import to avoid initializing CUDA during test collection
@pytest.fixture()
def processor_for_llava_onevision():
from vllm.model_executor.models.llava_onevision import (
LlavaOnevisionMultiModalProcessor)
return LlavaOnevisionMultiModalProcessor


@pytest.mark.parametrize("model_id",
["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"])
@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488),
(488, 183), (198, 176), (176, 198)])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_prompt_replacements(
processor_for_llava_onevision,
model_id: str,
image_size: tuple[int, int],
num_imgs: int,
):
"""
Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement
properly.
"""
ctx = build_model_context(
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

# Build the image str / prompt based on the number of images we pass
prompt = "<image>" * num_imgs
mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs}

# The processor will throw an error if there is a mismatch
# in the prompt replacements
processor = processor_for_llava_onevision(ctx)
processed_inputs = processor.apply(prompt, mm_data, {})

image_placeholders = processed_inputs["mm_placeholders"]["image"]
assert len(image_placeholders) == num_imgs

first_placeholder = image_placeholders[0]

# NOTE: There is a BOS token
assert first_placeholder["offset"] == 0
assert first_placeholder["length"] == len(
processed_inputs["prompt_token_ids"]) // num_imgs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Tests for phi3v's multimodal preprocessing kwargs."""
from typing import Optional

import pytest
from transformers import AutoTokenizer

Expand All @@ -10,8 +8,6 @@
from .....conftest import _ImageAssets
from ....utils import build_model_context

models = ["microsoft/Phi-3.5-vision-instruct"]


# Wrap lazy imports to avoid initializing CUDA during test collection
@pytest.fixture()
Expand All @@ -20,40 +16,40 @@ def processor_for_phi3v():
return Phi3VMultiModalProcessor


@pytest.mark.parametrize("model", models)
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
# yapf: disable
@pytest.mark.parametrize(
"num_crops,expected_toks_per_img",
("mm_processor_kwargs", "expected_toks_per_img"),
[
(4, 757),
(16, 1921),
({"num_crops": 4}, 757),
({"num_crops": 16}, 1921),
# the default num_crops of phi-3.5-vision is 4
(None, 757),
({}, 757),
])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets,
model: str, num_crops: Optional[int],
expected_toks_per_img: int, num_imgs: int):
def test_processor_override(
processor_for_phi3v,
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,
num_imgs: int,
):
"""Ensure input_processor_for_phi3v handles num_crops properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
model_name=model_id,
tokenizer_name=model_id,
trust_remote_code=True,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

# Build the image str / prompt based on the number of images we pass
img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)])
prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n"
images = [image_assets[0].pil_image] * num_imgs

mm_data = {"image": images}
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs = {"num_crops": num_crops}
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}

processor = processor_for_phi3v(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any, Dict, Tuple

import pytest
from transformers import AutoTokenizer

Expand All @@ -8,56 +6,45 @@
from .....conftest import _ImageAssets
from ....utils import build_model_context

MODEL = "Qwen/Qwen2-VL-2B-Instruct"
MIN_PIXELS = "min_pixels"
MAX_PIXELS = "max_pixels"


# Fixtures lazy import to avoid initializing CUDA during test collection
# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple
# input mappers.
@pytest.fixture()
def processor_for_qwen2_vl():
from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor
return Qwen2VLMultiModalProcessor


@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])
# yapf: disable
@pytest.mark.parametrize(
"mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [
("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [
({}, 1426, (5704, 1176)),
({
MIN_PIXELS: 64**2,
MAX_PIXELS: 512**2
}, 330, (1320, 1176)),
({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)),
])
@pytest.mark.parametrize("model", [MODEL])
# yapf: enable
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
processor_for_qwen2_vl,
image_assets: _ImageAssets,
model: str,
mm_processor_kwargs: Dict[str, Any],
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,
expected_pixels_shape: Tuple[int, int],
expected_pixels_shape: tuple[int, int],
num_imgs: int,
):
"""Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly."""
# Same as the previous test - don't initialize mm_processor_kwargs
# in this test and assume that the kwargs will be correctly expanded by
# the partial when calling the custom input processor.
ctx = build_model_context(
model_name=model,
tokenizer_name=model,
model_name=model_id,
tokenizer_name=model_id,
mm_processor_kwargs=None,
limit_mm_per_prompt={"image": num_imgs},
)
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
ctx = InputProcessingContext(ctx.model_config, tokenizer)

# Build the image str / prompt based on the number of images we pass
prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs
images = [image_assets[0].pil_image] * num_imgs

mm_data = {"image": images}
mm_data = {"image": [image_assets[0].pil_image] * num_imgs}

processor = processor_for_qwen2_vl(ctx)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
Expand Down
9 changes: 2 additions & 7 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,8 @@
),
limit_mm_per_prompt={"image": 4},
)],
# Llava-next tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
),
"llava_one_vision": VLMTestInfo(
"llava_onevision": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS,
prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
Expand All @@ -288,8 +286,6 @@
),
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output,
# Llava-one-vision tests fixed sizes & the default size factors
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
custom_test_opts=[CustomTestOptions(
inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs(
formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
Expand All @@ -306,7 +302,6 @@
max_model_len=4096,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output,
image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))],
),
"mantis": VLMTestInfo(
models=["TIGER-Lab/Mantis-8B-siglip-llama3"],
Expand Down Expand Up @@ -431,7 +426,7 @@
) for inp in custom_inputs.different_patch_input_cases_internvl()
],
),
"llava_one_vision-multiple-images": VLMTestInfo(
"llava_onevision-multiple-images": VLMTestInfo(
models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"],
test_type=VLMTestType.CUSTOM_INPUTS,
max_model_len=16384,
Expand Down
Loading

0 comments on commit eed11eb

Please sign in to comment.