Skip to content

Commit

Permalink
Add ONNX export support for DinoV2, Hiera, Maskformer, PVT, SigLIP, S…
Browse files Browse the repository at this point in the history
…winV2, VitMAE, and VitMSN models (#2001)

* Add support for siglip models

* cleanup

* remove submodule

* Add ONNX export for DinoV2 models

* Use height and width from preprocessor

* formatting

* Remove attention mask from model input

* Add ONNX export support for Hiera models

* Add ONNX export support for SwinV2

* Upgrade Siglip to opset=14

* Add VQA task

* Add ONNX export support for Maskformer

* Add ONNX export support for PVT

* Add ONNX export support for ViTMAE and ViTMSN

* Add siglip unit tests

* Add vit-mae unit tests

* Code formatting

* Add maskformer to list of supported models

* Formatting

* fix typo

* remove vit-mae masked-im task

* remove vit-msn masked-im task

* fix output names for maskformer export

---------

Co-authored-by: Ella Charlaix <[email protected]>
  • Loading branch information
xenova and echarlaix authored Dec 19, 2024
1 parent 4daa408 commit 0c42291
Show file tree
Hide file tree
Showing 9 changed files with 224 additions and 5 deletions.
8 changes: 8 additions & 0 deletions docs/source/exporters/onnx/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Decision Transformer
- Deit
- Detr
- DINOv2
- DistilBert
- Donut-Swin
- Electra
Expand All @@ -53,6 +54,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- GPT-NeoX
- OPT
- GroupVit
- Hiera
- Hubert
- IBert
- LayoutLM
Expand All @@ -64,6 +66,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- M2-M100
- Marian
- MarkupLM
- MaskFormer
- MBart
- MGP-STR
- Mistral
Expand All @@ -84,6 +87,7 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- Phi3
- Pix2Struct
- PoolFormer
- PVT
- Qwen2(Qwen1.5)
- RegNet
- RemBERT
Expand All @@ -95,17 +99,21 @@ Supported architectures from [🤗 Transformers](https://huggingface.co/docs/tra
- SEW
- SEW-D
- Speech2Text
- SigLIP
- SpeechT5
- Splinter
- SqueezeBert
- Swin
- SwinV2
- T5
- Table Transformer
- TROCR
- UniSpeech
- UniSpeech SAT
- Vision Encoder Decoder
- Vit
- VitMAE
- VitMSN
- Wav2Vec2
- Wav2Vec2 Conformer
- WavLM
Expand Down
118 changes: 118 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,65 @@ class ConvNextV2OnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class HieraOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class PvtOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class VitMAEOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class VitMSNOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class Dinov2DummyInputGenerator(DummyVisionInputGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedVisionConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
num_channels: int = DEFAULT_DUMMY_SHAPES["num_channels"],
width: int = DEFAULT_DUMMY_SHAPES["width"],
height: int = DEFAULT_DUMMY_SHAPES["height"],
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
num_channels=num_channels,
width=width,
height=height,
**kwargs,
)

from transformers.onnx.utils import get_preprocessor

preprocessor = get_preprocessor(normalized_config._name_or_path)
if preprocessor is not None and hasattr(preprocessor, "crop_size"):
self.height = preprocessor.crop_size.get("height", self.height)
self.width = preprocessor.crop_size.get("width", self.width)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
input_ = super().generate(
input_name=input_name, framework=framework, int_dtype=int_dtype, float_dtype=float_dtype
)
return input_


class Dinov2OnnxConfig(ViTOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (Dinov2DummyInputGenerator,)


class MobileViTOnnxConfig(ViTOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
DEFAULT_ONNX_OPSET = 11
Expand Down Expand Up @@ -888,6 +947,10 @@ class SwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11


class SwinV2OnnxConfig(SwinOnnxConfig):
pass


class Swin2srOnnxConfig(SwinOnnxConfig):
pass

Expand Down Expand Up @@ -923,6 +986,28 @@ class MobileNetV2OnnxConfig(MobileNetV1OnnxConfig):
pass


class MaskFormerOnnxConfig(ViTOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::einsum' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 12, try exporting with this version.
DEFAULT_ONNX_OPSET = 12

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "image-segmentation":
return {
"class_queries_logits": {0: "batch_size", 1: "num_queries"},
"masks_queries_logits": {0: "batch_size", 1: "num_queries", 2: "height", 3: "width"},
}
else:
return super().outputs

@property
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"transformer_decoder_last_hidden_state": "last_hidden_state",
}


class DonutSwinOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 11

Expand Down Expand Up @@ -1115,6 +1200,39 @@ def patch_model_for_export(
return CLIPModelPatcher(self, model, model_kwargs=model_kwargs)


class SiglipNormalizedConfig(CLIPNormalizedConfig):
pass


class SiglipOnnxConfig(CLIPOnnxConfig):
NORMALIZED_CONFIG_CLASS = SiglipNormalizedConfig
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 13 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"pixel_values": {0: "image_batch_size", 1: "num_channels", 2: "height", 3: "width"},
# NOTE: No attention_mask
}


class SiglipTextWithProjectionOnnxConfig(CLIPTextWithProjectionOnnxConfig):
pass


class SiglipTextOnnxConfig(CLIPTextOnnxConfig):
pass


class SiglipVisionModelOnnxConfig(CLIPVisionModelOnnxConfig):
# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::scaled_dot_product_attention' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14


class UNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-4
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
Expand Down
70 changes: 68 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,12 @@ class TasksManager:
"feature-extraction": "AutoModel",
"fill-mask": "AutoModelForMaskedLM",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"image-segmentation": (
"AutoModelForImageSegmentation",
"AutoModelForSemanticSegmentation",
"AutoModelForInstanceSegmentation",
"AutoModelForUniversalSegmentation",
),
"image-to-image": "AutoModelForImageToImage",
"image-to-text": ("AutoModelForVision2Seq", "AutoModel"),
"mask-generation": "AutoModel",
Expand All @@ -224,6 +229,7 @@ class TasksManager:
"text2text-generation": "AutoModelForSeq2SeqLM",
"text-classification": "AutoModelForSequenceClassification",
"token-classification": "AutoModelForTokenClassification",
"visual-question-answering": "AutoModelForVisualQuestionAnswering",
"zero-shot-image-classification": "AutoModelForZeroShotImageClassification",
"zero-shot-object-detection": "AutoModelForZeroShotObjectDetection",
}
Expand Down Expand Up @@ -307,6 +313,7 @@ class TasksManager:
"vision2seq-lm": "image-to-text",
"zero-shot-classification": "text-classification",
"image-feature-extraction": "feature-extraction",
"pretraining": "feature-extraction",
# for backward compatibility and testing (where
# model task and model type are still the same)
"stable-diffusion": "text-to-image",
Expand Down Expand Up @@ -601,6 +608,11 @@ class TasksManager:
"image-segmentation",
onnx="DetrOnnxConfig",
),
"dinov2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="Dinov2OnnxConfig",
),
"distilbert": supported_tasks_mapping(
"feature-extraction",
"fill-mask",
Expand Down Expand Up @@ -732,6 +744,11 @@ class TasksManager:
"feature-extraction",
onnx="GroupViTOnnxConfig",
),
"hiera": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="HieraOnnxConfig",
),
"hubert": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down Expand Up @@ -813,6 +830,11 @@ class TasksManager:
"question-answering",
onnx="MarkupLMOnnxConfig",
),
"maskformer": supported_tasks_mapping(
"feature-extraction",
"image-segmentation",
onnx="MaskFormerOnnxConfig",
),
"mbart": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1011,6 +1033,11 @@ class TasksManager:
"image-classification",
onnx="PoolFormerOnnxConfig",
),
"pvt": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="PvtOnnxConfig",
),
"regnet": supported_tasks_mapping(
"feature-extraction",
"image-classification",
Expand Down Expand Up @@ -1070,6 +1097,23 @@ class TasksManager:
"audio-classification",
onnx="SEWDOnnxConfig",
),
"siglip": supported_tasks_mapping(
"feature-extraction",
"zero-shot-image-classification",
onnx="SiglipOnnxConfig",
),
"siglip-text-model": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextOnnxConfig",
),
"siglip-text-with-projection": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipTextWithProjectionOnnxConfig",
),
"siglip-vision-model": supported_tasks_mapping(
"feature-extraction",
onnx="SiglipVisionModelOnnxConfig",
),
"speech-to-text": supported_tasks_mapping(
"feature-extraction",
"feature-extraction-with-past",
Expand Down Expand Up @@ -1102,6 +1146,12 @@ class TasksManager:
"masked-im",
onnx="SwinOnnxConfig",
),
"swinv2": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"masked-im",
onnx="SwinV2OnnxConfig",
),
"swin2sr": supported_tasks_mapping(
"feature-extraction",
"image-to-image",
Expand Down Expand Up @@ -1148,7 +1198,19 @@ class TasksManager:
onnx="VisionEncoderDecoderOnnxConfig",
),
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
"feature-extraction",
"image-classification",
"masked-im",
onnx="ViTOnnxConfig",
),
"vit-mae": supported_tasks_mapping(
"feature-extraction",
onnx="VitMAEOnnxConfig",
),
"vit-msn": supported_tasks_mapping(
"feature-extraction",
"image-classification",
onnx="VitMSNOnnxConfig",
),
"vits": supported_tasks_mapping(
"text-to-audio",
Expand Down Expand Up @@ -1232,6 +1294,10 @@ class TasksManager:
"unet-2d-condition",
"vae-encoder",
"vae-decoder",
"clip-text-model",
"clip-text-with-projection",
"siglip-text-model",
"siglip-text-with-projection",
# redundant model types
"trocr", # same as vision-encoder-decoder
}
Expand Down
4 changes: 2 additions & 2 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageClassification(ORTModel):
"""
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, vit.
ONNX Model for image-classification tasks. This class officially supports beit, convnext, convnextv2, data2vec_vision, deit, dinov2, levit, mobilenet_v1, mobilenet_v2, mobilevit, poolformer, resnet, segformer, swin, swinv2, vit.
"""

auto_model_class = AutoModelForImageClassification
Expand Down Expand Up @@ -1784,7 +1784,7 @@ def forward(
@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForSemanticSegmentation(ORTModel):
"""
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This class officially supports segformer.
ONNX Model for semantic-segmentation, with an all-MLP decode head on top e.g. for ADE20k, CityScapes. This class officially supports maskformer, segformer.
"""

auto_model_class = AutoModelForSemanticSegmentation
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def check_optimization_supported_model(cls, model_type: str, optimization_config
"clip",
"vit",
"swin",
"swinv2",
]
model_type = model_type.replace("_", "-")
if (model_type not in cls._conf) or (cls._conf[model_type] not in supported_model_types_for_optimization):
Expand Down
4 changes: 4 additions & 0 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,10 @@ class NormalizedConfigManager:
'data2vec-text',
'data2vec-vision',
'detr',
'dinov2',
'flaubert',
'groupvit',
'hiera',
'ibert',
'layoutlm',
'layoutlmv3',
Expand All @@ -216,6 +218,8 @@ class NormalizedConfigManager:
'owlvit',
'perceiver',
'roformer',
'segformer',
'siglip',
'squeezebert',
'table-transformer',
"""
Expand Down
Loading

0 comments on commit 0c42291

Please sign in to comment.