Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ERROR - converting 'scaled_dot_product_attention' op (located at: 'text_encoder/text_model/encoder/0/self_attn'): #345

Open
khanhvp2k opened this issue Aug 3, 2024 · 2 comments

Comments

@khanhvp2k
Copy link

khanhvp2k commented Aug 3, 2024

I am trying to convert my local model to CoreML and encountered an error after converting the TextEncoder:

ERROR - converting 'scaled_dot_product_attention' op (located at: 'text_encoder/text_model/encoder/0/self_attn'):
NotImplementedError: scaled_dot_product_attention op: scale parameter is not handled.

Below is the full log content:

ERROR - converting 'scaled_dot_product_attention' op (located at: 'text_encoder/text_model/encoder/0/self_attn'):

Converting PyTorch Frontend ==> MIL Ops:  15%|▏| 69/449 [00:00<00:00, 5385.12 o
Traceback (most recent call last):
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/Users/vietnamdtssoftware/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 1729, in <module>
    main(args)
  File "/Users/vietnamdtssoftware/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 1518, in main
    convert_text_encoder(pipe.text_encoder, pipe.tokenizer, "text_encoder", args)
  File "/Users/vietnamdtssoftware/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 403, in convert_text_encoder
    coreml_text_encoder, out_path = convert_to_coreml(
  File "/Users/vietnamdtssoftware/ml-stable-diffusion/python_coreml_stable_diffusion/torch2coreml.py", line 129, in convert_to_coreml
    coreml_model = ct.convert(
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/_converters_entry.py", line 635, in convert
    mlmodel = mil_convert(
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 188, in mil_convert
    return _mil_convert(model, convert_from, convert_to, ConverterRegistry, MLModel, compute_units, kwargs)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 212, in _mil_convert
    proto, mil_program = mil_convert_to_proto(
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 288, in mil_convert_to_proto
    prog = frontend_converter(model, kwargs)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/converter.py", line 108, in __call
    return load(args, *kwargs)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 84, in load
    return _perform_torch_convert(converter, debug)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 126, in _perform_torch_convert
    raise e
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/load.py", line 118, in _perform_torch_convert
    prog = converter.convert()
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/converter.py", line 1184, in convert
    convert_nodes(self.context, self.graph, early_exit=not has_states)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 93, in convert_nodes
    raise e     # re-raise exception
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 88, in convert_nodes
    convert_single_node(context, node)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 144, in convert_single_node
    add_op(context, node)
  File "/opt/miniconda3/envs/coreml_stable_diffusion/lib/python3.8/site-packages/coremltools/converters/mil/frontend/torch/ops.py", line 6895, in scaled_dot_product_attention
    raise NotImplementedError(
NotImplementedError: scaled_dot_product_attention op: scale parameter is not handled.

I used the following command in the command line:

python -m python_coreml_stable_diffusion.torch2coreml --convert-vae-decoder --convert-vae-encoder --convert-unet --unet-support-controlnet --convert-text-encoder --model-version v1-5-pruned-emaonly_mergemyrobo_home_ratio_1_diffusers --bundle-resources-for-swift-cli --attention-implementation SPLIT_EINSUM -o v1-5-pruned-emaonly_mergemyrobo_home_ratio_1_split-einsum

run on Mac Mini, MacOS Sonoma 14.0

can anyone help?

@atiorh
Copy link
Collaborator

atiorh commented Aug 4, 2024

huggingface/transformers#31940 introduced the usage of torch.nn.functional.scaled_dot_product_attention in the transformers implementation of CLIP. A short term workaround is to pip install transformers==4.42.4 which is the most recent release that doesn't have this change. Another workaround is to modify transformers code to apply the scale on q before calling the sdpa function.

@qubvel
Copy link

qubvel commented Aug 29, 2024

Hi! You can also load model with eager attention instead of sdpa (scaled_dot_product_attention)

model = CLIPModel.from_pretrained(..., attn_implementation="eager")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants