Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Different Results When Predicting with Multiple LoRA Adapters in a Loop VS. Using only One LoRA #2270

Open
4 tasks
beyondguo opened this issue Dec 10, 2024 · 11 comments

Comments

@beyondguo
Copy link

beyondguo commented Dec 10, 2024

System Info

Linux, Python 3.8

A two-H100 node.

Name: transformers
Version: 4.34.1

Name: peft
Version: 0.11.1

Who can help?

@BenjaminBossan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Description

I encountered a strange issue while using PEFT LoRA adapters with the Hugging Face Trainer. When predicting using different LoRA adapters in a loop, the predictions are different compared to when using the same LoRA adapter (e.g., m4) individually. The issue arises when I predict using multiple LoRA adapters sequentially, and then compare the results of the m4 adapter between the two scenarios.

Steps to Reproduce

  1. I have a dictionary lora_map that maps LoRA adapter names to their respective paths.
  2. The code below iterates over lora_map and predicts using each LoRA adapter:
dfs = []
for lora_name in lora_map: 
    pred_df = test_df[useful_columns].copy()
    # model.set_adapter(lora_name)
    model = PeftModel.from_pretrained(base_model, lora_map[lora_name], adapter_name=lora_name)
    print("predicting with lora", lora_name)
    trainer = Trainer(model=model, args=args, data_collator=data_collator)
    preds = trainer.predict(token_test_dataset).predictions  # logits
    pred_df[['neu','pos','neg']] = torch.softmax(torch.tensor(preds), dim=-1).numpy() 
    pred_df['lora'] = lora_name

    dfs.append(pred_df)

final_pred_df = pd.concat(dfs)

the lora_map is like lora_map={'m1':xxx,'m2':xxx,...}

I found the results in final_pred_df[final_pred_df.lora == 'm4'] is different from predicting with loading m4 only. But the results for m1 is the same, probably because its the first in the lora_map.

What could be the problem? What happend when I load the second adapter using PeftModel.from_pretrained ?


I'm sorry I can't share my lora weights (it was trained with PiSSA) since its a private model.

Expected behavior

Same results.

@beyondguo
Copy link
Author

Update: I identified the problem.

The first time after model = PeftModel.from_pretrained(base_model, lora_map['m1'], adapter_name='m1'), if we print the model:

...
            (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (m1): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (m1): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (m1): Linear(in_features=8, out_features=1024, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): ParameterDict()
                  )
                  (key): Linear(in_features=1024, out_features=1024, bias=True)
                  (value): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (m1): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (m1): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (m1): Linear(in_features=8, out_features=1024, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): ParameterDict()
                  )
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=1024, out_features=1024, bias=True)
                  (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=1024, out_features=4096, bias=True)
                (intermediate_act_fn): GELUActivation()
              )
              (output): BertOutput(
                (dense): Linear(in_features=4096, out_features=1024, bias=True)
                (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (activation): Tanh()
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (classifier): ModulesToSaveWrapper(
        (original_module): Linear(in_features=1024, out_features=3, bias=True)
        (modules_to_save): ModuleDict(
          (m1): Linear(in_features=1024, out_features=3, bias=True)
...

Then if I load another lora by model = PeftModel.from_pretrained(base_model, lora_map['m2'], adapter_name='m2'). The model becomes:

...
(attention): BertAttention(
                (self): BertSelfAttention(
                  (query): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (m1): Dropout(p=0.1, inplace=False)
                      (m2): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (m1): Linear(in_features=1024, out_features=8, bias=False)
                      (m2): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (m1): Linear(in_features=8, out_features=1024, bias=False)
                      (m2): Linear(in_features=8, out_features=1024, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): ParameterDict()
                  )
                  (key): Linear(in_features=1024, out_features=1024, bias=True)
                  (value): lora.Linear(
                    (base_layer): Linear(in_features=1024, out_features=1024, bias=True)
                    (lora_dropout): ModuleDict(
                      (m1): Dropout(p=0.1, inplace=False)
                      (m2): Dropout(p=0.1, inplace=False)
                    )
                    (lora_A): ModuleDict(
                      (m1): Linear(in_features=1024, out_features=8, bias=False)
                      (m2): Linear(in_features=1024, out_features=8, bias=False)
                    )
                    (lora_B): ModuleDict(
                      (m1): Linear(in_features=8, out_features=1024, bias=False)
                      (m2): Linear(in_features=8, out_features=1024, bias=False)
                    )
                    (lora_embedding_A): ParameterDict()
                    (lora_embedding_B): ParameterDict()
                  )
                  (dropout): Dropout(p=0.1, inplace=False)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_features=1024, out_features=1024, bias=True)
                  (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
                  (dropout): Dropout(p=0.1, inplace=False)
                )
              )
              (intermediate): BertIntermediate(
                (dense): Linear(in_features=1024, out_features=4096, bias=True)
                (intermediate_act_fn): GELUActivation()
              )
              (output): BertOutput(
                (dense): Linear(in_features=4096, out_features=1024, bias=True)
                (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
          )
        )
        (pooler): BertPooler(
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (activation): Tanh()
        )
      )
      (dropout): Dropout(p=0.1, inplace=False)
      (classifier): ModulesToSaveWrapper(
        (original_module): Linear(in_features=1024, out_features=3, bias=True)
        (modules_to_save): ModuleDict(
          (m1): Linear(in_features=1024, out_features=3, bias=True)
          (m2): Linear(in_features=1024, out_features=3, bias=True)
...

There are two lora adapters in the model, which is beyond my expectation, since I thought the PeftModel.from_pretrained method should initialize a new model, instead of building based on the previous one. However, PeftModel.from_pretrained(base_model,...) changed the base_model , resulting in this issue.

@beyondguo
Copy link
Author

Seems like its related to issue #2184. However, I didn't see any warnings when I tried to load multiple PiSSA lora adapters.

@BenjaminBossan
Copy link
Member

There are two lora adapters in the model, which is beyond my expectation, since I thought the PeftModel.from_pretrained method should initialize a new model, instead of building based on the previous one. However, PeftModel.from_pretrained(base_model,...) changed the base_model , resulting in this issue.

Please note that after loading the first adapter with PeftModel.from_pretrained, please load all further adapters by calling model.load_adapter.

Seems like its related to issue #2184. However, I didn't see any warnings when I tried to load multiple PiSSA lora adapters.

Okay, so you're using PiSSA. Please follow the advice in the linked issue to convert them to normal LoRA adapters and then they should not interfere with one another. Regarding the missing warning, it was added in a later PEFT version, if you upgrade PEFT you should see it.

@beyondguo
Copy link
Author

beyondguo commented Dec 10, 2024

@BenjaminBossan Thanks for your reply, happy to meet you in the timeline :)
I'm transfering the PiSSA to LoRA using the following code:

base_model = MyModelForClassification.from_pretrained(base_model_path, num_labels=3,ignore_mismatched_sizes=True)

lora_name = 'm1'
D = {
    "m1": "trained_model/pissa_to_lora/TP16to19_seed2",
    "m2": "trained_model/pissa_to_lora/TP16to19_seed3",
    "m3": "trained_model/pissa_to_lora/TP20to23_seed2",
    "m4": "trained_model/pissa_to_lora/TP20to23_seed3",
}
print('loading pissa...')
model = PeftModel.from_pretrained(base_model, lora_map[lora_name]) 
output_dir = D[lora_name]
print('saving pissa...')
model.save_pretrained(output_dir)
model.save_pretrained(output_dir, convert_pissa_to_lora="pissa_init")
print('Saved!')

and got this error:

loading model...
Some weights of MyModelForClassification were not initialized from the model checkpoint at ../../llm_hub/Erlangshen-Roberta-330M-Sentiment and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([2, 1024]) in the checkpoint and torch.Size([3, 1024]) in the model instantiated
- classifier.bias: found shape torch.Size([2]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
loading pissa...
saving pissa...
/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/utils/save_and_load.py:195: UserWarning: Could not find a config file in ../../llm_hub/Erlangshen-Roberta-330M-Sentiment - will assume that the vocabulary was not modified.
  warnings.warn(
Traceback (most recent call last):
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/config.py", line 197, in _get_peft_type
    config_file = hf_hub_download(
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 110, in _inner_fn
    validate_repo_id(arg_value)
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 164, in validate_repo_id
    raise HFValidationError(
huggingface_hub.utils._validators.HFValidationError: Repo id must use alphanumeric chars or '-', '_', '.', '--' and '..' are forbidden, '-' and '.' cannot start or end the name, max length is 96: ''.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "pissa_to_lora.py", line 62, in <module>
    model.save_pretrained(output_dir, convert_pissa_to_lora="pissa_init")
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/peft_model.py", line 283, in save_pretrained
    output_state_dict = save_pissa_as_lora(
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/peft_model.py", line 232, in save_pissa_as_lora
    self.load_adapter(
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/peft_model.py", line 970, in load_adapter
    PeftConfig._get_peft_type(
  File "/home/guoby/app/Anaconda3-2021.05/envs/news/lib/python3.8/site-packages/peft/config.py", line 203, in _get_peft_type
    raise ValueError(f"Can't find '{CONFIG_NAME}' at '{model_id}'")
ValueError: Can't find 'adapter_config.json' at ''

The contents of the output_dir now is:
image

@BenjaminBossan
Copy link
Member

model.save_pretrained(output_dir, convert_pissa_to_lora="pissa_init")

Note that the convert_pissa_to_lora argument should point to the

path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA and before performing any training.

according to the docs. (Note that the convert_pissa_to_lora argument is deprecated and should be renamed to path_initial_model_for_weight_conversion)

@beyondguo
Copy link
Author

beyondguo commented Dec 10, 2024

I don't quite understand what is "path to the initialized adapter, which is obtained after initializing the model with PiSSA or OLoRA and before performing any training."

Now I have a trained PiSSA adapter saved at path_1, and I want to convert PiSSA to lora to path_2. Which is the path I should input for path_initial_model_for_weight_conversion? Or neither of them?

I tried both, but found that if I load from the saved path, the model is still doing SVD.


I checked this issue: #1929, which seems to tell me the path should be created before training. So how could I convert an already trained PiSSA model?

@BenjaminBossan
Copy link
Member

No, it's neither of those. What you need to do is to save the adapter after it has been initialized but before training of the model has started, as e.g. shown here:

lora_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
init_lora_weights=script_args.init_lora_weights,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
peft_model = get_peft_model(model, lora_config)
# Save PiSSA modules:
peft_model.peft_config["default"].init_lora_weights = True
peft_model.save_pretrained(os.path.join(script_args.output_dir, "pissa_init"))

I know that you have already trained your model, so perhaps it's too late. But if you still have your training script and if you have fixed all random seeds, you should be able to restore that state exactly as it was before you did your first training run. So try to restore this state and save the model somewhere. Then use that model path as the argument for path_initial_model_for_weight_conversion.

@beyondguo
Copy link
Author

Thanks a lot for your help and detailed information! I will try to train my models again and see if all problem solved. I will report tomorrow and close the issue then.

@BenjaminBossan
Copy link
Member

Good luck. As mentioned, if you can precisely restore the model state to before training started, there should be no need for retraining, otherwise it is unfortunately required.

@beyondguo
Copy link
Author

Hi @BenjaminBossan I'm back again. I fixed all random seeds during previous training, so I tried to restore the model state without retraining. Here's how I save the initial pissa weights:

from transformers import set_seed
from transformers import BertForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType

set_seed(42)  # <---------- same with training

base_model_path = '../../llm_hub/Erlangshen-Roberta-330M-Sentiment'

class MyModelForClassification(BertForSequenceClassification):
    ...
base_model = MyModelForClassification.from_pretrained(base_model_path, num_labels=3,ignore_mismatched_sizes=True)

peft_config = LoraConfig(init_lora_weights="pissa", task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) # <---------- same with training
init_pissa_model = get_peft_model(base_model, peft_config)

init_pissa_model.peft_config["default"].init_lora_weights = True 
init_pissa_model.save_pretrained('trained_model/pissa_init_seed42')

Then I try to convert pissa to lora with:

trained_pissa_path = '...' # -----> this is where I save the model during training, the original pissa weights

class MyModelForClassification(BertForSequenceClassification):
    ...
base_model = MyModelForClassification.from_pretrained(base_model_path, num_labels=3,ignore_mismatched_sizes=True)

converted_pissa_path = '...' # -----> this is where I want to save the converted version
trained_pissa_model = PeftModel.from_pretrained(base_model, trained_pissa_path ) 
trained_pissa_model.save_pretrained(converted_pissa_path , path_initial_model_for_weight_conversion='trained_model/pissa_init_seed42')

Finally I load the model with:

model = PeftModel.from_pretrained(base_model, converted_pissa_path)

and found its still doing SVD.

Note that the above three scripts are seperate python programs.

Checking #1929 (comment) again, I guess this is due to my PEFT version (0.11.1). Maybe I have to upgrade the version and try again. I'm not sure if a new version can work, since my trained model is based on the old version.

What confuses me during these trials is that everythings runs fine without throwing errors or warnings, so I can't know whether a correct converted pissa is saved until I finally load it. I guess more checks can be done internally in the save_pretrained method to tell users if the weights is correctly converted.
Another thing I learned is that maybe PiSSA is not a good choice for online production, at least rightnow.

Anyway, thank you @BenjaminBossan for your help these days! Best regards!

@BenjaminBossan
Copy link
Member

Thanks for trying out the suggestion. I'm not quite sure if your second code block is correct. Where is the actual training happening? The intent is that you can start training at the end of the first code block, after the init_pissa_model.save_pretrained('trained_model/pissa_init_seed42') call. After you finished training, conversion is performed when you call .save_pretrained(converted_pissa_path , path_initial_model_for_weight_conversion='trained_model/pissa_init_seed42'). There shouldn't be any need for an additional trained_pissa_model = PeftModel.from_pretrained(base_model, trained_pissa_path ) call.

its still doing SVD.

Could you paste the adapter_config.json of that adapter?

I guess this is due to my PEFT version (0.11.1). Maybe I have to upgrade the version and try again. I'm not sure if a new version can work, since my trained model is based on the old version.

Please try if upgrading resolves the issue for you. In general, it is save to upgrade PEFT, as we avoid making changes that make older checkpoints invalid. There are rare occasions where upgrading PEFT can lead to different results, but we document that in our release notes.

I guess more checks can be done internally in the save_pretrained method to tell users if the weights is correctly converted.

We don't know if there is any error in the conversion here, I suspect it's either that the necessary steps have not been performed or it is indeed required to upgrade PEFT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants