Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【BUG】Attention.head_to_batch_dim has bug in terms of tensor permutation #10303

Open
Dawn-LX opened this issue Dec 19, 2024 · 2 comments
Open
Labels
bug Something isn't working

Comments

@Dawn-LX
Copy link

Dawn-LX commented Dec 19, 2024

Describe the bug

def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:

when out_dim==4, the ourpout shape is mismatch to the function's comment ``[batch_size, seq_len, heads, dim // heads]`

here is original function

def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
        r"""
        Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
        the number of heads initialized while constructing the `Attention` class.

        Args:
            tensor (`torch.Tensor`): The tensor to reshape.
            out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
                reshaped to `[batch_size * heads, seq_len, dim // heads]`.

        Returns:
            `torch.Tensor`: The reshaped tensor.
        """
        head_size = self.heads
        if tensor.ndim == 3:
            batch_size, seq_len, dim = tensor.shape
            extra_dim = 1
        else:
            batch_size, extra_dim, seq_len, dim = tensor.shape
        tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)
        tensor = tensor.permute(0, 2, 1, 3)

        if out_dim == 3:
            tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)

        return tensor

and at Line 633, tensor = tensor.permute(0, 2, 1, 3) the tensor permutes again

The correction should be moving Line633 to Line635.5 i.e.,

        ...
        tensor = tensor.reshape(batch_size, seq_len * extra_dim, head_size, dim // head_size)    

        if out_dim == 3:
            tensor = tensor.permute(0, 2, 1, 3)
            tensor = tensor.reshape(batch_size * head_size, seq_len * extra_dim, dim // head_size)

        return tensor

Reproduction

just inside Attention, run

self.head_to_batch_dim(query,out_dim=4)

Logs

No response

System Info

This is irrelevant to the system & environment info

Who can help?

No response

@Dawn-LX Dawn-LX added the bug Something isn't working label Dec 19, 2024
@hlky
Copy link
Collaborator

hlky commented Dec 19, 2024

Thanks for finding this @Dawn-LX. Would you like to create a PR with the fix?

Here's a self-contained reproduction.

from diffusers.models.attention_processor import Attention
import torch

attn = Attention(96, heads=8)

tensor = torch.randn([1, 16, 96])

out = attn.head_to_batch_dim(tensor, out_dim=3)
out.shape
# torch.Size([8, 16, 12])
# [batch_size * heads, seq_len, dim // heads]

out = attn.head_to_batch_dim(tensor, out_dim=4)
out.shape
# torch.Size([1, 8, 16, 12]) !=
# [batch_size, seq_len, heads, dim // heads]

# after fix
# torch.Size([1, 16, 8, 12])

cc @sayakpaul @yiyixuxu

@Dawn-LX
Copy link
Author

Dawn-LX commented Dec 20, 2024

Thanks for finding this @Dawn-LX. Would you like to create a PR with the fix?

Here's a self-contained reproduction.

from diffusers.models.attention_processor import Attention
import torch

attn = Attention(96, heads=8)

tensor = torch.randn([1, 16, 96])

out = attn.head_to_batch_dim(tensor, out_dim=3)
out.shape
# torch.Size([8, 16, 12])
# [batch_size * heads, seq_len, dim // heads]

out = attn.head_to_batch_dim(tensor, out_dim=4)
out.shape
# torch.Size([1, 8, 16, 12]) !=
# [batch_size, seq_len, heads, dim // heads]

# after fix
# torch.Size([1, 16, 8, 12])

cc @sayakpaul @yiyixuxu

ok, I have created this: #10312

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants