Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix position encodings for Pixtral #2678

Merged
merged 5 commits into from
Dec 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 55 additions & 13 deletions candle-transformers/src/models/pixtral/vision_model.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use candle::{DType, Module, Result, Tensor, D};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};

fn default_act() -> candle_nn::Activation {
candle_nn::Activation::Gelu
candle_nn::Activation::Silu
}

fn default_hidden_size() -> usize {
Expand Down Expand Up @@ -58,7 +58,7 @@ impl Config {
num_attention_heads: 16,
head_dim: None,
// Default
hidden_act: candle_nn::Activation::Gelu,
hidden_act: candle_nn::Activation::Silu,
}
}

Expand Down Expand Up @@ -104,6 +104,7 @@ impl Attention {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b, patches, _) = xs.dims3()?;
Expand All @@ -116,7 +117,8 @@ impl Attention {
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;

let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
let (query_states, key_states) =
emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;

let attn_weights = match attention_mask {
Expand Down Expand Up @@ -189,12 +191,16 @@ impl AttentionLayer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self
.attention
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
let xs = self.attention.forward(
&xs.apply(&self.attention_norm)?,
emb,
subsampled_positions,
attention_mask,
)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
Expand Down Expand Up @@ -222,11 +228,12 @@ impl Transformer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, emb, attention_mask)?
xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
}
Ok(xs)
}
Expand Down Expand Up @@ -270,10 +277,20 @@ impl RotaryEmbedding {
Ok(Self { cos, sin })
}

fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
subsampled_positions: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
let cos = &self.cos;
let sin = &self.sin;
let (cos, sin) = match subsampled_positions {
None => (&self.cos, &self.sin),
Some(pos) => (
&self.cos.index_select(pos, 0)?,
&self.sin.index_select(pos, 0)?,
),
};
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
Ok((q_embed, k_embed))
Expand All @@ -286,6 +303,7 @@ pub struct Model {
ln_pre: RmsNorm,
transformer: Transformer,
patch_positional_embedding: RotaryEmbedding,
max_image_width: u32,
}

impl Model {
Expand All @@ -305,20 +323,44 @@ impl Model {
let transformer = Transformer::new(cfg, vb.pp("transformer"))?;
let patch_positional_embedding =
RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?;
let max_image_width = (cfg.image_size / cfg.patch_size) as u32;
Ok(Self {
patch_conv,
ln_pre,
transformer,
patch_positional_embedding,
max_image_width,
})
}

pub fn position_ids_in_meshgrid(
&self,
num_patches_h: usize,
num_patches_w: usize,
device: &Device,
) -> Result<Tensor> {
let idx = Tensor::arange(0, num_patches_h as u32, device)?;
let idy = Tensor::arange(0, num_patches_w as u32, device)?;
let mesh = Tensor::meshgrid(&[idx, idy], false)?;
let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?;
Ok(ids)
}
}

impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let subsampled_positions = Some(self.position_ids_in_meshgrid(
patch_embeds.dim(2)?,
patch_embeds.dim(3)?,
patch_embeds.device(),
)?);
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, None)
self.transformer.forward(
&patch_embeds,
&self.patch_positional_embedding,
subsampled_positions.as_ref(),
None,
)
}
}
Loading