Skip to content

Commit

Permalink
Fix ONNX checkpoint loading (#2544)
Browse files Browse the repository at this point in the history
* Revert "Disable ONNX tests (#2509)"

This reverts commit a0549fe.

* add external weights

* + pb

* style
  • Loading branch information
anton-l authored Mar 3, 2023
1 parent 1021929 commit 4f0141a
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 3 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ jobs:
runner: docker-tpu
image: diffusers/diffusers-flax-tpu
report: flax_tpu
- name: Slow ONNXRuntime CUDA tests on Ubuntu
framework: onnxruntime
runner: docker-gpu
image: diffusers/diffusers-onnxruntime-cuda
report: onnx_cuda

name: ${{ matrix.config.name }}

Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/push_tests_fast.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ jobs:
runner: docker-cpu
image: diffusers/diffusers-flax-cpu
report: flax_cpu
- name: Fast ONNXRuntime CPU tests on Ubuntu
framework: onnxruntime
runner: docker-cpu
image: diffusers/diffusers-onnxruntime-cpu
report: onnx_cpu
- name: PyTorch Example CPU tests on Ubuntu
framework: pytorch_examples
runner: docker-cpu
Expand Down
12 changes: 9 additions & 3 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME

from ..utils import FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME
from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME


INDEX_FILE = "diffusion_pytorch_model.bin"
Expand Down Expand Up @@ -176,7 +176,13 @@ def is_safetensors_compatible(filenames, variant=None) -> bool:

def variant_compatible_siblings(info, variant=None) -> Union[List[os.PathLike], str]:
filenames = set(sibling.rfilename for sibling in info.siblings)
weight_names = [WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, FLAX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME]
weight_names = [
WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
FLAX_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
ONNX_EXTERNAL_WEIGHTS_NAME,
]

if is_transformers_available():
weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME]
Expand Down Expand Up @@ -604,7 +610,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
]

if from_flax:
ignore_patterns = ["*.bin", "*.safetensors", ".onnx"]
ignore_patterns = ["*.bin", "*.safetensors", "*.onnx", "*.pb"]
elif is_safetensors_available() and is_safetensors_compatible(model_filenames, variant=variant):
ignore_patterns = ["*.bin", "*.msgpack"]

Expand Down

0 comments on commit 4f0141a

Please sign in to comment.