Skip to content

Commit

Permalink
[VLM] Add max-count checking in data parser for single image models (#…
Browse files Browse the repository at this point in the history
…11661)

Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Roger Wang <[email protected]>
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
DarkLight1337 and ywang96 authored Jan 1, 2025
1 parent 4db72e5 commit 365801f
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ See [this page](#generative-models) for more information on how to use generativ
- [V1](gh-issue:8779)
* - `AriaForConditionalGeneration`
- Aria
- T + I
- T + I<sup>+</sup>
- `rhymes-ai/Aria`
-
- ✅︎
Expand Down
3 changes: 2 additions & 1 deletion tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,11 @@ def _test_processing_cache_correctness(


# yapf: disable
# True if the model supports multiple data items of the modality per request
@pytest.mark.parametrize(("model_id", "modalities"), [
("rhymes-ai/Aria", {"image": True}),
("Salesforce/blip2-opt-2.7b", {"image": False}),
("facebook/chameleon-7b", {"image": True}),
("facebook/chameleon-7b", {"image": False}),
("adept/fuyu-8b", {"image": False}),
("llava-hf/llava-1.5-7b-hf", {"image": True}),
("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand Down Expand Up @@ -404,6 +405,9 @@ def get_max_blip2_image_tokens(ctx: InputContext):

class Blip2MultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> Blip2Processor:
return self.ctx.get_hf_processor(Blip2Processor)

Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand Down Expand Up @@ -60,6 +61,9 @@ def get_max_chameleon_image_tokens(ctx: InputContext):

class ChameleonMultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> ChameleonProcessor:
return self.ctx.get_hf_processor(ChameleonProcessor)

Expand Down
18 changes: 11 additions & 7 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
Expand All @@ -54,7 +54,7 @@

class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
data: torch.Tensor
flat_data: torch.Tensor
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
Expand All @@ -63,7 +63,7 @@ class FuyuImagePatchInputs(TypedDict):
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`.
This is used to restore the first two dimensions of `flat_data`.
"""


Expand Down Expand Up @@ -102,6 +102,9 @@ def get_max_fuyu_image_tokens(ctx: InputContext):

class FuyuMultiModalProcessor(BaseMultiModalProcessor):

def _get_data_parser(self) -> MultiModalDataParser:
return MultiModalDataParser(max_mm_counts={"image": 1})

def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor)

Expand Down Expand Up @@ -304,7 +307,7 @@ def _parse_and_validate_image_input(

return FuyuImagePatchInputs(
type="image_patches",
data=self._validate_pixel_values(
flat_data=self._validate_pixel_values(
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
)
Expand All @@ -313,12 +316,13 @@ def _parse_and_validate_image_input(

def _process_image_input(
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"]
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]

assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
return vision_embeddings.split(patches_per_image, dim=0)
vision_embeddings_flat, _ = self.vision_embed_tokens(
image_patches_flat)
return vision_embeddings_flat.split(patches_per_image, dim=0)

def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand Down
28 changes: 26 additions & 2 deletions vllm/multimodal/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,24 @@ def get_items(
class MultiModalDataParser:
"""
Parses :class:`MultiModalDataDict` into :class:`MultiModalDataItems`.
Args:
max_mm_counts (Mapping[str, int]): The maximum allowed number of items
belonging to each modality. This effectively sets a hard limit over
`--limit-mm-per-prompt`.
target_sr (float, optional): Enables automatic resampling of audio
items to the model's expected sampling rate.
"""

def __init__(self, *, target_sr: Optional[float] = None) -> None:
def __init__(
self,
*,
max_mm_counts: Mapping[str, int] = {},
target_sr: Optional[float] = None,
) -> None:
super().__init__()

self.max_mm_counts = max_mm_counts
self.target_sr = target_sr

def _is_embeddings(self, data: object) -> TypeGuard[NestedTensors]:
Expand Down Expand Up @@ -332,13 +345,24 @@ def _get_subparsers(self) -> Mapping[str, ModalityDataParser]:

def parse_mm_data(self,
mm_data: MultiModalDataDict) -> MultiModalDataItems:
max_mm_counts = self.max_mm_counts
subparsers = self._get_subparsers()

mm_items = MultiModalDataItems()
for k, v in mm_data.items():
if k not in subparsers:
raise ValueError(f"Unsupported modality: {k}")

mm_items[k] = subparsers[k](v)
modality_items = subparsers[k](v)

if k in max_mm_counts:
max_count = max_mm_counts[k]
if len(modality_items) > max_count:
raise ValueError(
f"This model supports at most {max_count} {k} items "
f"per prompt, but {len(modality_items)} {k} items "
"were given or set as its limit_mm_per_prompt.")

mm_items[k] = modality_items

return mm_items

0 comments on commit 365801f

Please sign in to comment.