Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shape of attention mask in distilbert example #2667

Open
fbilhaut opened this issue Dec 13, 2024 · 1 comment
Open

Shape of attention mask in distilbert example #2667

fbilhaut opened this issue Dec 13, 2024 · 1 comment

Comments

@fbilhaut
Copy link

fbilhaut commented Dec 13, 2024

Hi,

I'm trying to adapt the distilbert example to make it process multiple sequences at once (the provided example just processes one prompt).

But I'm having trouble providing the proper attention mask to the DistilBertModel::forward() method.

I noticed, when reading the documentation of the forward() method of the equivalent Python class, that that this mask is expected to have the same shape as the input_ids parameter.

This seems sound, and is also consistent with the BERT example in Candle, which does it that way when it comes to processing multiple sequences to compute similarities:

let token_ids = tokens.iter().map(|tokens| {
    let tokens = tokens.get_ids().to_vec();
    Ok(Tensor::new(tokens.as_slice(), device)?)
}).collect::<Result<Vec<_>>>()?;

let attention_mask = tokens.iter().map(|tokens| {
    let tokens = tokens.get_attention_mask().to_vec();
    Ok(Tensor::new(tokens.as_slice(), device)?)
}).collect::<Result<Vec<_>>>()?;

let token_ids = Tensor::stack(&token_ids, 0)?;
let attention_mask = Tensor::stack(&attention_mask, 0)?;

// ...

model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?;

BUT:

In the distilbert example the tokenizer doesn't add any padding, and there is a quite mysterious function that is supposed compute the attention mask returning a different (squared) shape:

fn get_mask(size: usize, device: &Device) -> Tensor {
    let mask: Vec<_> = (0..size)
        .flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
        .collect();
    Tensor::from_slice(&mask, (size, size), device).unwrap()
}

For example for a sequence of 3 tokens this generates the following NxN mask:

[[0, 1, 1], 
[0, 0, 1], 
[0, 0, 0]]

If I simply replace this by the result of the get_attention_mask() function in a 1xN tensor, it works for one sequence.

But for several sequences, if I pad all the sequences to same size S, and stack the masks obtained for N sequence to get a NxS tensor (as does the bert example mentioned earlier), I get an error like this:

cannot broadcast [2, 32] to [2, 12, 32, 32]

I must admit that I don't get the expectations of DistilBertModel::forward() regarding the provided mask. I also don't understand what the gest_mask() function is supposed to do.

Maybe this is due to my lack of knowledge on that matter, but when I refer to the elements mentioned above (Python equivalent and similar Candle example with Bert), I'm wondering if there isn't something wrong with the distilbert example and/or model implementation ?

@fbilhaut
Copy link
Author

fbilhaut commented Dec 13, 2024

@ToluClassics maybe ? (seems like you commited this code :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant