-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Request for adding the lora implementation for Conv1d rather than transormers.utils.Conv1d #2241
Comments
Thanks for opening this feature request. We cannot drop support for transormers |
Thanks, would you please provide any hints for me to implement it? e.g., do you think it is ok if I directly replace every transformers.conv1d with torch.conv1d? Thanks a lot. |
First of all, let me preface by saying that I'm not sure if we can just use torch peft/src/peft/tuners/lora/layer.py Lines 1262 to 1269 in 3f9ce55
This is the logic where we check the base layer type and decide that we want to apply a Theoretically, we can just replace the elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)): to match torch However, we should not replace transformers |
Hi, I tried you recommendation method, but I received a new error:
Since my conv1d has kernel size large than 1, it is not trival to make the transformation. I will try other softwares to see if it works. |
If you provide the code to reproduce the error, I can take a look. |
Hi, thanks a lot. I am trying to implement the lora mode of Enformer: https://github.com/lucidrains/enformer-pytorch Here is my code to have the lora mode:
LoRA works well for linear, toq,tok, tov, but the conv represents the nn.Conv1d mode, and I faced this error. The conv layer has kernel size as 5. |
I could make a bit more progress: import torch
from peft import LoraConfig, get_peft_model
from enformer_pytorch import Enformer
model = Enformer.from_pretrained("EleutherAI/enformer-official-rough", device_map=0)
model = get_peft_model(model, LoraConfig(target_modules=["linear", "to_q", "to_k", "to_v", "conv"]))
seq = torch.randint(0, 5, (1, 196_608)).to(0) # for ACGTN, in that order (-1 for padding)
output = model(seq) The only changes I had to make were to this line: - elif isinstance(target_base_layer, Conv1D):
+ elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)): and this line: - elif isinstance(base_layer, nn.Conv2d):
+ elif isinstance(base_layer, (nn.Conv2d, nn.Conv1d)): However, the forward pass will fail because of mismatched shapes. I think the |
Ok, thanks a lot. Do you think it is possible to include other released LoRA modules directly (e.g., from LoRA torch)? Thanks. |
Not sure exactly what you mean, are you asking if other LoRA implementations exist that already support
Not sure which package exactly you mean. If you mean torchtune, it looks like they only support linear layers. |
How about this one: https://github.com/Baijiong-Lin/LoRA-Torch It claims that lora-torch supports conv1d. |
Yeah, it looks like this package has a LoRA implementation for |
Thanks, sure I am happy to have a try. |
I have a tried but it cannot work due to a misshape error. Let me wait the responses from authors and determine the next step. Thanks. |
Hi, I have exciting updates! I have helped the authors resolve the previous bugs, and now the model after lora can infer very well with nn.Conv1d layer. Do you think it will be promising to incoporate that implementation to peft model? Thanks. https://github.com/Baijiong-Lin/LoRA-Torch The results look good in Enformer: |
Thanks for your feedback. If possible, could you please provide the full code. Also, did you test this on your actual use case? I did a quick check of their implementation and I see no reason why |
Hi, sure. I make a simple test by replacing the nn.Conv1d layer in Enformer with lora.Conv1d. That is:
You can see that there is a new module: lora.Conv1d(in_channels=4, out_channels=half_dim, kernel_size=15, padding = 7, r=16, lora_alpha=32), My actual usecase needs fine-tuning, I also include the results after fine-tuning with one step here, which looks good to me:
It works well, but the lora register is too noisy, I prefer the implementation of PEFT. |
Feature request
Hi, I found that Lora does not support the model with torch.nn.Conv1d as convolution layers, which limits the use-case for models pre-trained with this class (for example, Enformer). I wonder if it is possible to add an implementation based on this class.
Motivation
To finetune enformer.
Your contribution
If you need I can open a PR.
The text was updated successfully, but these errors were encountered: