Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent Parameter Mismatches After Merging PEFT and Base Models #2289

Open
2 of 4 tasks
enhulu-ms opened this issue Dec 19, 2024 · 10 comments
Open
2 of 4 tasks

Inconsistent Parameter Mismatches After Merging PEFT and Base Models #2289

enhulu-ms opened this issue Dec 19, 2024 · 10 comments

Comments

@enhulu-ms
Copy link

enhulu-ms commented Dec 19, 2024

System Info

peft 0.14.0, transformers 4.45.2, accelerate 1.0.1, Python 3.11.9, windows

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

from transformers_custom.modeling impor6t CustomConfig
from transformers_custom.tokenization import CustomTokenizer
from transformers_custom.multitask_model import CustomForSequenceClassificationMultitask
from peft import PeftModel
import torch    

def compare_model_params(model1, model2):
    # Extract state dictionaries
    sd1 = model1.state_dict()
    sd2 = model2.state_dict()
    
    # First, check if they have the same keys
    keys1 = set(sd1.keys())
    keys2 = set(sd2.keys())
    
    # Find parameters that are not present in both
    missing_in_model2 = keys1 - keys2
    missing_in_model1 = keys2 - keys1
    
    if missing_in_model2:
        print("Parameters missing in model2:", missing_in_model2)
    if missing_in_model1:
        print("Parameters missing in model1:", missing_in_model1)
        
    # Now compare parameters that exist in both
    mismatch_names = []
    for key in sorted(keys1.intersection(keys2)):
        param1 = sd1[key]
        param2 = sd2[key]
        
        # Check for shape mismatch
        if param1.shape != param2.shape:
            mismatch_names.append(key)
            continue
        
        # Check for value mismatch
        if not torch.allclose(param1, param2):
            print("Mismatched values for parameter:", key, f"model1: {param1}", f"model2: {param2}")
            mismatch_names.append(key)
    
    # Print out results
    if mismatch_names:
        print("Mismatched parameters:", mismatch_names)
    else:
        print("All parameters match perfectly.")

base_model_path = r"C:\models\tms\download\base2"
peft_path = r"C:\models\tms\download\adapter2"
merged_model_path = r"C:\models\tms\download\adapter2_merged\peft_merged"

