diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/__init__.py b/tests/models/decoder_only/vision_language/processing/__init__.py similarity index 100% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/__init__.py rename to tests/models/decoder_only/vision_language/processing/__init__.py diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_idefics3.py b/tests/models/decoder_only/vision_language/processing/test_idefics3.py similarity index 100% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/test_idefics3.py rename to tests/models/decoder_only/vision_language/processing/test_idefics3.py diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py b/tests/models/decoder_only/vision_language/processing/test_internvl.py similarity index 100% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py rename to tests/models/decoder_only/vision_language/processing/test_internvl.py diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_next.py b/tests/models/decoder_only/vision_language/processing/test_llava_next.py new file mode 100644 index 0000000000000..6772130c9b884 --- /dev/null +++ b/tests/models/decoder_only/vision_language/processing/test_llava_next.py @@ -0,0 +1,58 @@ +import pytest +from PIL import Image +from transformers import AutoTokenizer + +from vllm.inputs import InputProcessingContext + +from ....utils import build_model_context + + +# Fixtures lazy import to avoid initializing CUDA during test collection +@pytest.fixture() +def processor_for_llava_next(): + from vllm.model_executor.models.llava_next import ( + LlavaNextMultiModalProcessor) + return LlavaNextMultiModalProcessor + + +# FIXME: image_size [(198, 176), (176, 198)] +@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488), + (488, 183)]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_prompt_replacements( + processor_for_llava_next, + model_id: str, + image_size: tuple[int, int], + num_imgs: int, +): + """ + Ensure LlavaNextMultiModalProcessor handles prompt replacement properly. + """ + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) + + # Build the image str / prompt based on the number of images we pass + prompt = "" * num_imgs + mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs} + + # The processor will throw an error if there is a mismatch + # in the prompt replacements + processor = processor_for_llava_next(ctx) + processed_inputs = processor.apply(prompt, mm_data, {}) + + image_placeholders = processed_inputs["mm_placeholders"]["image"] + assert len(image_placeholders) == num_imgs + + first_placeholder = image_placeholders[0] + + # NOTE: There is a BOS token + assert first_placeholder["offset"] == 1 + assert first_placeholder["length"] == ( + len(processed_inputs["prompt_token_ids"]) - 1) // num_imgs diff --git a/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py new file mode 100644 index 0000000000000..71adde6568a17 --- /dev/null +++ b/tests/models/decoder_only/vision_language/processing/test_llava_onevision.py @@ -0,0 +1,59 @@ +import pytest +from PIL import Image +from transformers import AutoTokenizer + +from vllm.inputs import InputProcessingContext + +from ....utils import build_model_context + + +# Fixtures lazy import to avoid initializing CUDA during test collection +@pytest.fixture() +def processor_for_llava_onevision(): + from vllm.model_executor.models.llava_onevision import ( + LlavaOnevisionMultiModalProcessor) + return LlavaOnevisionMultiModalProcessor + + +@pytest.mark.parametrize("model_id", + ["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"]) +@pytest.mark.parametrize("image_size", [(1669, 2560), (2560, 1669), (183, 488), + (488, 183), (198, 176), (176, 198)]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_processor_prompt_replacements( + processor_for_llava_onevision, + model_id: str, + image_size: tuple[int, int], + num_imgs: int, +): + """ + Ensure LlavaOnevisionMultiModalProcessor handles prompt replacement + properly. + """ + ctx = build_model_context( + model_name=model_id, + tokenizer_name=model_id, + mm_processor_kwargs=None, + limit_mm_per_prompt={"image": num_imgs}, + ) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + ctx = InputProcessingContext(ctx.model_config, tokenizer) + + # Build the image str / prompt based on the number of images we pass + prompt = "" * num_imgs + mm_data = {"image": [Image.new("RGB", size=image_size)] * num_imgs} + + # The processor will throw an error if there is a mismatch + # in the prompt replacements + processor = processor_for_llava_onevision(ctx) + processed_inputs = processor.apply(prompt, mm_data, {}) + + image_placeholders = processed_inputs["mm_placeholders"]["image"] + assert len(image_placeholders) == num_imgs + + first_placeholder = image_placeholders[0] + + # NOTE: There is a BOS token + assert first_placeholder["offset"] == 0 + assert first_placeholder["length"] == len( + processed_inputs["prompt_token_ids"]) // num_imgs diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py b/tests/models/decoder_only/vision_language/processing/test_phi3v.py similarity index 60% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py rename to tests/models/decoder_only/vision_language/processing/test_phi3v.py index 3edf96d11106d..249045b3c04ce 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_phi3v.py +++ b/tests/models/decoder_only/vision_language/processing/test_phi3v.py @@ -1,6 +1,4 @@ """Tests for phi3v's multimodal preprocessing kwargs.""" -from typing import Optional - import pytest from transformers import AutoTokenizer @@ -10,8 +8,6 @@ from .....conftest import _ImageAssets from ....utils import build_model_context -models = ["microsoft/Phi-3.5-vision-instruct"] - # Wrap lazy imports to avoid initializing CUDA during test collection @pytest.fixture() @@ -20,40 +16,40 @@ def processor_for_phi3v(): return Phi3VMultiModalProcessor -@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"]) +# yapf: disable @pytest.mark.parametrize( - "num_crops,expected_toks_per_img", + ("mm_processor_kwargs", "expected_toks_per_img"), [ - (4, 757), - (16, 1921), + ({"num_crops": 4}, 757), + ({"num_crops": 16}, 1921), # the default num_crops of phi-3.5-vision is 4 - (None, 757), + ({}, 757), ]) +# yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) -def test_processor_override(processor_for_phi3v, image_assets: _ImageAssets, - model: str, num_crops: Optional[int], - expected_toks_per_img: int, num_imgs: int): +def test_processor_override( + processor_for_phi3v, + image_assets: _ImageAssets, + model_id: str, + mm_processor_kwargs: dict[str, int], + expected_toks_per_img: int, + num_imgs: int, +): """Ensure input_processor_for_phi3v handles num_crops properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the custom input processor. ctx = build_model_context( - model_name=model, - tokenizer_name=model, + model_name=model_id, + tokenizer_name=model_id, trust_remote_code=True, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) + # Build the image str / prompt based on the number of images we pass img_str = "".join([f"<|image_{idx}|>\n" for idx in range(1, num_imgs + 1)]) prompt = f"<|user|>\n{img_str}<|end|>\n<|assistant|>\n" - images = [image_assets[0].pil_image] * num_imgs - - mm_data = {"image": images} - mm_processor_kwargs = {} - if num_crops is not None: - mm_processor_kwargs = {"num_crops": num_crops} + mm_data = {"image": [image_assets[0].pil_image] * num_imgs} processor = processor_for_phi3v(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py b/tests/models/decoder_only/vision_language/processing/test_qwen.py similarity index 100% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen.py rename to tests/models/decoder_only/vision_language/processing/test_qwen.py diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py similarity index 64% rename from tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py rename to tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py index 1f0b482666723..b9ac887edf90f 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/processing/test_qwen2_vl.py @@ -1,5 +1,3 @@ -from typing import Any, Dict, Tuple - import pytest from transformers import AutoTokenizer @@ -8,56 +6,45 @@ from .....conftest import _ImageAssets from ....utils import build_model_context -MODEL = "Qwen/Qwen2-VL-2B-Instruct" -MIN_PIXELS = "min_pixels" -MAX_PIXELS = "max_pixels" - # Fixtures lazy import to avoid initializing CUDA during test collection -# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple -# input mappers. @pytest.fixture() def processor_for_qwen2_vl(): from vllm.model_executor.models.qwen2_vl import Qwen2VLMultiModalProcessor return Qwen2VLMultiModalProcessor +@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) +# yapf: disable @pytest.mark.parametrize( - "mm_processor_kwargs, expected_toks_per_img, expected_pixels_shape", [ + ("mm_processor_kwargs", "expected_toks_per_img", "expected_pixels_shape"), [ ({}, 1426, (5704, 1176)), - ({ - MIN_PIXELS: 64**2, - MAX_PIXELS: 512**2 - }, 330, (1320, 1176)), + ({"min_pixels": 64**2, "max_pixels": 512**2}, 330, (1320, 1176)), ]) -@pytest.mark.parametrize("model", [MODEL]) +# yapf: enable @pytest.mark.parametrize("num_imgs", [1, 2]) def test_processor_override( processor_for_qwen2_vl, image_assets: _ImageAssets, - model: str, - mm_processor_kwargs: Dict[str, Any], + model_id: str, + mm_processor_kwargs: dict[str, object], expected_toks_per_img: int, - expected_pixels_shape: Tuple[int, int], + expected_pixels_shape: tuple[int, int], num_imgs: int, ): """Ensure Qwen2VLMultiModalProcessor handles min/max pixels properly.""" - # Same as the previous test - don't initialize mm_processor_kwargs - # in this test and assume that the kwargs will be correctly expanded by - # the partial when calling the custom input processor. ctx = build_model_context( - model_name=model, - tokenizer_name=model, + model_name=model_id, + tokenizer_name=model_id, mm_processor_kwargs=None, limit_mm_per_prompt={"image": num_imgs}, ) - tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) ctx = InputProcessingContext(ctx.model_config, tokenizer) + # Build the image str / prompt based on the number of images we pass prompt = "<|vision_start|><|image_pad|><|vision_end|>" * num_imgs - images = [image_assets[0].pil_image] * num_imgs - - mm_data = {"image": images} + mm_data = {"image": [image_assets[0].pil_image] * num_imgs} processor = processor_for_qwen2_vl(ctx) processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs) diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 7db08166826eb..dc0b683c1f1cb 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -274,10 +274,8 @@ ), limit_mm_per_prompt={"image": 4}, )], - # Llava-next tests fixed sizes & the default size factors - image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], ), - "llava_one_vision": VLMTestInfo( + "llava_onevision": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.CUSTOM_INPUTS, prompt_formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 @@ -288,8 +286,6 @@ ), auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, - # Llava-one-vision tests fixed sizes & the default size factors - image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], custom_test_opts=[CustomTestOptions( inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 @@ -306,7 +302,6 @@ max_model_len=4096, auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, - image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], ), "mantis": VLMTestInfo( models=["TIGER-Lab/Mantis-8B-siglip-llama3"], @@ -431,7 +426,7 @@ ) for inp in custom_inputs.different_patch_input_cases_internvl() ], ), - "llava_one_vision-multiple-images": VLMTestInfo( + "llava_onevision-multiple-images": VLMTestInfo( models=["llava-hf/llava-onevision-qwen2-0.5b-ov-hf"], test_type=VLMTestType.CUSTOM_INPUTS, max_model_len=16384, diff --git a/tests/models/decoder_only/vision_language/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/test_qwen2_vl.py index 51fe7d2ad32a8..16e256e040a74 100644 --- a/tests/models/decoder_only/vision_language/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/test_qwen2_vl.py @@ -427,130 +427,3 @@ def test_qwen2_vl_video_embeddings_input(vllm_runner, video_assets, model, mm_limit=1, tensor_parallel_size=1, ) - - -def run_chunked_prefill_test( - vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], PromptImageInput, PromptVideoInput]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - mm_limit: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Compare inference result between - chunked prefill disabled and chunked prefill enabled - """ - - # NOTE: - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - task="generate", - max_model_len=4000, - max_num_seqs=4, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend - ) as vllm_model: - - outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images or None, - videos=videos or None) - for prompts, images, videos in inputs - ] - - with vllm_runner( - model, - task="generate", - max_model_len=4000, - max_num_seqs=4, - dtype=dtype, - limit_mm_per_prompt={ - "image": mm_limit, - "video": mm_limit - }, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enable_chunked_prefill=True, - # should be small enough to ensure prefilling is chunked - max_num_batched_tokens=32, - mm_processor_kwargs={ - "max_pixels": 16 * 28 * 28, - }) as vllm_model_chunked: - outputs_per_case_chunked = [ - vllm_model_chunked.generate_greedy_logprobs( - prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images or None, - videos=videos or None) for prompts, images, videos in inputs - ] - - for outputs, \ - outputs_chunked \ - in zip(outputs_per_case, - outputs_per_case_chunked): - check_logprobs_close( - outputs_0_lst=outputs, - outputs_1_lst=outputs_chunked, - name_0="non_chunked", - name_1="chunked", - ) - - -@pytest.mark.core_model -@pytest.mark.parametrize("model", models) -@pytest.mark.parametrize("dtype", [target_dtype]) -@pytest.mark.parametrize("max_tokens", [1]) -@pytest.mark.parametrize("num_logprobs", [10]) -def test_qwen2_vl_mrope_chunked_prefill(vllm_runner, example_prompts, - model: str, dtype: str, - max_tokens: int, - num_logprobs: int) -> None: - """ - Test Qwen2-VL's chunked prefill with M-RoPE - """ - prompts = [ - qwen2_vl_chat_template(IMAGE_PLACEHOLDER, prompt) - for prompt in example_prompts[:1] - ] - - # 1. Qwen2-VL's M-RoPE works only when there are some multi-modal inputs, - # so an image is included in the inputs - # 2. however, Qwen2-VL currently won't work properly - # when chunked prefill is enabled and there are some multi-modal inputs, - # here use a hacky way: provide a **zero-length** image to make it happy - # - # and finally we achieved: - # (1) chunked_prefill enabled; (2) M-RoPE works; to continue our tests - zero_len_image = { - "image_embeds": torch.empty((0, MODEL_HIDDEN_SIZE)), - "image_grid_thw": torch.tensor([[0, 0, 0]]) - } - images = [zero_len_image] * len(prompts) - - inputs_per_case: List[Tuple[List[str], PromptImageInput, - PromptVideoInput]] = [ - (prompts, images, []), - ] - - run_chunked_prefill_test( - vllm_runner, - inputs_per_case, - model, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - mm_limit=1, - tensor_parallel_size=1, - ) diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py index f99d7556b27f9..b32faa699ebf2 100644 --- a/tests/multimodal/test_processing.py +++ b/tests/multimodal/test_processing.py @@ -11,8 +11,8 @@ from vllm.inputs import InputProcessingContext from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.processing import (ProcessingCache, PromptReplacement, - _PlaceholderInfo, find_text_matches, - find_token_matches, iter_placeholders, + _PlaceholderInfo, find_mm_placeholders, + find_text_matches, find_token_matches, iter_token_matches, replace_text_matches, replace_token_matches) @@ -314,21 +314,27 @@ def test_find_replace_text( # Should not be used since there is nothing to convert to text mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) + mm_prompt_repls = { + key: [ + PromptReplacement(key, target, + repl_by_key[key]).bind(mock_tokenizer) + ] for key, target in target_by_key.items() - ] - matches = find_text_matches(prompt, prompt_repls) + } + mm_matches = { + key: find_text_matches(prompt, prompt_repls) + for key, prompt_repls in mm_prompt_repls.items() + } result = replace_text_matches( prompt, - matches, + mm_matches, {key: mm_count for key in repl_by_key}, ) # Only displayed on error - print("matches:", matches) + print("mm_matches:", mm_matches) print("result:", result) # Manually constructed results @@ -380,21 +386,27 @@ def test_find_replace_tokens( # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer) + mm_prompt_repls = { + key: [ + PromptReplacement(key, target, + repl_by_key[key]).bind(mock_tokenizer) + ] for key, target in target_by_key.items() - ] - matches = find_token_matches(prompt, prompt_repls) + } + mm_matches = { + key: find_token_matches(prompt, prompt_repls) + for key, prompt_repls in mm_prompt_repls.items() + } result = replace_token_matches( prompt, - matches, + mm_matches, {key: mm_count for key in repl_by_key}, ) # Only displayed on error - print("matches:", matches) + print("mm_matches:", mm_matches) print("result:", result) # Manually constructed results @@ -417,58 +429,76 @@ def test_find_replace_tokens( [ ( [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918], - [ - _PlaceholderInfo( - modality="pattern_1", - start_idx=6, - replacement=[32000, 32000], - ), - ], + { + "pattern_1": [ + _PlaceholderInfo( + modality="pattern_1", + item_idx=0, + start_idx=6, + replacement=[32000, 32000], + ), + ], + } + ), ( [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550], - [ - _PlaceholderInfo( - modality="pattern_1", - start_idx=1, - replacement=[32000, 32000], - ), - _PlaceholderInfo( - modality="pattern_1", - start_idx=5, - replacement=[32000, 32000], - ), - _PlaceholderInfo( - modality="pattern_3", - start_idx=7, - replacement=[1550, 918, 1550], - ), - ], + { + "pattern_1": [ + _PlaceholderInfo( + modality="pattern_1", + item_idx=0, + start_idx=1, + replacement=[32000, 32000], + ), + _PlaceholderInfo( + modality="pattern_1", + item_idx=1, + start_idx=5, + replacement=[32000, 32000], + ), + ], + "pattern_3": [ + _PlaceholderInfo( + modality="pattern_3", + item_idx=0, + start_idx=7, + replacement=[1550, 918, 1550], + ), + ], + } ), ( [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550], - [ - _PlaceholderInfo( - modality="pattern_1", - start_idx=1, - replacement=[32000, 32000], - ), - _PlaceholderInfo( - modality="pattern_1", - start_idx=3, - replacement=[32000, 32000], - ), - _PlaceholderInfo( - modality="pattern_3", - start_idx=6, - replacement=[1550, 918, 1550], - ), - ], + { + "pattern_1": [ + _PlaceholderInfo( + modality="pattern_1", + item_idx=0, + start_idx=1, + replacement=[32000, 32000], + ), + _PlaceholderInfo( + modality="pattern_1", + item_idx=1, + start_idx=3, + replacement=[32000, 32000], + ), + ], + "pattern_3": [ + _PlaceholderInfo( + modality="pattern_3", + item_idx=0, + start_idx=6, + replacement=[1550, 918, 1550], + ), + ], + } ), ] ) # yapf: enable -def test_iter_placeholders( +def test_find_mm_placeholders( repl_by_key, prompt, expected, @@ -476,19 +506,18 @@ def test_iter_placeholders( # Should not be used since there is nothing to convert to tokens mock_tokenizer = cast(AnyTokenizer, object()) - prompt_repls = [ - PromptReplacement(key, [], repl).bind(mock_tokenizer) + mm_prompt_repls = { + key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)] for key, repl in repl_by_key.items() - ] + } - result = list( - iter_placeholders( - prompt_repls, - prompt, - # Effectively match all occurrences in the prompt - {key: 3 - for key in repl_by_key}, - )) + result = find_mm_placeholders( + mm_prompt_repls, + prompt, + # Effectively match all occurrences in the prompt + {key: 3 + for key in repl_by_key}, + ) # Only displayed on error print("result:", result) @@ -694,7 +723,10 @@ def _test_processing_cache_correctness( } mm_counts = {k: len(vs) for k, vs in mm_data.items()} - prompt = baseline_processor._get_dummy_mm_inputs(mm_counts).prompt_text + prompt = baseline_processor._get_dummy_processor_inputs( + model_config.max_model_len, + mm_counts, + ).prompt_text # Drop unnecessary keys and test single -> multi conversion if rng.rand() < simplify_rate: @@ -728,6 +760,8 @@ def _test_processing_cache_correctness( ("adept/fuyu-8b", {"image": False}), ("llava-hf/llava-1.5-7b-hf", {"image": True}), ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}), + ("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}), + ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}), # noqa: E501 ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}), ("mistral-community/pixtral-12b", {"image": True}), ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}), diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index 4f0d679bd6c28..2fd4262a9d3b9 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -456,7 +456,7 @@ def _get_num_image_tokens(self) -> int: hf_config = self.ctx.get_hf_config() return max(hf_config.projector_patch_to_query_dict.values()) - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} def _get_mm_fields_config( @@ -488,8 +488,9 @@ def _get_prompt_replacements( ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: hf_config = self.ctx.get_hf_config() diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 0fe10d8585215..b3ecb2f22dc19 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -405,7 +405,7 @@ def _get_num_image_tokens(self) -> int: hf_config = self.ctx.get_hf_config(Blip2Config) return hf_config.num_query_tokens - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} def _get_hf_processor(self) -> Blip2Processor: @@ -457,8 +457,9 @@ def apply( return result - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: hf_config = self.ctx.get_hf_config(Blip2Config) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 0bd0194243ceb..1ad44678a591d 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -57,7 +57,7 @@ def _get_num_image_tokens(self) -> int: processor = self._get_hf_processor() return processor.image_seq_length - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: return {"image": self._get_num_image_tokens()} def _get_hf_processor(self) -> ChameleonProcessor: @@ -90,8 +90,9 @@ def _get_prompt_replacements( ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: config = self.ctx.get_hf_config(ChameleonConfig) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 0188452054b8c..1bde45cb140cb 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -164,15 +164,18 @@ def get_num_image_tokens( def get_max_image_tokens(self) -> int: return get_max_clip_image_tokens(self.vision_config) - def get_num_patches(self) -> int: + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: return get_clip_patch_grid_length( image_size=self.vision_config.image_size, patch_size=self.vision_config.patch_size, ) - def get_image_size(self) -> int: - return self.vision_config.image_size - # Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa class CLIPVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 3680d01725238..7cd58fbc7cf21 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -96,7 +96,7 @@ def _get_image_feature_grid_size( nrows = math.ceil(image_height / 30) return ncols, nrows - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: target_width, target_height = self._get_image_target_size() max_ncols, max_nrows = self._get_image_feature_grid_size( @@ -208,8 +208,9 @@ def apply( return result - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: target_width, target_height = self._get_image_target_size() diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 78de27cd821c6..d522378e0bebb 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -25,11 +25,9 @@ NestedTensors) from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, ImageSize) -from vllm.multimodal.processing import (BaseMultiModalProcessor, - InputProcessingContext, +from vllm.multimodal.processing import (InputProcessingContext, MultiModalDataItems, ProcessingCache, - ProcessorInputs, PromptReplacement, - full_groupby_modality) + ProcessorInputs, PromptReplacement) from vllm.sequence import IntermediateTensors from .clip import CLIPVisionModel @@ -39,7 +37,7 @@ from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -from .vision import vision_encoder_info +from .vision import BaseVisionLanguageMultiModalProcessor class LlavaImagePixelInputs(TypedDict): @@ -100,19 +98,7 @@ class LlavaLikeConfig(Protocol): vision_feature_layer: Final[Union[int, List[int]]] -class BaseLlavaMultiModalProcessor(BaseMultiModalProcessor): - - def __init__(self, - ctx: InputProcessingContext, - *, - cache: Optional[ProcessingCache] = None, - enable_sanity_checks: bool = True) -> None: - super().__init__(ctx, - cache=cache, - enable_sanity_checks=enable_sanity_checks) - - vision_config = self._get_hf_config().vision_config - self._vision_encoder_info = vision_encoder_info(vision_config) +class BaseLlavaMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): @abstractmethod def _get_hf_config(self) -> LlavaLikeConfig: @@ -121,6 +107,19 @@ def _get_hf_config(self) -> LlavaLikeConfig: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None} + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + return {"image": self._get_max_image_tokens()} + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + ) + def _apply_feature_select_strategy( self, strategy: str, @@ -142,19 +141,6 @@ def _get_max_image_tokens(self) -> int: self._vision_encoder_info.get_max_image_tokens(), ) - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: - return {"image": self._get_max_image_tokens()} - - def _get_mm_fields_config( - self, - hf_inputs: BatchFeature, - hf_processor_mm_kwargs: Mapping[str, object], - ) -> Mapping[str, MultiModalFieldConfig]: - return dict( - pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), - ) - def _get_dummy_image_size(self) -> ImageSize: image_size = self._vision_encoder_info.get_image_size() return ImageSize(image_size, image_size) @@ -163,8 +149,9 @@ def _get_dummy_image_size(self) -> ImageSize: def _get_image_token(self) -> str: raise NotImplementedError - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) @@ -709,7 +696,7 @@ def get_replacement_mantis(item_idx: int): ")", # 3 tokens ]) - mantis_repls = self._bind_prompt_replacements([ + mantis_mm_repls = self._bind_and_group_repls([ PromptReplacement( modality="image", target=[image_token_id] * num_image_tokens, @@ -719,7 +706,7 @@ def get_replacement_mantis(item_idx: int): prompt_ids, prompt_text, _ = self._apply_prompt_replacements( result["prompt_token_ids"], - mantis_repls, + mantis_mm_repls, mm_item_counts, ) @@ -728,15 +715,19 @@ def get_replacement_mantis(item_idx: int): hf_processor_mm_kwargs, mm_kwargs, ) - orig_repls = self._bind_prompt_replacements(unbound_orig_repls) + orig_repls = self._bind_and_group_repls(unbound_orig_repls) + + mm_placeholders = self._find_mm_placeholders( + orig_repls, + prompt_ids, + mm_item_counts, + ) - all_placeholders = self._find_placeholders(orig_repls, prompt_ids, - mm_item_counts) - assert len(all_placeholders) == mm_item_counts.get("image", 0) + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) - mm_placeholders = { - modality: [item.to_range() for item in items] - for modality, items in full_groupby_modality(all_placeholders) + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() } return MultiModalInputsV2( @@ -744,7 +735,7 @@ def get_replacement_mantis(item_idx: int): prompt=prompt_text, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, - mm_placeholders=mm_placeholders, + mm_placeholders=mm_placeholder_ranges, ) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 24debd1cbf3fe..3769f04f94a92 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -67,9 +67,6 @@ def _get_hf_config(self) -> LlavaNextConfig: def _get_hf_processor(self) -> LlavaNextProcessor: return self.ctx.get_hf_processor(LlavaNextProcessor) - def _get_image_token(self) -> str: - return self._get_hf_processor().image_token - def _get_mm_fields_config( self, hf_inputs: BatchFeature, @@ -81,6 +78,9 @@ def _get_mm_fields_config( image_embeds=MultiModalFieldConfig.batched("image"), ) + def _get_image_token(self) -> str: + return self._get_hf_processor().image_token + def _get_max_image_tokens(self) -> int: largest_feature_size, _ = self._get_pinpoint_with_most_features() return largest_feature_size @@ -97,20 +97,20 @@ def _get_num_image_tokens( image_height: int, ) -> int: hf_config = self._get_hf_config() + vision_encoder_info = self._vision_encoder_info base_feature_size = self._apply_feature_select_strategy( hf_config.vision_feature_select_strategy, - self._vision_encoder_info.get_num_image_tokens( + vision_encoder_info.get_num_image_tokens( image_width=image_width, image_height=image_height, ), ) - num_patches = self._vision_encoder_info.get_num_patches() num_patch_height, num_patch_width = get_anyres_image_grid_shape( image_size=(image_height, image_width), grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=self._vision_encoder_info.get_image_size(), + patch_size=vision_encoder_info.get_image_size(), ) ( @@ -119,7 +119,7 @@ def _get_num_image_tokens( ) = self._get_num_unpadded_features( original_height=image_height, original_width=image_width, - npatches=num_patches, + npatches=vision_encoder_info.get_patch_grid_length(), num_patch_height=num_patch_height, num_patch_width=num_patch_width, ) @@ -155,6 +155,7 @@ def _get_num_unpadded_features( unpadded_features = current_height * current_width newline_features = current_height + return (unpadded_features, newline_features) def _get_pinpoint_with_most_features(self) -> tuple[int, ImageSize]: diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 0de9d8c5ea572..ee6b89f0d4498 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -3,38 +3,32 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.nn as nn -from transformers import (CLIPVisionConfig, LlavaNextVideoConfig, - SiglipVisionConfig) +from transformers import (BatchFeature, LlavaNextVideoConfig, + LlavaNextVideoProcessor) from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import (ImageSize, MultiModalDataItems, + VideoEmbeddingItems, VideoProcessorItems) +from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava -from .siglip import (SiglipVisionModel, dummy_image_for_siglip, - dummy_seq_data_for_siglip) +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) - -# For profile run -_MAX_FRAMES_PER_VIDEO = 32 -_MAX_NUM_VIDEOS = 1 +from .vision import BaseVisionLanguageMultiModalProcessor class LlavaNextVideoPixelInputs(TypedDict): @@ -50,143 +44,148 @@ class LlavaNextVideoPixelInputs(TypedDict): """ -def get_llava_next_video_frame_feature_size( - hf_config: LlavaNextVideoConfig) -> int: - # Support both CLIPVisionConfig and SiglipVisionConfig - image_size = hf_config.vision_config.image_size - patch_size = hf_config.vision_config.patch_size - spatial_pool_stride = hf_config.spatial_pool_stride +class LlavaNextVideoMultiModalProcessor(BaseVisionLanguageMultiModalProcessor): - return int((image_size / patch_size / spatial_pool_stride)**2) + def _get_hf_config(self) -> LlavaNextVideoConfig: + return self.ctx.get_hf_config(LlavaNextVideoConfig) + def _get_hf_processor(self) -> LlavaNextVideoProcessor: + return self.ctx.get_hf_processor(LlavaNextVideoProcessor) -def _get_max_llm_tokens(ctx: InputContext) -> int: - """ - Calculated from the maximum video frames under the context length - constraints of the language model. - """ - hf_text_config = ctx.model_config.hf_text_config - model_config = ctx.model_config - max_tokens = model_config.max_model_len - rope_scaling = model_config.rope_scaling - - if rope_scaling: - rope_scaling_factor = hf_text_config.rope_scaling["factor"] - else: - rope_scaling_factor = 1 - - max_tokens *= rope_scaling_factor - - return max_tokens - - -def get_max_llava_next_video_tokens(ctx: InputContext) -> int: - # Currently set to 32 frames - # TODO: max_tokens = _get_max_llm_tokens(ctx) - hf_config = ctx.get_hf_config(LlavaNextVideoConfig) - tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config) - return _MAX_FRAMES_PER_VIDEO * tokens_per_frame - - -def dummy_data_for_llava_next_video(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaNextVideoConfig) - vision_config = hf_config.vision_config - - # TODO: support multiple videos - num_videos = mm_counts["video"] - if num_videos != _MAX_NUM_VIDEOS: - raise NotImplementedError( - f"Only {_MAX_NUM_VIDEOS} videos are supported") - - # TODO: support configuring the number of frames - frames_per_video = _MAX_FRAMES_PER_VIDEO - # num_images = num_videos * frames_per_video - - # fills the sequence with as longer video data as possible - tokens_per_frame = get_llava_next_video_frame_feature_size(hf_config) - video_feature_size = frames_per_video * tokens_per_frame - - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_videos, - image_token_id=hf_config.video_token_index, - image_feature_size_override=video_feature_size, - mm_key="video", - ) + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"video": 1} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + num_frames = self._get_dummy_num_frames(seq_len) + max_video_tokens = self._get_max_video_tokens(num_frames) + + return {"video": max_video_tokens} + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict(pixel_values_videos=MultiModalFieldConfig.batched("video")) + + def _get_num_frame_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + spatial_pool_stride = hf_config.spatial_pool_stride - pil_frame = dummy_image_for_clip(vision_config, num_images=1) - np_frame = np.array(pil_frame["image"]) - mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) - mm_data = {"video": mm_data_per_video} - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_videos, - image_token_id=hf_config.video_token_index, - image_feature_size_override=video_feature_size, - mm_key="video", + patch_grid_length = self._vision_encoder_info.get_patch_grid_length() + pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) + + return pooled_grid_length * pooled_grid_length + + def _get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + num_frame_tokens = self._get_num_frame_tokens( + image_width=image_width, + image_height=image_height, ) - pil_frame = dummy_image_for_siglip(vision_config, num_images=1) - np_frame = np.array(pil_frame["image"]) - mm_data_per_video = np.repeat([np_frame], frames_per_video, axis=0) - mm_data = {"video": mm_data_per_video} - return DummyData(seq_data, mm_data, ranges) + return num_frame_tokens * num_frames - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def _get_max_video_tokens(self, num_frames: int) -> int: + return self._get_num_video_tokens(image_width=999999, + image_height=999999, + num_frames=num_frames) + def _get_max_video_frames(self, max_tokens: int) -> int: + num_frames = 0 -def input_processor_for_llava_next_video(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "video" not in multi_modal_data: - return inputs + while True: + next_num_frames = num_frames + 1 - if "multi_modal_placeholders" in inputs and "video" in inputs[ - "multi_modal_placeholders"]: - # The inputs already have placeholders. - return inputs + if self._get_max_video_tokens(next_num_frames) > max_tokens: + break - video_data = multi_modal_data["video"] + num_frames = next_num_frames - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaNextVideoConfig) - vision_config = hf_config.vision_config + return num_frames - if isinstance(video_data, np.ndarray): - # Supports both CLIP and Siglip - num_frames = video_data.shape[0] - frame_feature_size = \ - get_llava_next_video_frame_feature_size(hf_config) - video_feature_size = num_frames * frame_feature_size + def _get_dummy_num_frames(self, seq_len: int) -> int: + mm_config = self.ctx.get_mm_config() + max_videos = mm_config.limit_per_prompt.get("video", 1) - tokenizer = cached_get_tokenizer(model_config.tokenizer) + max_total_frames = self._get_max_video_frames(seq_len) - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=hf_config.video_token_index, - repeat_count=video_feature_size, - ) + return max(max_total_frames // max(max_videos, 1), 1) - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"video": ranges}) + def _get_dummy_image_size(self) -> ImageSize: + image_size = self._vision_encoder_info.get_image_size() + return ImageSize(image_size, image_size) - elif is_list_of(video_data, np.ndarray): - raise NotImplementedError( - "Processing multiple videos is not supported") + def _get_video_token(self) -> str: + return self._get_hf_processor().video_token - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + hf_config = self._get_hf_config() + video_token_id = hf_config.video_token_index + + def get_replacement(item_idx: int): + videos = mm_items.get_items( + "video", (VideoEmbeddingItems, VideoProcessorItems)) + + if isinstance(videos, VideoEmbeddingItems): + num_video_tokens = videos.get_feature_size(item_idx) + else: + image_size = videos.get_frame_size(item_idx) + num_video_tokens = self._get_num_video_tokens( + image_width=image_size.width, + image_height=image_size.height, + num_frames=videos.get_num_frames(item_idx), + ) + + return [video_token_id] * num_video_tokens + + return [ + PromptReplacement( + modality="video", + target=[video_token_id], + replacement=get_replacement, + ), + ] + + def _get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_videos = mm_counts.get("video", 0) + + video_token = self._get_video_token() + target_width, target_height = self._get_dummy_image_size() + + mm_data = { + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=video_token * num_videos, + mm_data=mm_data, + ) # adopted from transformers modeling_llava_next_video.py @@ -246,11 +245,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_input_mapper("video") -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "video", get_max_llava_next_video_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) +@MULTIMODAL_REGISTRY.register_processor(LlavaNextVideoMultiModalProcessor) class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 0bebc1c745e2b..1e51e09a24c18 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -3,47 +3,36 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.nn as nn -from PIL import Image -from transformers import (CLIPVisionConfig, LlavaOnevisionConfig, - SiglipVisionConfig) +from transformers import (BatchFeature, LlavaOnevisionConfig, + LlavaOnevisionProcessor) from transformers.models.llava_onevision.modeling_llava_onevision import ( get_anyres_image_grid_shape, unpad_image) from typing_extensions import NotRequired from vllm.attention import AttentionMetadata from vllm.config import VllmConfig -from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, - InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.inputs import NestedTensors -from vllm.multimodal.utils import (cached_get_tokenizer, - repeat_and_pad_placeholder_tokens) +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors +from vllm.multimodal.parse import (MultiModalDataItems, VideoEmbeddingItems, + VideoProcessorItems) +from vllm.multimodal.processing import (MultiModalFieldConfig, ProcessorInputs, + PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of -from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, - dummy_video_for_clip, get_clip_image_feature_size, - get_clip_patch_grid_length, input_processor_for_clip) +from .clip import CLIPVisionModel from .interfaces import SupportsMultiModal, SupportsPP from .llava import init_vision_tower_for_llava -from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, - dummy_video_for_siglip, get_siglip_image_feature_size, - get_siglip_patch_grid_length, input_processor_for_siglip) +from .llava_next import LlavaNextMultiModalProcessor +from .siglip import SiglipVisionModel from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) -# Result in the max possible feature size (2x2 grid of 336x336px tiles) -MAX_IMAGE_FEATURE_SIZE_HEIGHT = MAX_IMAGE_FEATURE_SIZE_WIDTH = 448 - -# For profile run -_MAX_FRAMES_PER_VIDEO = 16 - class LlavaOnevisionVideoPixelInputs(TypedDict): type: Literal["pixel_values_videos"] @@ -92,286 +81,251 @@ class LlavaOnevisionImageEmbeddingInputs(TypedDict): LlavaOnevisionVideoPixelInputs] -def _get_llava_onevision_image_unppaded_feature_size(height, width, patches, - scale_height, - scale_width): - current_height = patches * scale_height - current_width = patches * scale_width - - original_aspect_ratio = width / height - current_aspect_ratio = current_width / current_height - if original_aspect_ratio > current_aspect_ratio: - new_height = int(height * (current_width / width)) - padding = (current_height - new_height) // 2 - current_height -= padding * 2 - else: - new_width = int(width * (current_height / height)) - padding = (current_width - new_width) // 2 - current_width -= padding * 2 - - unpadded_features = current_height * current_width - newline_features = current_height - - ratio = math.sqrt(current_height * current_width / (9 * patches**2)) - if ratio > 1.1: - unpadded_features = int(current_height // ratio) * int( - current_width // ratio) - newline_features = int(current_height // ratio) - - return (unpadded_features, newline_features) - - -def get_llava_onevision_image_feature_size( - hf_config: LlavaOnevisionConfig, - *, - input_height: int, - input_width: int, -) -> int: - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - num_patches = get_clip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_clip_image_feature_size(vision_config) - elif isinstance(vision_config, SiglipVisionConfig): - num_patches = get_siglip_patch_grid_length( - image_size=vision_config.image_size, - patch_size=vision_config.patch_size, - ) - base_feature_size = get_siglip_image_feature_size(vision_config) - else: - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - strategy = hf_config.vision_feature_select_strategy - if strategy == "default": - base_feature_size -= 1 - elif strategy == "full": - pass - else: - raise ValueError(f"Unexpected select feature strategy: {strategy}") +class LlavaOnevisionMultiModalProcessor(LlavaNextMultiModalProcessor): - num_patch_height, num_patch_width = get_anyres_image_grid_shape( - image_size=(input_height, input_width), - grid_pinpoints=hf_config.image_grid_pinpoints, - patch_size=vision_config.image_size, - ) + def _get_hf_config(self) -> LlavaOnevisionConfig: + return self.ctx.get_hf_config(LlavaOnevisionConfig) - ( - unpadded_feature_size, - newline_feature_size, - ) = _get_llava_onevision_image_unppaded_feature_size( - input_height, input_width, num_patches, num_patch_height, - num_patch_width) - - return unpadded_feature_size + newline_feature_size + base_feature_size - - -def get_max_llava_onevision_image_tokens(ctx: InputContext): - return get_llava_onevision_image_feature_size( - ctx.get_hf_config(LlavaOnevisionConfig), - input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, - input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, - ) - - -def get_llava_onevision_video_frame_feature_size( - hf_config: LlavaOnevisionConfig) -> int: - # Support both CLIPVisionConfig and SiglipVisionConfig - image_size = hf_config.vision_config.image_size - patch_size = hf_config.vision_config.patch_size - spatial_pool_stride = hf_config.spatial_pool_stride if hasattr( - hf_config, "spatial_pool_stride") else 2 - - height = width = image_size // patch_size - return math.ceil(height / spatial_pool_stride) * math.ceil( - width / spatial_pool_stride) - - -def get_llava_onevision_video_tokens(ctx: InputContext, - num_frames: int) -> int: - hf_config = ctx.get_hf_config(LlavaOnevisionConfig) - - # TODO: support configuring (not supported by HF right now) - num_token_image_newline = 1 - tokens_per_frame = get_llava_onevision_video_frame_feature_size(hf_config) - video_feature_size = num_frames * tokens_per_frame + num_token_image_newline - - return video_feature_size - - -def get_max_llava_onevision_video_tokens(ctx: InputContext) -> int: - return get_llava_onevision_video_tokens(ctx, _MAX_FRAMES_PER_VIDEO) - - -def dummy_data_for_llava_onevision(ctx: InputContext, seq_len: int, - mm_counts: Mapping[str, int]): - hf_config = ctx.get_hf_config(LlavaOnevisionConfig) - vision_config = hf_config.vision_config - - num_videos = mm_counts["video"] - - # TODO: support configuring the number of frames - num_frames = _MAX_FRAMES_PER_VIDEO - video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) - - if isinstance(vision_config, CLIPVisionConfig): - seq_data, ranges = dummy_seq_data_for_clip( - vision_config, - seq_len, - num_videos, - image_token_id=hf_config.video_token_index, - image_feature_size_override=video_feature_size, - mm_key="video") - - mm_data = dummy_video_for_clip(vision_config, - num_frames=num_frames, - num_videos=num_videos) - return DummyData(seq_data, mm_data, ranges) - elif isinstance(vision_config, SiglipVisionConfig): - seq_data, ranges = dummy_seq_data_for_siglip( - vision_config, - seq_len, - num_videos, - image_token_id=hf_config.video_token_index, - image_feature_size_override=video_feature_size, - mm_key="video") - - mm_data = dummy_video_for_siglip(vision_config, - num_frames=num_frames, - num_videos=num_videos) - return DummyData(seq_data, mm_data, ranges) - - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) - - -def input_processor_when_multimodal_input_image(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - return inputs - - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaOnevisionConfig) - vision_config = hf_config.vision_config - - image_data = multi_modal_data["image"] - if isinstance(image_data, Image.Image): - width, height = image_data.size - - image_feature_size = get_llava_onevision_image_feature_size( - hf_config, - input_height=height, - input_width=width, - ) - elif is_list_of(image_data, Image.Image): - image_feature_size = [ - get_llava_onevision_image_feature_size(hf_config, - input_height=img.height, - input_width=img.width) - for img in image_data - ] - elif isinstance(image_data, torch.Tensor): - num_images, image_feature_size, hidden_size = image_data.shape - elif is_list_of(image_data, torch.Tensor): - image_feature_size = [item.shape[1] for item in image_data] - else: - raise TypeError(f"Invalid image type: {type(image_data)}") - - vision_config = hf_config.vision_config - - if isinstance(vision_config, CLIPVisionConfig): - return input_processor_for_clip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + def _get_hf_processor(self) -> LlavaOnevisionProcessor: + return self.ctx.get_hf_processor(LlavaOnevisionProcessor) + + def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + return {"image": None, "video": None} + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + max_image_tokens = self._get_max_image_tokens() + + num_frames = self._get_dummy_num_frames(seq_len) + max_video_tokens = self._get_max_video_tokens(num_frames) + + return { + "image": max_image_tokens, + "video": max_video_tokens, + } + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + pixel_values=MultiModalFieldConfig.batched("image"), + image_sizes=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), + pixel_values_videos=MultiModalFieldConfig.batched("video"), ) - elif isinstance(vision_config, SiglipVisionConfig): - return input_processor_for_siglip( - model_config, - vision_config, - inputs, - image_token_id=hf_config.image_token_index, - image_feature_size_override=image_feature_size, + + def _get_num_unpadded_features( + self, + *, + original_height: int, + original_width: int, + npatches: int, + num_patch_height: int, + num_patch_width: int, + ) -> tuple[int, int]: + current_height = npatches * num_patch_height + current_width = npatches * num_patch_width + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + if original_aspect_ratio > current_aspect_ratio: + new_height = int(original_height * + (current_width / original_width)) + padding = (current_height - new_height) // 2 + current_height -= padding * 2 + else: + new_width = int(original_width * + (current_height / original_height)) + padding = (current_width - new_width) // 2 + current_width -= padding * 2 + + unpadded_features = current_height * current_width + newline_features = current_height + + ratio = math.sqrt(current_height * current_width / (9 * npatches**2)) + if ratio > 1.1: + unpadded_features = int(current_height // ratio) * int( + current_width // ratio) + newline_features = int(current_height // ratio) + + return (unpadded_features, newline_features) + + def _get_num_frame_tokens( + self, + *, + image_width: int, + image_height: int, + ) -> int: + hf_config = self._get_hf_config() + spatial_pool_stride = getattr(hf_config, "spatial_pool_stride", 2) + + patch_grid_length = self._vision_encoder_info.get_patch_grid_length() + pooled_grid_length = math.ceil(patch_grid_length / spatial_pool_stride) + + return pooled_grid_length * pooled_grid_length + + def _get_num_video_tokens( + self, + *, + image_width: int, + image_height: int, + num_frames: int, + ) -> int: + num_frame_tokens = self._get_num_frame_tokens( + image_width=image_width, + image_height=image_height, ) - msg = f"Unsupported vision config: {type(vision_config)}" - raise NotImplementedError(msg) + return num_frame_tokens * num_frames + 1 # Newline token + + def _get_max_video_tokens(self, num_frames: int) -> int: + return self._get_num_video_tokens(image_width=999999, + image_height=999999, + num_frames=num_frames) + def _get_max_video_frames(self, max_tokens: int) -> int: + num_frames = 0 -def input_processor_when_multimodal_input_video(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or "video" not in multi_modal_data: - return inputs - video_data = multi_modal_data["video"] + while True: + next_num_frames = num_frames + 1 - model_config = ctx.model_config - hf_config = ctx.get_hf_config(LlavaOnevisionConfig) + if self._get_max_video_tokens(next_num_frames) > max_tokens: + break - if isinstance(video_data, np.ndarray): - # Supports both CLIP and Siglip - num_frames = video_data.shape[0] - video_feature_size = get_llava_onevision_video_tokens(ctx, num_frames) - tokenizer = cached_get_tokenizer(model_config.tokenizer) + num_frames = next_num_frames - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=hf_config.video_token_index, - repeat_count=video_feature_size, + return num_frames + + def _get_dummy_num_frames(self, seq_len: int) -> int: + mm_config = self.ctx.get_mm_config() + max_images = mm_config.limit_per_prompt.get("image", 1) + max_videos = mm_config.limit_per_prompt.get("video", 1) + + max_image_tokens = self._get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + + return max(max_total_frames // max(max_videos, 1), 1) + + def _get_video_token(self) -> str: + return self._get_hf_processor().video_token + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + ) -> BatchFeature: + mm_data = dict(mm_data) + videos = mm_data.pop("videos", []) + assert isinstance(videos, list) + + if not videos: + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + video_token = self._get_video_token() + + # LLaVA-OneVision processor doesn't support multiple videos + # with different sizes when converting back to tensors + text_image_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + ) + + pixel_values_videos = [] + for video in videos: + item_processor_data = dict(prompt=video_token, videos=video) + + item_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=item_processor_data, + mm_kwargs=mm_kwargs, + ) + + pixel_values_videos.append( + item_outputs.pop("pixel_values_videos")[0]) + + combined_outputs = dict( + **text_image_outputs, + pixel_values_videos=pixel_values_videos, ) + return BatchFeature(combined_outputs) - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"video": ranges}) - - elif is_list_of(video_data, np.ndarray): - video_feature_size = [] - for video in video_data: - num_frames = video.shape[0] - video_feature_size.append( - get_llava_onevision_video_tokens(ctx, num_frames)) - - tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids, ranges = repeat_and_pad_placeholder_tokens( - tokenizer, - inputs.get("prompt"), - inputs["prompt_token_ids"], - placeholder_token_id=hf_config.video_token_index, - repeat_count=video_feature_size, + def _get_prompt_replacements( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargs, + ) -> list[PromptReplacement]: + image_repls = super()._get_prompt_replacements( + mm_items=mm_items, + hf_processor_mm_kwargs=hf_processor_mm_kwargs, + out_mm_kwargs=out_mm_kwargs, ) - return token_inputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data, - multi_modal_placeholders={"video": ranges}) - else: - raise TypeError(f"Invalid video type: {type(video_data)}") - msg = f"Unsupported video type: {type(video_data)}" - raise NotImplementedError(msg) + hf_config = self._get_hf_config() + video_token_id = hf_config.video_token_index + def get_video_replacement(item_idx: int): + videos = mm_items.get_items( + "video", (VideoEmbeddingItems, VideoProcessorItems)) -def input_processor_for_llava_onevision(ctx: InputContext, - inputs: DecoderOnlyInputs): - multi_modal_data = inputs.get("multi_modal_data") - if multi_modal_data is None or ("video" not in multi_modal_data - and "image" not in multi_modal_data): - return inputs - if "image" in multi_modal_data: - return input_processor_when_multimodal_input_image(ctx, inputs) - if "video" in multi_modal_data: - return input_processor_when_multimodal_input_video(ctx, inputs) + if isinstance(videos, VideoEmbeddingItems): + num_video_tokens = videos.get_feature_size(item_idx) + else: + image_size = videos.get_frame_size(item_idx) + num_video_tokens = self._get_num_video_tokens( + image_width=image_size.width, + image_height=image_size.height, + num_frames=videos.get_num_frames(item_idx), + ) + + return [video_token_id] * num_video_tokens - msg = "Unsupported multi data type" - raise NotImplementedError(msg) + return image_repls + [ + PromptReplacement( + modality="video", + target=[video_token_id], + replacement=get_video_replacement, + ), + ] + + def _get_dummy_processor_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> ProcessorInputs: + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + + image_token = self._get_image_token() + video_token = self._get_video_token() + target_width, target_height = self._get_dummy_image_size() + + mm_data = { + "image": + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) + } + + return ProcessorInputs( + prompt_text=image_token * num_images + video_token * num_videos, + mm_data=mm_data, + ) class LlavaOnevisionMultiModalProjector(nn.Module): @@ -394,14 +348,7 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: return hidden_states -@MULTIMODAL_REGISTRY.register_image_input_mapper() -@MULTIMODAL_REGISTRY.register_input_mapper("video") -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "image", get_max_llava_onevision_image_tokens) -@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "video", get_max_llava_onevision_video_tokens) -@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) -@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) +@MULTIMODAL_REGISTRY.register_processor(LlavaOnevisionMultiModalProcessor) class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index f2e49d8e4848d..7aa9d58d1d348 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -323,7 +323,7 @@ def _get_num_image_tokens( height=image_height, ) - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: max_image_tokens = self._get_num_image_tokens( image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH, image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT, @@ -415,12 +415,12 @@ def get_replacement_phi3v(item_idx: int): def _apply_prompt_replacements( self, token_ids: list[int], - prompt_repls: Sequence[_BoundPromptReplacement], + mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: token_ids, text, placeholders = super()._apply_prompt_replacements( token_ids=token_ids, - prompt_repls=prompt_repls, + mm_prompt_repls=mm_prompt_repls, mm_item_counts=mm_item_counts, ) @@ -428,15 +428,23 @@ def _apply_prompt_replacements( if text.startswith(" <|image|>"): text = text.replace(" <|image|>", "<|image|>", 1) token_ids = [token_ids[0], *token_ids[2:]] - placeholders = [ - _PlaceholderInfo(p.modality, p.start_idx - 1, p.replacement) - for p in placeholders - ] + placeholders = { + modality: [ + _PlaceholderInfo( + modality=p.modality, + item_idx=p.item_idx, + start_idx=p.start_idx - 1, + replacement=p.replacement, + ) for p in ps + ] + for modality, ps in placeholders.items() + } return token_ids, text, placeholders - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: num_images = mm_counts.get("image", 0) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d7233bd6028ed..9e1d38512c0b4 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -780,15 +780,18 @@ def get_num_image_tokens( def get_max_image_tokens(self) -> int: return get_max_pixtral_hf_image_tokens(self.vision_config) - def get_num_patches(self) -> int: + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: return get_pixtral_hf_patch_grid_length( image_size=self.vision_config.image_size, patch_size=self.vision_config.patch_size, ) - def get_image_size(self) -> int: - return self.vision_config.image_size - class PixtralHFMLP(nn.Module): diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index d050fd060353a..bc3bb1f79b407 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -84,7 +84,7 @@ class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: hf_config = self.ctx.get_hf_config(Qwen2AudioConfig) max_source_positions = hf_config.audio_config.max_source_positions max_output_lengths = (max_source_positions - 2) // 2 + 1 @@ -184,15 +184,16 @@ def get_replacement_qwen2_audio(item_idx: int): ] def _always_apply_prompt_replacements(self) -> bool: - # HF never applies prompt replacements, so we have to do it ourselves - # _find_placeholders may incorrectly think that HF has already performed - # processing for multi-audio input when the input audios are short - # (the corresponding placeholders may take up fewer tokens than - # the number of audio items) + # HF never applies prompt replacements, so we have to do it ourselves. + # NOTE: `_find_placeholders_by_modality` may incorrectly think that HF + # has already performed processing for multi-audio input when the input + # audios are short (the corresponding placeholders may take up fewer + # tokens than the number of audio items) return True - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: feature_extractor = self._get_feature_extractor() diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 5a8c6e4deb7ac..abca85e0e2024 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -56,7 +56,8 @@ from vllm.multimodal.inputs import (ImageItem, ModalityData, MultiModalFieldConfig, MultiModalKwargs, NestedTensors, VideoItem) -from vllm.multimodal.parse import ModalityDataItems, MultiModalDataParser +from vllm.multimodal.parse import (ImageSize, ModalityDataItems, + MultiModalDataParser) from vllm.multimodal.processing import (BaseMultiModalProcessor, MultiModalDataItems, ProcessorInputs, PromptReplacement) @@ -641,58 +642,6 @@ def load_weights(self, weights: Iterable[Tuple[str, return loaded_params -# === Vision input helpers === # - - -def _get_vision_info( - vision_config: Qwen2VLVisionConfig, - height: int, - width: int, - min_pixels: int, - max_pixels: int, - *, - do_resize: bool = True, - modality: str = "image", - mm_count: int = 1, -): - """Get information (resized height / width and number of vision tokens) - of input image / video frame.""" - patch_size = vision_config.patch_size - merge_size = vision_config.spatial_merge_size - temporal_patch_size = vision_config.temporal_patch_size - - if do_resize: - resized_height, resized_width = smart_resize( - height=height, - width=width, - factor=patch_size * merge_size, - min_pixels=min_pixels, - max_pixels=max_pixels, - ) - else: - resized_height, resized_width = height, width - - if modality == "image": - grid_t = mm_count - elif modality == "video": - grid_t = max(mm_count // temporal_patch_size, 1) - else: - raise ValueError(f"Modality {modality} is not supported") - - grid_h = resized_height // patch_size - grid_w = resized_width // patch_size - vision_tokens = grid_t * grid_h * grid_w - llm_num_vision_tokens = vision_tokens // (merge_size**2) - - return resized_height, resized_width, llm_num_vision_tokens - - -def _get_image_processor(hf_processor: Qwen2VLProcessor): - image_processor = hf_processor.image_processor # type: ignore - assert isinstance(image_processor, Qwen2VLImageProcessor) - return image_processor - - class Qwen2EmbeddingItems(ModalityDataItems[dict[str, torch.Tensor], dict[str, torch.Tensor]]): @@ -764,32 +713,111 @@ class Qwen2VLMultiModalProcessor(BaseMultiModalProcessor): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"image": None, "video": None} - def _get_max_mm_tokens(self, modality: str) -> int: + def _get_vision_info( + self, + *, + image_width: int, + image_height: int, + num_frames: int = 1, + do_resize: bool = True, + ) -> tuple[ImageSize, int]: hf_config = self.ctx.get_hf_config(Qwen2VLConfig) vision_config = hf_config.vision_config + patch_size = vision_config.patch_size + merge_size = vision_config.spatial_merge_size + temporal_patch_size = vision_config.temporal_patch_size hf_processor = self._get_hf_processor() - image_processor = _get_image_processor(hf_processor) - - _, _, max_llm_image_tokens = _get_vision_info( - vision_config, - height=9999999, - width=9999999, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - modality=modality, + image_processor = self._get_image_processor(hf_processor) + + if do_resize: + resized_height, resized_width = smart_resize( + height=image_height, + width=image_width, + factor=patch_size * merge_size, + min_pixels=image_processor.min_pixels, + max_pixels=image_processor.max_pixels, + ) + preprocessed_size = ImageSize(width=resized_width, + height=resized_height) + else: + preprocessed_size = ImageSize(width=image_width, + height=image_height) + + grid_t = max(num_frames // temporal_patch_size, 1) + grid_h = preprocessed_size.height // patch_size + grid_w = preprocessed_size.width // patch_size + + num_patches = grid_t * grid_h * grid_w + num_vision_tokens = num_patches // (merge_size**2) + + return preprocessed_size, num_vision_tokens + + def _get_dummy_image_size(self) -> ImageSize: + max_image_size, _ = self._get_vision_info( + image_width=9999999, + image_height=9999999, + ) + return max_image_size + + def _get_max_image_tokens(self) -> int: + _, max_image_tokens = self._get_vision_info( + image_width=9999999, + image_height=9999999, + ) + return max_image_tokens + + def _get_max_video_tokens(self, num_frames: int) -> int: + _, max_video_tokens = self._get_vision_info( + image_width=9999999, + image_height=9999999, + num_frames=num_frames, ) - return max_llm_image_tokens + return max_video_tokens + + def _get_max_video_frames(self, max_tokens: int) -> int: + num_frames = 0 + + while True: + next_num_frames = num_frames + 1 + + if self._get_max_video_tokens(next_num_frames) > max_tokens: + break + + num_frames = next_num_frames + + return num_frames + + def _get_dummy_num_frames(self, seq_len: int) -> int: + mm_config = self.ctx.get_mm_config() + max_images = mm_config.limit_per_prompt.get("image", 1) + max_videos = mm_config.limit_per_prompt.get("video", 1) + + max_image_tokens = self._get_max_image_tokens() * max_images + max_total_frames = self._get_max_video_frames(seq_len - + max_image_tokens) + + return max(max_total_frames // max(max_videos, 1), 1) + + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: + max_image_tokens = self._get_max_image_tokens() + + num_frames = self._get_dummy_num_frames(seq_len) + max_video_tokens = self._get_max_video_tokens(num_frames) - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: return { - "image": self._get_max_mm_tokens("image"), - "video": self._get_max_mm_tokens("video"), + "image": max_image_tokens, + "video": max_video_tokens, } def _get_data_parser(self) -> MultiModalDataParser: return Qwen2MultiModalDataParser() + def _get_image_processor(self, hf_processor: Qwen2VLProcessor): + image_processor = hf_processor.image_processor # type: ignore + assert isinstance(image_processor, Qwen2VLImageProcessor) + return image_processor + def _get_hf_processor( self, *, @@ -797,7 +825,7 @@ def _get_hf_processor( max_pixels: Optional[int] = None, ) -> Qwen2VLProcessor: hf_processor = self.ctx.get_hf_processor(Qwen2VLProcessor) - image_processor = _get_image_processor(hf_processor) + image_processor = self._get_image_processor(hf_processor) if min_pixels: image_processor.min_pixels = min_pixels @@ -818,7 +846,7 @@ def _get_prompt_replacements( out_mm_kwargs: MultiModalKwargs, ) -> list[PromptReplacement]: hf_processor = self._get_hf_processor() - image_processor = _get_image_processor(hf_processor) + image_processor = self._get_image_processor(hf_processor) # NOTE: Only Qwen2VLProcessor in transformers 4.47.0 has # image_token and video_token registered @@ -873,32 +901,35 @@ def _get_mm_fields_config( video_grid_thw=MultiModalFieldConfig.batched("video"), ) - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: - hf_processor = self._get_hf_processor() - image_processor = _get_image_processor(hf_processor) + num_images = mm_counts.get("image", 0) + num_videos = mm_counts.get("video", 0) + hf_processor = self._get_hf_processor() image_token: str = hf_processor.image_token - resized_height, resized_width = smart_resize( - height=9999999, - width=9999999, - factor=image_processor.patch_size * image_processor.merge_size, - min_pixels=image_processor.min_pixels, - max_pixels=image_processor.max_pixels, - ) - num_images = mm_counts.get("image", 0) + video_token: str = hf_processor.video_token + target_width, target_height = self._get_dummy_image_size() mm_data = { "image": - self._get_dummy_images(width=resized_width, - height=resized_height, - num_images=num_images) + self._get_dummy_images(width=target_width, + height=target_height, + num_images=num_images), + "video": + self._get_dummy_videos( + width=target_width, + height=target_height, + num_frames=self._get_dummy_num_frames(seq_len), + num_videos=num_videos, + ) } return ProcessorInputs( - prompt_text=image_token * num_images, + prompt_text=image_token * num_images + video_token * num_videos, mm_data=mm_data, ) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 115eaaac900e0..7ea177e94afc0 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -171,15 +171,18 @@ def get_num_image_tokens( def get_max_image_tokens(self) -> int: return get_max_siglip_image_tokens(self.vision_config) - def get_num_patches(self) -> int: + def get_image_size(self) -> int: + return self.vision_config.image_size + + def get_patch_size(self) -> int: + return self.vision_config.patch_size + + def get_patch_grid_length(self) -> int: return get_siglip_patch_grid_length( image_size=self.vision_config.image_size, patch_size=self.vision_config.patch_size, ) - def get_image_size(self) -> int: - return self.vision_config.image_size - # Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa class SiglipVisionEmbeddings(nn.Module): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 0b83684c9bac5..6ad4661e3bb8d 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -6,7 +6,6 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple, TypedDict, Union) -import numpy as np import torch import torch.utils.checkpoint from torch import nn @@ -31,7 +30,6 @@ PromptReplacement) from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.ultravox import UltravoxConfig -from vllm.utils import is_list_of from .interfaces import SupportsMultiModal, SupportsPP from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, @@ -62,7 +60,7 @@ class UltravoxMultiModalProcessor(BaseMultiModalProcessor): def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: return {"audio": None} - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: feature_extractor = self._get_feature_extractor() max_audio_tokens = math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) @@ -103,6 +101,7 @@ def _call_hf_processor( mm_data = dict(mm_data) audios = mm_data.pop("audios", []) + assert isinstance(audios, list) if not audios: return super()._call_hf_processor( @@ -117,9 +116,6 @@ def _call_hf_processor( sampling_rate=feature_extractor.sampling_rate, ) - # Already resampled by _get_hf_mm_data - assert is_list_of(audios, np.ndarray) - # Ultravox processor doesn't support multiple inputs, # therefore we need to input text and audio one by one audio_features, audio_token_len = [], [] @@ -177,8 +173,9 @@ def get_replacement_ultravox(item_idx: int): ) ] - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: feature_extractor = self._get_feature_extractor() diff --git a/vllm/model_executor/models/vision.py b/vllm/model_executor/models/vision.py index 65a773480d2a1..014f02ee10a1b 100644 --- a/vllm/model_executor/models/vision.py +++ b/vllm/model_executor/models/vision.py @@ -1,8 +1,12 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar +from typing import Final, Generic, Optional, Protocol, TypeVar from transformers import PretrainedConfig +from vllm.multimodal.processing import (BaseMultiModalProcessor, + InputProcessingContext, + ProcessingCache) + _C = TypeVar("_C", bound=PretrainedConfig) @@ -27,11 +31,15 @@ def get_max_image_tokens(self) -> int: raise NotImplementedError @abstractmethod - def get_num_patches(self) -> int: + def get_image_size(self) -> int: raise NotImplementedError @abstractmethod - def get_image_size(self) -> int: + def get_patch_size(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_patch_grid_length(self) -> int: raise NotImplementedError @@ -50,3 +58,26 @@ def vision_encoder_info(vision_config: PretrainedConfig) -> VisionEncoderInfo: msg = f"Unsupported vision config: {type(vision_config)}" raise NotImplementedError(msg) + + +class VisionLanguageConfig(Protocol): + vision_config: Final[PretrainedConfig] + + +class BaseVisionLanguageMultiModalProcessor(BaseMultiModalProcessor): + + def __init__(self, + ctx: InputProcessingContext, + *, + cache: Optional[ProcessingCache] = None, + enable_sanity_checks: bool = True) -> None: + super().__init__(ctx, + cache=cache, + enable_sanity_checks=enable_sanity_checks) + + vision_config = self._get_hf_config().vision_config + self._vision_encoder_info = vision_encoder_info(vision_config) + + @abstractmethod + def _get_hf_config(self) -> VisionLanguageConfig: + raise NotImplementedError diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 00acb77435163..6be046ba77ca7 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -146,6 +146,20 @@ class VideoProcessorItems(ProcessorBatchItems[HfVideoItem]): def __init__(self, data: Sequence[HfVideoItem]) -> None: super().__init__(data, "video") + def get_num_frames(self, item_idx: int) -> int: + return len(self.get(item_idx)) + + def get_frame_size(self, item_idx: int) -> ImageSize: + image = self.get(item_idx)[0] # Assume that the video isn't empty + + if isinstance(image, Image): + return ImageSize(*image.size) + if isinstance(image, (np.ndarray, torch.Tensor)): + _, h, w = image.shape + return ImageSize(w, h) + + assert_never(image) + class VideoEmbeddingItems(EmbeddingItems): diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index eb7552176e974..ebc16b817684a 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -16,7 +16,8 @@ from vllm.inputs import DummyData, InputProcessingContext from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer import AnyTokenizer, encode_tokens +from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, + encode_tokens) from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from .inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -69,19 +70,6 @@ def _cached_encode( add_special_tokens=add_special_tokens) -def _decode( - tokenizer: AnyTokenizer, - token_ids: list[int], - *, - skip_special_tokens: bool = False, -) -> str: - """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. - """ - return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) - - @lru_cache(maxsize=2048) def _cached_decode( tokenizer: AnyTokenizer, @@ -89,9 +77,9 @@ def _cached_decode( *, skip_special_tokens: bool = False, ) -> str: - return _decode(tokenizer, - list(token_ids), - skip_special_tokens=skip_special_tokens) + return decode_tokens(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) class _HasModalityAttr(Protocol): @@ -269,8 +257,10 @@ def end_idx(self) -> int: return self.match.end() -class _PlaceholderInfo(NamedTuple): +@dataclass +class _PlaceholderInfo: modality: str + item_idx: int start_idx: int replacement: list[int] @@ -311,12 +301,14 @@ def find_text_matches( def _resolve_matches( prompt: _PromptSeq, - matches: Sequence[_PromptReplacementMatch], + mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], ) -> list[_PromptReplacementMatch]: """ - Resolve :code:`matches` to ensure that there are no overlapping matches, + Resolve :code:`mm_matches` to ensure that there are no overlapping matches, and sort them such that earlier matches take priority over later ones. """ + matches = [m for matches in mm_matches.values() for m in matches] + seen_matches: list[Optional[_PromptReplacementMatch]] = [None ] * len(prompt) @@ -334,14 +326,15 @@ def _resolve_matches( def _replace_matches( prompt: _S, - matches: Sequence[_PromptReplacementMatch], + mm_matches: Mapping[str, Sequence[_PromptReplacementMatch]], mm_item_counts: Mapping[str, int], ) -> list[_S]: + """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" out_seqs = list[_S]() prev_end_idx = 0 next_idx_by_modality = defaultdict[str, int](lambda: 0) - for match in _resolve_matches(prompt, matches): + for match in _resolve_matches(prompt, mm_matches): modality = match.modality item_idx = next_idx_by_modality[modality] @@ -371,28 +364,28 @@ def _replace_matches( def replace_token_matches( prompt: list[int], - matches: Sequence[_PromptReplacementTokenMatch], + mm_matches: Mapping[str, Sequence[_PromptReplacementTokenMatch]], mm_item_counts: Mapping[str, int], ) -> list[int]: - """Apply :code:`prompt_repls` to :code:`prompt`.""" - if not matches: + """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + if not mm_matches: return prompt - token_id_seqs = _replace_matches(prompt, matches, mm_item_counts) + token_id_seqs = _replace_matches(prompt, mm_matches, mm_item_counts) return flatten_2d_lists(token_id_seqs) def replace_text_matches( prompt: str, - matches: Sequence[_PromptReplacementTextMatch], + mm_matches: Mapping[str, Sequence[_PromptReplacementTextMatch]], mm_item_counts: Mapping[str, int], ) -> str: - """Apply :code:`prompt_repls` to :code:`prompt`.""" - if not matches: + """Apply the replacements in :code:`mm_matches` to :code:`prompt`.""" + if not mm_matches: return prompt - texts = _replace_matches(prompt, matches, mm_item_counts) + texts = _replace_matches(prompt, mm_matches, mm_item_counts) return "".join(texts) @@ -407,14 +400,14 @@ def _iter_modality_placeholders( return prompt_len = len(prompt) - item_index = 0 + item_idx = 0 start_idx = 0 while start_idx < prompt_len: found = False for repl_info in modality_repls: - replacement = repl_info.get_replacement(item_index) + replacement = repl_info.get_replacement(item_idx) repl_tokens = replacement.token_ids repl_len = len(repl_tokens) end_idx = start_idx + repl_len @@ -425,12 +418,13 @@ def _iter_modality_placeholders( if prompt[start_idx:end_idx] == repl_tokens: yield _PlaceholderInfo( modality=modality, + item_idx=item_idx, start_idx=start_idx, replacement=repl_tokens, ) - item_index += 1 - if item_index >= modal_item_count: + item_idx += 1 + if item_idx >= modal_item_count: return # Exclude overlapping matches @@ -442,28 +436,36 @@ def _iter_modality_placeholders( start_idx += 1 -def iter_placeholders( - prompt_repls: Sequence[_BoundPromptReplacement], +def _iter_placeholders( + mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], prompt: list[int], mm_item_counts: Mapping[str, int], ) -> Iterable[_PlaceholderInfo]: """ - Yield each set of placeholder tokens found in :code:`prompt`. + For each modality, yield each set of placeholder tokens found in + :code:`prompt`. Note that empty matches are ignored. """ - repls_by_modality = dict(full_groupby_modality(prompt_repls)) - for modality, modal_item_count in mm_item_counts.items(): - if modality in repls_by_modality: + if modality in mm_prompt_repls: yield from _iter_modality_placeholders( prompt, modality, - repls_by_modality[modality], + mm_prompt_repls[modality], modal_item_count, ) +def find_mm_placeholders( + mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], + prompt: list[int], + mm_item_counts: Mapping[str, int], +) -> Mapping[str, list[_PlaceholderInfo]]: + it = _iter_placeholders(mm_prompt_repls, prompt, mm_item_counts) + return dict(full_groupby_modality(it)) + + @dataclass class ProcessorInputs: """Keyword arguments to :meth:`BaseMultiModalProcessor`.""" @@ -620,7 +622,7 @@ def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: raise NotImplementedError @abstractmethod - def get_mm_max_tokens_per_item(self) -> Mapping[str, int]: + def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]: """ Get the maximum possible number of tokens per data item for each modality. @@ -703,14 +705,14 @@ def _get_prompt_replacements( """ raise NotImplementedError - def _find_placeholders( + def _find_mm_placeholders( self, - all_prompt_repls: Sequence[_BoundPromptReplacement], + mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], new_token_ids: list[int], mm_item_counts: Mapping[str, int], - ) -> list[_PlaceholderInfo]: - return list( - iter_placeholders(all_prompt_repls, new_token_ids, mm_item_counts)) + ) -> Mapping[str, list[_PlaceholderInfo]]: + return find_mm_placeholders(mm_prompt_repls, new_token_ids, + mm_item_counts) def _get_hf_mm_data( self, @@ -797,7 +799,10 @@ def _apply_hf_processor_missing( # Some HF processors (e.g. Qwen2-VL) expect corresponding # multi-modal tokens to be in the prompt text - dummy_inputs = self._get_dummy_mm_inputs(mm_missing_counts) + dummy_inputs = self._get_dummy_processor_inputs( + self.ctx.model_config.max_model_len, + mm_missing_counts, + ) _, mm_missing_kwargs = self._apply_hf_processor( prompt_text=dummy_inputs.prompt_text, @@ -889,50 +894,44 @@ def _cached_apply_hf_processor( mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) - if self.enable_sanity_checks: - mm_item_counts = mm_data_items.get_all_counts() - - for modality, item_count in mm_item_counts.items(): - for item_idx in range(item_count): - try: - mm_kwargs.get_item(modality, item_idx) - except Exception as e: - # Make it easy to set a breakpoint in the debugger - raise e - return prompt_ids, mm_kwargs - def _bind_prompt_replacements( + def _bind_and_group_repls( self, prompt_repls: list[PromptReplacement], - ) -> list[_BoundPromptReplacement]: + ) -> dict[str, list[_BoundPromptReplacement]]: tokenizer = self._get_tokenizer() - return [prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls] + it = (prompt_repl.bind(tokenizer) for prompt_repl in prompt_repls) + return dict(full_groupby_modality(it)) def _always_apply_prompt_replacements(self) -> bool: """ A flag which can be overridden so that :meth:`_apply_prompt_replacements` is always called even if we - detect that HF has performed processing via :meth:`_find_placeholders`. + detect that HF has performed processing via + :meth:`_find_placeholders_by_modality`. - This is useful in cases where :meth:`_find_placeholders` cannot be - reliably used to detect whether HF has performed processing or not. + This is useful in cases where :meth:`_find_placeholders_by_modality` + cannot be reliably used to detect whether HF has performed processing. """ return False def _apply_prompt_replacements( self, token_ids: list[int], - prompt_repls: Sequence[_BoundPromptReplacement], + mm_prompt_repls: Mapping[str, Sequence[_BoundPromptReplacement]], mm_item_counts: Mapping[str, int], - ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + ) -> tuple[list[int], str, Mapping[str, list[_PlaceholderInfo]]]: tokenizer = self._get_tokenizer() - token_matches = find_token_matches(token_ids, prompt_repls) + mm_token_matches = { + modality: find_token_matches(token_ids, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } mm_match_counts = { modality: len(matches) - for modality, matches in full_groupby_modality(token_matches) + for modality, matches in mm_token_matches.items() } # If the search text does not represent a special token, @@ -951,32 +950,92 @@ def _apply_prompt_replacements( ): # yapf: disable token_ids = replace_token_matches( token_ids, - token_matches, + mm_token_matches, mm_item_counts, ) - text = _decode(tokenizer, token_ids) - matched_repls = [match.prompt_repl for match in token_matches] + text = decode_tokens(tokenizer, token_ids) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_token_matches.items() + } else: - text = _decode(tokenizer, token_ids) + text = decode_tokens(tokenizer, token_ids) - text_matches = find_text_matches(text, prompt_repls) + mm_text_matches = { + modality: find_text_matches(text, prompt_repls) + for modality, prompt_repls in mm_prompt_repls.items() + } text = replace_text_matches( text, - text_matches, + mm_text_matches, mm_item_counts, ) token_ids = encode_tokens(tokenizer, text, add_special_tokens=False) - matched_repls = [match.prompt_repl for match in text_matches] - - placeholders = self._find_placeholders(matched_repls, token_ids, - mm_item_counts) + matched_repls = { + modality: [match.prompt_repl for match in token_matches] + for modality, token_matches in mm_text_matches.items() + } + + placeholders = self._find_mm_placeholders( + matched_repls, + token_ids, + mm_item_counts, + ) return token_ids, text, placeholders + def _validate_mm_kwargs( + self, + mm_kwargs: MultiModalKwargs, + mm_item_counts: Mapping[str, int], + ) -> None: + for modality, item_count in mm_item_counts.items(): + if modality in mm_kwargs.modalities: + items = mm_kwargs.get_items(modality) + else: + items = [] + + if len(items) != item_count: + raise RuntimeError( + f"Expected there to be {item_count} {modality} items in " + f"keyword arguments corresponding to {item_count} " + f"{modality} data items, but only found {len(items)}! " + "There is likely a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_mm_fields_config`).") + + def _validate_mm_placeholders( + self, + mm_placeholders: Mapping[str, list[_PlaceholderInfo]], + mm_item_counts: Mapping[str, int], + *, + allow_missing: bool = False, + ) -> Mapping[str, int]: + missing_repl_counts = dict[str, int]() + + for modality, item_count in mm_item_counts.items(): + placeholders = mm_placeholders.get(modality, []) + + if len(placeholders) != item_count and not allow_missing: + raise RuntimeError( + f"Expected there to be {item_count} prompt replacements " + f"corresponding to {item_count} {modality} items, but only " + f"found {len(placeholders)} prompt replacements! Either " + "the prompt text has missing/incorrect tokens for " + "multi-modal inputs, or there is a problem with your " + "implementation of merged multi-modal processor for this " + "model (usually arising from an inconsistency between " + "`_call_hf_processor` and `_get_prompt_replacements`).") + + missing_repl_counts[modality] = item_count - len(placeholders) + + return missing_repl_counts + def apply( self, prompt_text: str, @@ -1009,56 +1068,69 @@ def apply( hf_processor_mm_kwargs, mm_kwargs, ) - prompt_repls = self._bind_prompt_replacements(unbound_prompt_repls) + mm_prompt_repls = self._bind_and_group_repls(unbound_prompt_repls) - # If HF processor already inserts placeholder tokens, - # there is no need for us to insert them mm_item_counts = mm_items.get_all_counts() - all_placeholders = self._find_placeholders(prompt_repls, prompt_ids, - mm_item_counts) + self._validate_mm_kwargs(mm_kwargs, mm_item_counts) + + hf_mm_placeholders = self._find_mm_placeholders( + mm_prompt_repls, + prompt_ids, + mm_item_counts, + ) + + if self._always_apply_prompt_replacements(): + mm_missing_repl_counts = mm_item_counts + mm_missing_repls = dict(mm_prompt_repls) + else: + mm_missing_repl_counts = self._validate_mm_placeholders( + hf_mm_placeholders, + mm_item_counts, + allow_missing=True, + ) + + mm_missing_repls = dict[str, list[_BoundPromptReplacement]]() + for modality, missing_repl_count in mm_missing_repl_counts.items(): + if missing_repl_count == 0: + mm_missing_repls[modality] = [] + elif missing_repl_count == mm_item_counts.get(modality, 0): + mm_missing_repls[modality] = mm_prompt_repls[modality] + else: + raise ValueError("Partial prompt replacement within " + f"{modality=} is not supported") - if all_placeholders and not self._always_apply_prompt_replacements(): + # If HF processor already inserts placeholder tokens, + # there is no need for us to insert them + if all(len(repls) == 0 for repls in mm_missing_repls.items()): tokenizer = self._get_tokenizer() - prompt_text = _decode(tokenizer, prompt_ids) + prompt_text = decode_tokens(tokenizer, prompt_ids) + mm_placeholders = hf_mm_placeholders else: ( prompt_ids, prompt_text, - all_placeholders, + missing_mm_placeholders, ) = self._apply_prompt_replacements( prompt_ids, - prompt_repls, - mm_item_counts, + mm_missing_repls, + mm_missing_repl_counts, ) - mm_placeholders = dict[str, list[PlaceholderRange]]() - err_suffix = ("This suggests a problem with your implementation of " - "the merged multi-modal processor for this model, " - "particularly in the `_get_prompt_replacements` method.") - - for modality, placeholders in full_groupby_modality(all_placeholders): - if modality not in mm_items: - raise AssertionError( - f"Expected no placeholders for {modality=}, " - f"but found {placeholders=}. Input items: {mm_items}" - f"\n{err_suffix}") - - if len(placeholders) != len(mm_items[modality]): - raise AssertionError( - f"Expected length of {placeholders=} for {modality=} " - f"to equal that of input items: {mm_items[modality]}" - f"\n{err_suffix}") - - mm_placeholders[modality] = [ - item.to_range() for item in placeholders - ] + mm_placeholders = {**hf_mm_placeholders, **missing_mm_placeholders} + + self._validate_mm_placeholders(mm_placeholders, mm_item_counts) + + mm_placeholder_ranges = { + modality: [item.to_range() for item in placeholders] + for modality, placeholders in mm_placeholders.items() + } return MultiModalInputsV2( type="multimodal", prompt=prompt_text, prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, - mm_placeholders=mm_placeholders, + mm_placeholders=mm_placeholder_ranges, ) def _get_dummy_audios( @@ -1092,8 +1164,9 @@ def _get_dummy_videos( return [video] * num_videos @abstractmethod - def _get_dummy_mm_inputs( + def _get_dummy_processor_inputs( self, + seq_len: int, mm_counts: Mapping[str, int], ) -> ProcessorInputs: """ @@ -1121,12 +1194,25 @@ def _get_and_validate_dummy_mm_counts(self) -> Mapping[str, int]: return mm_limits + def _get_dummy_mm_inputs( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> MultiModalInputsV2: + processor_inputs = self._get_dummy_processor_inputs(seq_len, mm_counts) + + return self.apply( + prompt_text=processor_inputs.prompt_text, + mm_data=processor_inputs.mm_data, + hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, + ) + def get_dummy_data(self, seq_len: int) -> DummyData: # Avoid circular import from vllm.sequence import SequenceData mm_counts = self._get_and_validate_dummy_mm_counts() - mm_max_tokens_per_item = self.get_mm_max_tokens_per_item() + mm_max_tokens_per_item = self.get_mm_max_tokens_per_item(seq_len) if mm_counts.keys() != mm_max_tokens_per_item.keys(): raise AssertionError( "The keys returned by `get_supported_mm_limits`" @@ -1134,13 +1220,7 @@ def get_dummy_data(self, seq_len: int) -> DummyData: "returned by `get_mm_max_tokens_per_item` " f"({set(mm_max_tokens_per_item.keys())})") - processor_inputs = self._get_dummy_mm_inputs(mm_counts) - mm_inputs = self.apply( - prompt_text=processor_inputs.prompt_text, - mm_data=processor_inputs.mm_data, - hf_processor_mm_kwargs=processor_inputs.hf_processor_mm_kwargs, - ) - + mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts) prompt_token_ids = mm_inputs["prompt_token_ids"] placeholders_by_modality = mm_inputs["mm_placeholders"] @@ -1171,6 +1251,12 @@ def get_dummy_data(self, seq_len: int) -> DummyData: "reduce `max_num_seqs`, and/or reduce `mm_counts`.", seq_len, total_len, total_placeholders_by_modality) + return DummyData( + seq_data=SequenceData.from_prompt_token_counts((0, seq_len)), + multi_modal_data=None, + multi_modal_placeholders=None, + ) + prompt_token_ids.extend([0] * (seq_len - len(prompt_token_ids))) return DummyData( diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 073d49d7d2009..fb4389dc4df42 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -223,7 +223,8 @@ def get_max_tokens_per_item_by_modality( if self.has_processor(model_config): tokenizer = cached_get_tokenizer(model_config.tokenizer) processor = self.create_processor(model_config, tokenizer) - return processor.get_mm_max_tokens_per_item() + seq_len = model_config.max_model_len + return processor.get_mm_max_tokens_per_item(seq_len) return { key: plugin.get_max_multimodal_tokens(model_config) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 42b2f095bc543..97920f42ec52f 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -21,6 +21,19 @@ MistralTokenizer] +def decode_tokens( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: bool = False, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + """ + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + + def encode_tokens( tokenizer: AnyTokenizer, text: str,