diff --git a/src/transformers/models/llava/convert_llava_weights_to_hf.py b/src/transformers/models/llava/convert_llava_weights_to_hf.py index b8d936e8cc4473..3582b9772c9c44 100644 --- a/src/transformers/models/llava/convert_llava_weights_to_hf.py +++ b/src/transformers/models/llava/convert_llava_weights_to_hf.py @@ -15,7 +15,7 @@ import glob import torch -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import file_exists, hf_hub_download, snapshot_download from safetensors import safe_open from transformers import ( @@ -140,11 +140,12 @@ def convert_llava_llama_to_hf(text_model_id, vision_model_id, output_hub_path, o with torch.device("meta"): model = LlavaForConditionalGeneration(config) - if "Qwen" in text_model_id: - state_dict = load_original_state_dict(old_state_dict_id) - else: + # Some llava variants like microsoft/llava-med-v1.5-mistral-7b use safetensors to store weights + if file_exists(old_state_dict_id, "model_state_dict.bin"): state_dict_path = hf_hub_download(old_state_dict_id, "model_state_dict.bin") - state_dict = torch.load(state_dict_path, map_location="cpu") + state_dict = torch.load(state_dict_path, map_location="cpu", weights_only=True) + else: + state_dict = load_original_state_dict(old_state_dict_id) state_dict = convert_state_dict_to_hf(state_dict) model.load_state_dict(state_dict, strict=True, assign=True)