-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve vlm support (add idefics3 support) #2437
base: main
Are you sure you want to change the base?
Conversation
c93fd85
to
35c64b2
Compare
35c64b2
to
ebef284
Compare
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -632,13 +630,12 @@ class FlashLlamaForCausalLM(torch.nn.Module): | |||
def __init__(self, prefix: str, config, weights): | |||
super().__init__() | |||
|
|||
if config.model_type == "mllama_text_model": | |||
prefix = f"{prefix}.model" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. The correct line was whatever line was before.
This class cannot know about the model_type (shouldn't).
Especially since you're removing everything a few lines below.
No shenanigans here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
100% agreed that the class should not know about the model type, I've updated the logic to handle this case by avoiding appending .model
if the prefix ends in text_model
base_model = "" if prefix.endswith("text_model") else ".model"
The reason for this complexity is the naming convention used by idefics3. The model has weights with names like model.text_model.embed_tokens.weight
and the current logic always expects models to contain model.embed_tokens
or X.model.embed_tokens
.
The latest changes handle this case by conditionally appending ".model"
before constructing the prefixes. Please let me know if theres a better way to handle this 🙏
@@ -679,6 +679,215 @@ def forward(self, image_hidden_states, attention_mask): | |||
return image_hidden_states | |||
|
|||
|
|||
class Idefics3Connector(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This belongs in idefic3 file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, moved into a new file in latest commit
diff = mask_size - unrolled_image_size | ||
if diff > 0: | ||
print( | ||
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
diff = mask_size - unrolled_image_size | |
if diff > 0: | |
print( | |
f"Mask size {mask_size} is greater than the number of images {unrolled_image_size}." | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed in latest changes, thanks!
if mask_size == unrolled_image_size: | ||
inputs_embeds = self._merge_input_ids_with_image_features( | ||
input_ids, inputs_embeds, image_hidden_states | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if mask_size == unrolled_image_size: | |
inputs_embeds = self._merge_input_ids_with_image_features( | |
input_ids, inputs_embeds, image_hidden_states | |
) | |
inputs_embeds = self._merge_input_ids_with_image_features( | |
input_ids, inputs_embeds, image_hidden_states | |
) |
Let it crash if something is wrong here. We should NEVER do silent errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes agreed! removed in the latest commits
IDEFICS3_IMAGE_TOKEN = "<image>" | ||
IDEFICS3_FAKE_IMAGE_TOKEN = "<fake_token_around_image>" | ||
IDEFICS3_GLOBAL_IMG_TOKEN = "<global-img>" | ||
|
||
|
||
def _prompt_split_image( | ||
image_seq_len, | ||
image_rows, | ||
image_cols, | ||
fake_token_around_image, | ||
image_token, | ||
global_img_token, | ||
): | ||
"""Prompt with expanded image tokens for when the image is split into patches.""" | ||
text_split_images = "" | ||
for n_h in range(image_rows): | ||
for n_w in range(image_cols): | ||
text_split_images += ( | ||
f"{fake_token_around_image}" | ||
+ f"<row_{n_h + 1}_col_{n_w + 1}>" | ||
+ f"{image_token}" * image_seq_len | ||
) | ||
text_split_images += "\n" | ||
|
||
text_split_images += ( | ||
f"\n{fake_token_around_image}" | ||
+ f"{global_img_token}" | ||
+ f"{image_token}" * image_seq_len | ||
+ f"{fake_token_around_image}" | ||
) | ||
return text_split_images | ||
|
||
|
||
def _prompt_single_image( | ||
image_seq_len, fake_token_around_image, image_token, global_img_token | ||
): | ||
"""Prompt with expanded image tokens for a single image.""" | ||
return ( | ||
f"{fake_token_around_image}" | ||
+ f"{global_img_token}" | ||
+ f"{image_token}" * image_seq_len | ||
+ f"{fake_token_around_image}" | ||
) | ||
|
||
|
||
def get_image_prompt_string( | ||
image_rows, | ||
image_cols, | ||
image_seq_len, | ||
fake_token_around_image, | ||
image_token, | ||
global_img_token, | ||
): | ||
if image_rows == 0 and image_cols == 0: | ||
return _prompt_single_image( | ||
image_seq_len, | ||
fake_token_around_image=fake_token_around_image, | ||
image_token=image_token, | ||
global_img_token=global_img_token, | ||
) | ||
return _prompt_split_image( | ||
image_seq_len, | ||
image_rows, | ||
image_cols, | ||
fake_token_around_image, | ||
image_token, | ||
global_img_token, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put everything in some idefics3 file.
Can't those 4 functions be trivially merged into one using joins instead of forloops and ifs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. I've moved the idefics code into a new file and reduce this logic into a much more simple function
def get_image_prompt_string(
rows=0,
cols=0,
seq_len=1,
fake_token=IDEFICS3_FAKE_IMAGE_TOKEN,
img_token=IDEFICS3_IMAGE_TOKEN,
global_token=IDEFICS3_GLOBAL_IMG_TOKEN,
):
tokens = img_token * seq_len
end_token = f"{fake_token}{global_token}{tokens}{fake_token}"
if rows == 0 or cols == 0:
return end_token
grid = "\n".join(
"".join(f"{fake_token}<row_{i+1}_col_{j+1}>{tokens}" for j in range(cols))
for i in range(rows)
)
return f"{grid}\n\n{end_token}"
This PR is a work in progress and add support for Idefics3 in TGI. opening for transparency and feedback.
This implementation uses the
AutoProcessor/Idefics3Processor
that will be added when this PR is merged: huggingface/transformers#32473todos
processor_kwargs
)