config = CustomConfig.from_pretrained(
    base_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

peft_model = PeftModel.from_pretrained(base_model, peft_path)

peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

merged_config = CustomConfig.from_pretrained(
    merged_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

merged_model = CustomForSequenceClassificationMultitask.from_pretrained(
    merged_model_path,
    config=merged_config,
    cache_dir=None,
    revision="main",
)
merged_model.eval()

compare_model_params(peft_model_merged, merged_model)

Expected behavior

I saved the base model and the merged model (using save_pretrained) after training and calling merge_and_unload(). I also saved the PEFT model (via trainer.save_model). After loading the PEFT parameters on top of the base model and calling merge_and_unload(), I compared the newly merged model with the previously saved merged model. Some parameters do not match, and the specific mismatches change with each run to compare models. For example, sometimes the mismatched parameters are ['classifier2.class_dense.bias', 'classifier2.class_dense.weight', ...] and other times ['custom.encoder.layer.19.attention.self.query.weight'].

How can I resolve this issue? Ideally, there should be no mismatches, or at least the mismatches should be consistent across runs.

@enhulu-ms enhulu-ms changed the title PEFT loaded parameters are randomized after merge_and_unload Inconsistent Parameter Mismatches After Merging PEFT and Base Models Dec 19, 2024
@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

Although each run contains different set of components with mismatches, I also noticed that the mismatches across multiple runs appear to have the same mismatched values on the same component. e.g. Mismatched values for parameter: classifier4.out_proj.bias model1: tensor([ 0.0013, -0.0003, -0.0003, -0.0004, -0.0004, -0.0005, -0.0007, -0.0010]) model2: tensor([ 2.3966e-01, 6.4979e-03, 8.9810e-04, 1.0589e-04, -2.9830e-03,
-5.1880e-03, -1.1035e-02, -2.7188e-02]), where model1 is the newly merged model and model2 is merged model after training.

@enhulu-ms
Copy link
Author

Also I noticed that by just adding an irrelevant second model loading without actually using it, all parameters become mismatched... I guess it might be a incorrect memory pointer somewhere in Peft implementation??

peft_model1 = PeftModel.from_pretrained(base_model, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

I noticed a really weird behavior while debugging. It seems that both PeftModel.from_pretrained(base_model, peft_path) and peft_model.merge_and_unload() will change the parameters in base_model??? below is the code to reproduce:

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

base_model1 = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

print("Comparing base_model and base_model1 before loading peft model")
compare_model_params(base_model, base_model1)

peft_model = PeftModel.from_pretrained(base_model, peft_path)
peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

peft_model1 = PeftModel.from_pretrained(base_model, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

print("Comparing base_model and base_model1")
compare_model_params(base_model, base_model1)

It results in mismatches in all components between base_model and base_model1... Any explanation?

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

for the PEFT configuration, I am using modules_to_save = ["classifier","classifier2","classifier3","classifier4"], and it seems the random mismatches happens mostly in "classifier2","classifier3","classifier4". For example, Mismatched parameters: ['classifier2.class_dense.bias', 'classifier2.class_dense.weight', 'classifier2.out_proj.bias', 'classifier2.out_proj.weight', 'classifier3.class_dense.bias', 'classifier3.class_dense.weight', 'classifier3.out_proj.bias', 'classifier3.out_proj.weight', 'classifier4.class_dense.bias', 'classifier4.class_dense.weight', 'classifier4.out_proj.bias', 'classifier4.out_proj.weight']

@githubnemo
Copy link
Collaborator

Hey :) Thanks for raising an issue.

I noticed a really weird behavior while debugging. It seems that both PeftModel.from_pretrained(base_model, peft_path) and peft_model.merge_and_unload() will change the parameters in base_model??? below is the code to reproduce:

I think this behavior is expected and documented. This is done to save memory on large models. See from_pretrained and merge_and_unload.

Could this already explain the discrepancies you're seeing? It is not really possible for me to reproduce your setup exactly since I don't know what your exact lora config is nor how your model behaves.

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

@githubnemo , thanks for the explanation. Unfortunately, the behavior of changing base model does not solve or explain the mismatches in this case. So basically, the issue is that I got different model parameters each session I load the PEFT model (load base -> apply lora -> merge_and_unload). I tried loading two models in the same session and those two models are the same. So the discrepancies happens in each session not for each instance in the same session. I also tried setting the random seed and the problem persists. I suspect it is related to modules_to_save function. Anyway, the LoRA configuration I am using is the below:

  • lora_rank: 128, lora_alpha: 256, lora_dropout: 0.1
  • lora target_modules: ['query', 'value', 'key', 'dense'], saved_modules: ['classifier', 'classifier2', 'classifier3', 'classifier4', 'gate_ur_linear']
  • The base model architecture is below
CustomForSequenceClassificationMultitask(
  (Custom): CustomModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(500002, 1024, padding_idx=1)
      (position_embeddings): Embedding(514, 1024, padding_idx=1)
      (token_type_embeddings): Embedding(1, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): CustomEncoder(
      (layer): ModuleList(
        (0-23): 24 x CustomLayer(
          (attention): CustomAttention(
            (self): CustomSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (gate_ur_linear): Linear(in_features=64, out_features=8, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (intermediate): BertIntermediate(
            (dense): Linear(in_features=1024, out_features=4096, bias=True)
            (intermediate_act_fn): GELUActivation()
          )
          (output): BertOutput(
            (dense): Linear(in_features=4096, out_features=1024, bias=True)
            (LayerNorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
    )
    (pooler): BertPooler(
      (dense): Linear(in_features=1024, out_features=1024, bias=True)
      (activation): Tanh()
    )
    (rel_pos_bias): Linear(in_features=32, out_features=16, bias=False)
  )
  (classifier): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier2): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier3): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
  (classifier4): CustomClassificationHead(
    (class_dense): Linear(in_features=1024, out_features=1024, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (out_proj): Linear(in_features=1024, out_features=8, bias=True)
  )
)

@githubnemo
Copy link
Collaborator

githubnemo commented Dec 19, 2024

I think this is expected. If CustomForSequenceClassificationMultitask works similar to AutoModelForSequenceClassification then a model, say gpt2, will receive a new (untrained, freshly initialized) classification head. That would be classification* in your case. If you use the task type SEQ_CLS in your LoraConfig the classification head(s) will automatically be added to LoraConfig.modules_to_save, thus, they will be saved but they are not adapters. Therefore they are not merged unto the base model.

Therefore the mystery is that you are comparing base models that have merged adapters but differently initialized classification heads which, of course, differ.

You should not see a difference when comparing the PeftModel.from_pretrained() instances.

peft_model1 = PeftModel.from_pretrained(base_model1, peft_path)
peft_model1_merged = peft_model1.merge_and_unload()
peft_model1_merged.eval()

peft_model2 = PeftModel.from_pretrained(base_model2, peft_path)
peft_model2_merged = peft_model2.merge_and_unload()
peft_model2_merged.eval()

print("Comparing base_model1 and base_model2")
compare_model_params(base_model1, base_model2) # expecting a difference

print("Comparing peft_model1 and peft_model2")
compare_model_params(peft_model1, peft_model2) # no difference

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 19, 2024

@githubnemo , the CustomForSequenceClassificationMultitask does include the parameters for classification heads. So there should not be any initialization for classification head. What I am comparing with is the previously trained-merged model and the loaded-merged model. The trained-merged model is saved as CustomForSequenceClassificationMultitask, so no initialization for it either. I already mentioned that within the same session, if I load two models (load base -> apply lora -> merge_and_unload), there is no difference within the session. However, in each session, I am comparing the loaded model with the checkpoint of the trained-merged one which should not change across sessions as I just load it from checkpoint. below is the same code attached in the description of the issue. For each session, the mismatched components are different. If it is initialization issue then the mismatched components should stay the same. Or do you have any explanation on that?

base_model_path = r"C:\models\tms\download\base2"
peft_path = r"C:\models\tms\download\adapter2"
merged_model_path = r"C:\models\tms\download\adapter2_merged\peft_merged"

config = CustomConfig.from_pretrained(
    base_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

base_model = CustomForSequenceClassificationMultitask.from_pretrained(
    base_model_path,
    config=config,
    cache_dir=None,
    revision="main",
) 

peft_model = PeftModel.from_pretrained(base_model, peft_path)
peft_model_merged = peft_model.merge_and_unload()
peft_model_merged.eval()

merged_config = CustomConfig.from_pretrained(
    merged_model_path,
    num_labels=8,
    finetuning_task=None,
    cache_dir=None,
    revision="main",
)

merged_model = CustomForSequenceClassificationMultitask.from_pretrained(
    merged_model_path,
    config=merged_config,
    cache_dir=None,
    revision="main",
)
merged_model.eval()

print("Comparing base_model and base_model1")
compare_model_params(base_model, base_model1)

@githubnemo
Copy link
Collaborator

If I understand you correctly you are wondering why the classification heads are replaced even though you are passing pretrained classifiers. That is understandable and surprising when you simply want to fine-tune the in-between layers instead of the classification heads. PEFT assumes that if your task type is classification (LoraConfig.task_type) that the classification* or score* layers want to be trained as well.

What you are seeing is because you probably set the task type to SEQ_CLS and once you use get_peft_model the classification heads are being retrained as well.

Try setting LoraConfig.task_type to None or using the PeftModel class directly when training your adapters. You should get your expected behavior then.

@enhulu-ms
Copy link
Author

enhulu-ms commented Dec 20, 2024

@githubnemo, when i was training the model, I did not pass task_type. So the default should be None? The pretrained model I load includes classification head parameters as well. Also I put classification heads into modules_to_save. Do you mean that the default task_type is not None in this case? I checked the saved adapter_config.json, which seems to be null

target_modules = ["query","value","key","dense", "gate_ur_linear"]
saved_modules = ["classifier","classifier2","classifier3","classifier4"]
peft_config = LoraConfig(
r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, target_modules=target_modules,, modules_to_save=saved_modules
 )
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

below is the saved adapter_config.json

{
  "alpha_pattern": {},
  "auto_mapping": {
    "base_model_class": "CustomForSequenceClassificationMultitask",
    "parent_library": "transformers_custom.multitask_model"
  },
  "base_model_name_or_path": "<local>",
  "bias": "none",
  "eva_config": null,
  "exclude_modules": null,
  "fan_in_fan_out": false,
  "inference_mode": true,
  "init_lora_weights": true,
  "layer_replication": null,
  "layers_pattern": null,
  "layers_to_transform": null,
  "loftq_config": {},
  "lora_alpha": 256,
  "lora_bias": false,
  "lora_dropout": 0.1,
  "megatron_config": null,
  "megatron_core": "megatron.core",
  "modules_to_save": [
    "classifier",
    "classifier2",
    "classifier3",
    "classifier4"
  ],
  "peft_type": "LORA",
  "r": 128,
  "rank_pattern": {},
  "revision": null,
  "target_modules": [
    "key",
    "query",
    "value",
    "dense"
  ],
  "task_type": null,
  "use_dora": false,
  "use_rslora": false
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants