-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support to export ColPali Model to ONNX #2074
base: main
Are you sure you want to change the base?
Conversation
@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun Are you open to merging this? |
Apologies for the delay @akshayballal95, could you add a test with a tiny random model like https://huggingface.co/hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration, can be added here https://github.com/huggingface/optimum/blob/main/tests/exporters/exporters_utils.py#L37 |
class ColPaliModelPatcher(ModelPatcher): | ||
def __init__( | ||
self, | ||
config: "OnnxConfig", | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
super().__init__(config, model, model_kwargs) | ||
|
||
def patched_forward(input_ids=None, pixel_values=None, attention_mask=None, **kwargs): | ||
outputs = self.orig_forward( | ||
input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **kwargs | ||
) | ||
return outputs | ||
|
||
self.patched_forward = patched_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is it needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its needed because the original ColPali Model only takes **kwargs and no named arguments. This resulted in an error. This fixes that error
I have added the conversion test. It works fine locally. |
What does this PR do?
This PR adds support for exporting the ColPali merged model to ONNX format. The model is based on the "pali gemma" model type, and thus, I have added it under the "feature-extraction" task. Do suggest if there is a better way to integrate this. If this looks fine with a few modifications, I can add support for the Paligemma text-generation task as well.
Before submitting
Who can review?
@fxmarty, @echarlaix, @JingyaHuang, @michaelbenayoun