Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapter name conflict with tuner prefix leads to unclear warning during model loading #2252

Open
2 of 4 tasks
pzdkn opened this issue Dec 3, 2024 · 3 comments
Open
2 of 4 tasks

Comments

@pzdkn
Copy link
Contributor

pzdkn commented Dec 3, 2024

System Info

  • peft=0.13.3.dev0
  • transformers=4.46.3
  • Python=3.12.6

Who can help?

@BenjaminBossan
@stevhliu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

import peft 
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, TaskType
from tempfile import TemporaryDirectory


def get_adapter_state_dict(model, adapter_name):
    adapter_state_dict = model.state_dict()
    adapter_weights = {key: value for key, value in adapter_state_dict.items() if adapter_name in key}
    return adapter_weights

def test_adapter_name():
    MODEL_ID = "openai-community/gpt2"
    ADAPTER_NAME = "lora"
    
    # Load the base model
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
    
    # Define LoraConfig without base_model_name_or_path
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, 
        inference_mode=False, 
        r=8
    )
    
    # Get the PEFT model with the adapter name
    peft_model = peft.get_peft_model(base_model, peft_config, adapter_name=ADAPTER_NAME)
    
    # Saving and loading the model
    with TemporaryDirectory() as temp_dir:
        temp_dir = Path(temp_dir)

        adapter_weights = get_adapter_state_dict(peft_model, ADAPTER_NAME)
        peft_model.save_pretrained(temp_dir)
        
        loaded_peft_model = peft.PeftModel.from_pretrained(model=base_model,
                                                            model_id= temp_dir / ADAPTER_NAME,
                                                            adapter_name=ADAPTER_NAME)
        loaded_adapter_weights = get_adapter_state_dict(loaded_peft_model, ADAPTER_NAME)
       
        # Assertion fails due to adapter weights been newly intialized
        for key in adapter_weights:
            assert torch.isclose(adapter_weights[key], loaded_adapter_weights[key]).all()

Expected behavior

A clear error or warning message indicating that the adapter_name and tuner_prefix should not be the same.

@pzdkn
Copy link
Contributor Author

pzdkn commented Dec 3, 2024

This is kind of a stupid edge case, but not unthinkable to happen (e.g. to me). I suggest to raise when both adapter name and tuner prefix are the same when loading the peft model in the first place.

@BenjaminBossan
Copy link
Member

Ah yes, this is a bit of an unfortunate situation. You should have gotten a warning:

UserWarning: Found missing adapter keys while loading the checkpoint: [...]

But I agree that we could give another warning when we detect that the adapter name is a substring of the PEFT method prefix.

Btw. you may want to use peft.get_peft_model_state_dict in situations like this.

@pzdkn
Copy link
Contributor Author

pzdkn commented Dec 3, 2024

@BenjaminBossan I see, I have added a PR for this, but maybe the problem is so minor that it doesn't matter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants