Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve vlm support (add idefics3 support) #2437

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open

Conversation

drbh
Copy link
Collaborator

@drbh drbh commented Aug 20, 2024

This PR is a work in progress and add support for Idefics3 in TGI. opening for transparency and feedback.

This implementation uses the AutoProcessor/Idefics3Processor that will be added when this PR is merged: huggingface/transformers#32473

todos

  • add more comprehensive tests
  • ensure rust image token logic is correct
  • ensure correct config is loaded (related to processor_kwargs)
  • refactors/cleanup typos etc..

@ErikKaum ErikKaum mentioned this pull request Sep 9, 2024
2 tasks
@drbh drbh force-pushed the improve-vlm-support branch from c93fd85 to 35c64b2 Compare October 3, 2024 12:57
@drbh drbh force-pushed the improve-vlm-support branch from 35c64b2 to ebef284 Compare December 17, 2024 18:25
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@drbh drbh marked this pull request as ready for review December 19, 2024 02:40
@@ -632,13 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix: str, config, weights):
super().__init__()

if config.model_type == "mllama_text_model":
prefix = f"{prefix}.model"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The correct line was whatever line was before.

This class cannot know about the model_type (shouldn't).
Especially since you're removing everything a few lines below.

No shenanigans here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

100% agreed that the class should not know about the model type, I've updated the logic to handle this case by avoiding appending .model if the prefix ends in text_model

base_model = "" if prefix.endswith("text_model") else ".model"

The reason for this complexity is the naming convention used by idefics3. The model has weights with names like model.text_model.embed_tokens.weight and the current logic always expects models to contain model.embed_tokens or X.model.embed_tokens.

The latest changes handle this case by conditionally appending ".model" before constructing the prefixes. Please let me know if theres a better way to handle this 🙏

@@ -679,6 +679,215 @@ def forward(self, image_hidden_states, attention_mask):
return image_hidden_states


class Idefics3Connector(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This belongs in idefic3 file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, moved into a new file in latest commit

Comment on lines 861 to 865
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
diff = mask_size - unrolled_image_size
if diff > 0:
print(
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}."
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest changes, thanks!

Comment on lines 867 to 870
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if mask_size == unrolled_image_size:
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)
inputs_embeds = self._merge_input_ids_with_image_features(
input_ids, inputs_embeds, image_hidden_states
)

Let it crash if something is wrong here. We should NEVER do silent errors.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed! removed in the latest commits

Comment on lines 26 to 93
IDEFICS3_IMAGE_TOKEN = "<image>"
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>"
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>"


def _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
):
"""Prompt with expanded image tokens for when the image is split into patches."""
text_split_images = ""
for n_h in range(image_rows):
for n_w in range(image_cols):
text_split_images += (
f"{fake_token_around_image}"
+ f"<row_{n_h + 1}_col_{n_w + 1}>"
+ f"{image_token}" * image_seq_len
)
text_split_images += "\n"

text_split_images += (
f"\n{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)
return text_split_images


def _prompt_single_image(
image_seq_len, fake_token_around_image, image_token, global_img_token
):
"""Prompt with expanded image tokens for a single image."""
return (
f"{fake_token_around_image}"
+ f"{global_img_token}"
+ f"{image_token}" * image_seq_len
+ f"{fake_token_around_image}"
)


def get_image_prompt_string(
image_rows,
image_cols,
image_seq_len,
fake_token_around_image,
image_token,
global_img_token,
):
if image_rows == 0 and image_cols == 0:
return _prompt_single_image(
image_seq_len,
fake_token_around_image=fake_token_around_image,
image_token=image_token,
global_img_token=global_img_token,
)
return _prompt_split_image(
image_seq_len,
image_rows,
image_cols,
fake_token_around_image,
image_token,
global_img_token,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put everything in some idefics3 file.

Can't those 4 functions be trivially merged into one using joins instead of forloops and ifs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. I've moved the idefics code into a new file and reduce this logic into a much more simple function

def get_image_prompt_string(
    rows=0,
    cols=0,
    seq_len=1,
    fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
    img_token=IDEFICS3_IMAGE_TOKEN,
    global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
):
    tokens = img_token * seq_len
    end_token = f"{fake_token}{global_token}{tokens}{fake_token}"

    if rows == 0 or cols == 0:
        return end_token

    grid = "\n".join(
        "".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
        for i in range(rows)
    )

    return f"{grid}\n\n{end_token}"

@drbh drbh requested a review from Narsil December 23, 2024 16:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants