Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT model doesn't update params when having changed LoRA config #2295

Open
4 tasks done
d-kleine opened this issue Dec 23, 2024 · 2 comments
Open
4 tasks done

PEFT model doesn't update params when having changed LoRA config #2295

d-kleine opened this issue Dec 23, 2024 · 2 comments

Comments

@d-kleine
Copy link
Contributor

d-kleine commented Dec 23, 2024

System Info

I have noticed that when updated the target_modules settings in the LoRA config, the PEFT model params remain unchanged. Might affect other PEFT settings too.

My assumption is that get_peft_model() does not re-instantiate/update its settings once it has been initialized before.

System: Windows 11
Python: 3.11
peft: 0.14.0

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

For reproduction in a Jupyter Notebook:

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch

label_list = ['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

model = AutoModelForTokenClassification.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    pad_token_id=tokenizer.eos_token_id,
    torch_dtype=torch.bfloat16,
    device_map="auto", 
    num_labels=len(label_list)
)

for name, module in model.named_modules():
    print(name)
lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,             
    lora_alpha=32, 
    target_modules=["q_proj", "v_proj"],  
    lora_dropout=0.1
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

This outputs

trainable params: 1,722,377 || all params: 1,237,555,218 || trainable%: 0.1392

But when changing the above code without restarting the kernel to:

lora_config = LoraConfig(
    task_type=TaskType.TOKEN_CLS,
    r=16,             
    lora_alpha=32, 
    target_modules=["layers.0.self_attn.q_proj", "layers.0.self_attn.v_proj"], # changed to specific heads
    lora_dropout=0.1
)

and retrieving the trainable params again:

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

it outputs again

trainable params: 1,722,377 || all params: 1,237,555,218 || trainable%: 0.1392

but after the update it should be

trainable params: 124,937 || all params: 1,235,957,778 || trainable%: 0.0101

Expected behavior

When having updated lora_config, get_peft_model() should retrieve the current config.

@BenjaminBossan
Copy link
Member

Your observation is correct. Of course, re-defining a completely new lora_config cannot influence the model, as this just defines a new, unrelated variable that just happens to have the same name. Probably what you mean is that you would like to change the attribute on the existing lora_config:

lora_config = LoraConfig(..., target_modules=["foo"])
model = get_peft_model(base_model, lora_config)
lora_config.target_modules = ["bar"]  # <= you expect this to trigger re-initialization of peft model

Although it is technically possible to turn each parameter into a @property and define a setter that re-initializes the model each time the config is changed, I'd say this is not worth the effort. Intuitively, I also lean towards the current behavior being more intuitive, but that's hard to say.

@d-kleine
Copy link
Contributor Author

d-kleine commented Dec 27, 2024

Your observation is correct. Of course, re-defining a completely new lora_config cannot influence the model, as this just defines a new, unrelated variable that just happens to have the same name.

Sorry, yeah, the model assignment was a bad example. I mean if you save it like that:

model_peft = get_peft_model(model, lora_config)
model_peft.print_trainable_parameters()

In my opinion, if you change the lora_config, these changes should be retrieved by get_peft_model(). Currently, there is not even a warning that the config has changed but the PEFT model doesn't reflect these changes. I don't find this very intuitive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants