Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Request for adding the lora implementation for Conv1d rather than transormers.utils.Conv1d #2241

Open
HelloWorldLTY opened this issue Nov 28, 2024 · 16 comments

Comments

@HelloWorldLTY
Copy link

Feature request

Hi, I found that Lora does not support the model with torch.nn.Conv1d as convolution layers, which limits the use-case for models pre-trained with this class (for example, Enformer). I wonder if it is possible to add an implementation based on this class.

Motivation

To finetune enformer.

Your contribution

If you need I can open a PR.

@BenjaminBossan
Copy link
Member

Thanks for opening this feature request. We cannot drop support for transormers Conv1D as it is required for certain models like gpt2. However, we can consider adding support for torch Conv1d on top. If that layer can re-use the same LoRA Linear implementation as transformers Conv1D does, it should be fairly easy. If you have some code to enable this, feel free to open a (draft) PR.

@HelloWorldLTY
Copy link
Author

Thanks, would you please provide any hints for me to implement it? e.g., do you think it is ok if I directly replace every transformers.conv1d with torch.conv1d? Thanks a lot.

@BenjaminBossan
Copy link
Member

First of all, let me preface by saying that I'm not sure if we can just use torch Conv1d or if it won't work. This would need to be tested. Second, here is the crucial part of the code:

elif isinstance(target_base_layer, Conv1D):
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. " "Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
kwargs.update(lora_config.loftq_config)
new_module = Linear(target, adapter_name, is_target_conv_1d_layer=True, **kwargs)

This is the logic where we check the base layer type and decide that we want to apply a LoraLayer (in this case the Linear LoRA layer).

Theoretically, we can just replace the isinstance check by:

elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)):

to match torch Conv1d. Maybe you can give this a try and check if it works for your use case before opening a PR.

However, we should not replace transformers Conv1D because this is needed for some models to work.

@HelloWorldLTY
Copy link
Author

Hi, I tried you recommendation method, but I received a new error:

File /home/tl688/.conda/envs/evo/lib/python3.11/site-packages/torch/nn/modules/linear.py:98, in Linear.__init__(self, in_features, out_features, bias, device, dtype)
     96 self.in_features = in_features
     97 self.out_features = out_features
---> 98 self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
     99 if bias:
    100     self.bias = Parameter(torch.empty(out_features, **factory_kwargs))

TypeError: empty(): argument 'size' failed to unpack the object at pos 2 with error "type must be tuple of ints,but got NoneType"

Since my conv1d has kernel size large than 1, it is not trival to make the transformation. I will try other softwares to see if it works.

@BenjaminBossan
Copy link
Member

If you provide the code to reproduce the error, I can take a look.

@HelloWorldLTY
Copy link
Author

Hi, thanks a lot. I am trying to implement the lora mode of Enformer:

https://github.com/lucidrains/enformer-pytorch

Here is my code to have the lora mode:

def get_lora(model, lora_config = None, train = False): 
    """
    Applies Low-Rank Adaptation (LoRA) to the model.
    This function integrates LoRA modules into specified layers of the model, enabling parameter-efficient 
    fine-tuning. If `train` is True, it sets the LoRA parameters and specific layers in the base model 
    to be trainable. Otherwise, it freezes all parameters.
    Args:
        lora_config (LoraConfig, optional): Configuration for LoRA. If None, uses a default configuration.
        train (bool): Whether the model is being prepared for training.
    """
    if lora_config is None:
#         lora_config = LoraConfig(
#             target_modules=r"(?!separable\d+).*ConvBlock|.*to_q|.*to_v|EnformerTransformerBlock\.\d+\.1\.fn\.1|EnformerTransformerBlock\.\d+\.1\.fn\.4",
#         )

        lora_config =LoraConfig(
            r=8,
            lora_alpha=32,
            target_modules=["linear", "to_q", "to_k", "to_v", "conv"],
            lora_dropout=0.01,
        )
    model = get_peft_model(model, lora_config) # get LoRA model
    print(model)
    if train:
        for params in model.base_model.model.model.embedding.conv_tower.parameters():
            params.requires_grad = True
        if model.base_model.model.model.embedding.transformer_tower:
            for params in model.base_model.model.model.embedding.transformer_tower.parameters():
                params.requires_grad = True
        model.print_trainable_parameters()

    else:
        for params in model.parameters():
            params.requires_grad = False
    return model

LoRA works well for linear, toq,tok, tov, but the conv represents the nn.Conv1d mode, and I faced this error. The conv layer has kernel size as 5.

@BenjaminBossan
Copy link
Member

I could make a bit more progress:

import torch
from peft import LoraConfig, get_peft_model
from enformer_pytorch import Enformer

model = Enformer.from_pretrained("EleutherAI/enformer-official-rough", device_map=0)
model = get_peft_model(model, LoraConfig(target_modules=["linear", "to_q", "to_k", "to_v", "conv"]))
seq = torch.randint(0, 5, (1, 196_608)).to(0) # for ACGTN, in that order (-1 for padding)
output = model(seq)

The only changes I had to make were to this line:

-    elif isinstance(target_base_layer, Conv1D):
+    elif isinstance(target_base_layer, (Conv1D, nn.Conv1d)):

and this line:

-        elif isinstance(base_layer, nn.Conv2d):
+        elif isinstance(base_layer, (nn.Conv2d, nn.Conv1d)):

However, the forward pass will fail because of mismatched shapes. I think the nn.Conv1d module cannot be simply replaced by a Linear layer, unlike transformers Conv1D. It probably needs its own LoRA layer type.

@HelloWorldLTY
Copy link
Author

Ok, thanks a lot. Do you think it is possible to include other released LoRA modules directly (e.g., from LoRA torch)? Thanks.

@BenjaminBossan
Copy link
Member

Not sure exactly what you mean, are you asking if other LoRA implementations exist that already support nn.Conv1d? I'm not sure, from my experience, most other LoRA libraries support less layer types than we do in PEFT.

e.g., from LoRA torch

Not sure which package exactly you mean. If you mean torchtune, it looks like they only support linear layers.

@HelloWorldLTY
Copy link
Author

How about this one: https://github.com/Baijiong-Lin/LoRA-Torch

It claims that lora-torch supports conv1d.

@BenjaminBossan
Copy link
Member

Yeah, it looks like this package has a LoRA implementation for nn.Conv1d. If you give it a try and it works well, feel free to report back. We can check if their implementation can be ported to PEFT.

@HelloWorldLTY
Copy link
Author

Thanks, sure I am happy to have a try.

@HelloWorldLTY
Copy link
Author

I have a tried but it cannot work due to a misshape error. Let me wait the responses from authors and determine the next step. Thanks.

@HelloWorldLTY
Copy link
Author

Hi, I have exciting updates! I have helped the authors resolve the previous bugs, and now the model after lora can infer very well with nn.Conv1d layer. Do you think it will be promising to incoporate that implementation to peft model? Thanks.

https://github.com/Baijiong-Lin/LoRA-Torch

The results look good in Enformer:

image

@BenjaminBossan
Copy link
Member

Thanks for your feedback. If possible, could you please provide the full code. Also, did you test this on your actual use case?

I did a quick check of their implementation and I see no reason why nn.Conv1d could not be supported in PEFT. We would implement this a bit differently (IIUC, their code relies on merging the weights for the forward call, which we don't want to do).

@HelloWorldLTY
Copy link
Author

Hi, sure. I make a simple test by replacing the nn.Conv1d layer in Enformer with lora.Conv1d. That is:

class Enformer_lora(PreTrainedModel):
    config_class = EnformerConfig
    base_model_prefix = "enformer"

    @staticmethod
    def from_hparams(**kwargs):
        return Enformer(EnformerConfig(**kwargs))

    def __init__(self, config):
        super().__init__(config)
        self.dim = config.dim
        half_dim = config.dim // 2
        twice_dim = config.dim * 2

        # create stem

        self.stem = nn.Sequential(
            lora.Conv1d(in_channels=4, out_channels=half_dim, kernel_size=15, padding = 7, r=16, lora_alpha=32),
            Residual(ConvBlock(half_dim)),
            AttentionPool(half_dim, pool_size = 2)
        )

        # create conv tower

        filter_list = exponential_linspace_int(half_dim, config.dim, num = (config.num_downsamples - 1), divisible_by = config.dim_divisible_by)
        filter_list = [half_dim, *filter_list]

        conv_layers = []
        for dim_in, dim_out in zip(filter_list[:-1], filter_list[1:]):
            conv_layers.append(nn.Sequential(
                ConvBlock(dim_in, dim_out, kernel_size = 5),
                Residual(ConvBlock(dim_out, dim_out, 1)),
                AttentionPool(dim_out, pool_size = 2)
            ))

        self.conv_tower = nn.Sequential(*conv_layers)

        # whether to use tensorflow gamma positions

        use_tf_gamma = config.use_tf_gamma
        self.use_tf_gamma = use_tf_gamma

        # transformer

        transformer = []
        for _ in range(config.depth):
            transformer.append(nn.Sequential(
                Residual(nn.Sequential(
                    nn.LayerNorm(config.dim),
                    Attention(
                        config.dim,
                        heads = config.heads,
                        dim_key = config.attn_dim_key,
                        dim_value = config.dim // config.heads,
                        dropout = config.attn_dropout,
                        pos_dropout = config.pos_dropout,
                        num_rel_pos_features = config.dim // config.heads,
                        use_tf_gamma = use_tf_gamma
                    ),
                    nn.Dropout(config.dropout_rate)
                )),
                Residual(nn.Sequential(
                    nn.LayerNorm(config.dim),
                    nn.Linear(config.dim, config.dim * 2),
                    nn.Dropout(config.dropout_rate),
                    nn.ReLU(),
                    nn.Linear(config.dim * 2, config.dim),
                    nn.Dropout(config.dropout_rate)
                ))
            ))

        self.transformer = nn.Sequential(*transformer)

        # target cropping

        self.target_length = config.target_length
        self.crop_final = TargetLengthCrop(config.target_length)

        # final pointwise

        self.final_pointwise = nn.Sequential(
            Rearrange('b n d -> b d n'),
            ConvBlock(filter_list[-1], twice_dim, 1),
            Rearrange('b d n -> b n d'),
            nn.Dropout(config.dropout_rate / 8),
            GELU()
        )

        # create trunk sequential module

        self._trunk = nn.Sequential(
            Rearrange('b n d -> b d n'),
            self.stem,
            self.conv_tower,
            Rearrange('b d n -> b n d'),
            self.transformer,
            self.crop_final,
            self.final_pointwise
        )

        # create final heads for human and mouse

        self.add_heads(**config.output_heads)

        # use checkpointing on transformer trunk

        self.use_checkpointing = config.use_checkpointing

    def add_heads(self, **kwargs):
        self.output_heads = kwargs

        self._heads = nn.ModuleDict(map_values(lambda features: nn.Sequential(
            nn.Linear(self.dim * 2, features),
            nn.Softplus()
        ), kwargs))

    def set_target_length(self, target_length):
        crop_module = self._trunk[-2]
        crop_module.target_length = target_length

    @property
    def trunk(self):
        return self._trunk

    @property
    def heads(self):
        return self._heads

    def trunk_checkpointed(self, x):
        x = rearrange(x, 'b n d -> b d n')
        x = self.stem(x)
        x = self.conv_tower(x)
        x = rearrange(x, 'b d n -> b n d')
        x = checkpoint_sequential(self.transformer, len(self.transformer), x)
        x = self.crop_final(x)
        x = self.final_pointwise(x)
        return x

    def forward(
        self,
        x,
        target = None,
        return_corr_coef = False,
        return_embeddings = False,
        return_only_embeddings = False,
        head = None,
        target_length = None
    ):
        if isinstance(x, list):
            x = str_to_one_hot(x)

        elif type(x) == torch.Tensor and x.dtype == torch.long:
            x = seq_indices_to_one_hot(x)
        x.to(self.device)

        no_batch = x.ndim == 2

        if no_batch:
            x = rearrange(x, '... -> () ...')

        if exists(target_length):
            self.set_target_length(target_length)

        trunk_fn = self.trunk_checkpointed if self.use_checkpointing else self._trunk
        x = trunk_fn(x)

        if no_batch:
            x = rearrange(x, '() ... -> ...')

        if return_only_embeddings:
            return x

        out = map_values(lambda fn: fn(x), self._heads)

        if exists(head):
            assert head in self._heads, f'head {head} not found'
            out = out[head]

        if exists(target):
            assert exists(head), 'head must be passed in if one were to calculate loss directly with targets'

            if return_corr_coef:
                return pearson_corr_coef(out, target)

            return poisson_loss(out, target)

        if return_embeddings:
            return out, x

        return out

You can see that there is a new module: lora.Conv1d(in_channels=4, out_channels=half_dim, kernel_size=15, padding = 7, r=16, lora_alpha=32),

My actual usecase needs fine-tuning, I also include the results after fine-tuning with one step here, which looks good to me:

seq = torch.randint(0, 5, (1, 196_608))
one_hot = seq_indices_to_one_hot(seq)
model = Enformer_lora.from_pretrained('EleutherAI/enformer-official-rough').cuda()
# (!!!) This sets requires_grad to False for all parameters without the string "lora_" in their names
lora.mark_only_lora_as_trainable(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
# Training loop

model.train()
# forward process
out = model(one_hot.cuda())['human']
# backward process

loss = torch.nn.functional.mse_loss(out,predict)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# (!!!) reregister model param to ensure they are in model.state_dict() and model.parameters()
# (!!!) Without this line, the performance does not be affected but you will find that some weights are missing in model.state_dict() and model.parameters()
lora.register_model_param_after_backward(model)

It works well, but the lora register is too noisy, I prefer the implementation of PEFT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants