diff --git a/candle-transformers/src/models/pixtral/vision_model.rs b/candle-transformers/src/models/pixtral/vision_model.rs index 20d8f08231..3f884aaf89 100644 --- a/candle-transformers/src/models/pixtral/vision_model.rs +++ b/candle-transformers/src/models/pixtral/vision_model.rs @@ -1,8 +1,8 @@ -use candle::{DType, Module, Result, Tensor, D}; +use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder}; fn default_act() -> candle_nn::Activation { - candle_nn::Activation::Gelu + candle_nn::Activation::Silu } fn default_hidden_size() -> usize { @@ -58,7 +58,7 @@ impl Config { num_attention_heads: 16, head_dim: None, // Default - hidden_act: candle_nn::Activation::Gelu, + hidden_act: candle_nn::Activation::Silu, } } @@ -104,6 +104,7 @@ impl Attention { &self, xs: &Tensor, emb: &RotaryEmbedding, + subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result { let (b, patches, _) = xs.dims3()?; @@ -116,7 +117,8 @@ impl Attention { let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?; - let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?; + let (query_states, key_states) = + emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?; let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?; let attn_weights = match attention_mask { @@ -189,12 +191,16 @@ impl AttentionLayer { &self, xs: &Tensor, emb: &RotaryEmbedding, + subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result { let residual = xs; - let xs = self - .attention - .forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?; + let xs = self.attention.forward( + &xs.apply(&self.attention_norm)?, + emb, + subsampled_positions, + attention_mask, + )?; let xs = (residual + xs)?; let residual = &xs; let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?; @@ -222,11 +228,12 @@ impl Transformer { &self, xs: &Tensor, emb: &RotaryEmbedding, + subsampled_positions: Option<&Tensor>, attention_mask: Option<&Tensor>, ) -> Result { let mut xs = xs.clone(); for layer in self.layers.iter() { - xs = layer.forward(&xs, emb, attention_mask)? + xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)? } Ok(xs) } @@ -270,10 +277,20 @@ impl RotaryEmbedding { Ok(Self { cos, sin }) } - fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + fn apply_rotary_emb_qkv( + &self, + q: &Tensor, + k: &Tensor, + subsampled_positions: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?; - let cos = &self.cos; - let sin = &self.sin; + let (cos, sin) = match subsampled_positions { + None => (&self.cos, &self.sin), + Some(pos) => ( + &self.cos.index_select(pos, 0)?, + &self.sin.index_select(pos, 0)?, + ), + }; let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?; let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?; Ok((q_embed, k_embed)) @@ -286,6 +303,7 @@ pub struct Model { ln_pre: RmsNorm, transformer: Transformer, patch_positional_embedding: RotaryEmbedding, + max_image_width: u32, } impl Model { @@ -305,20 +323,44 @@ impl Model { let transformer = Transformer::new(cfg, vb.pp("transformer"))?; let patch_positional_embedding = RotaryEmbedding::new(cfg, vb.pp("patch_positional_embedding"))?; + let max_image_width = (cfg.image_size / cfg.patch_size) as u32; Ok(Self { patch_conv, ln_pre, transformer, patch_positional_embedding, + max_image_width, }) } + + pub fn position_ids_in_meshgrid( + &self, + num_patches_h: usize, + num_patches_w: usize, + device: &Device, + ) -> Result { + let idx = Tensor::arange(0, num_patches_h as u32, device)?; + let idy = Tensor::arange(0, num_patches_w as u32, device)?; + let mesh = Tensor::meshgrid(&[idx, idy], false)?; + let ids = (&mesh[0] * (self.max_image_width as f64) + &mesh[1])?.flatten_all()?; + Ok(ids) + } } impl Module for Model { fn forward(&self, xs: &Tensor) -> Result { let patch_embeds = xs.apply(&self.patch_conv)?; + let subsampled_positions = Some(self.position_ids_in_meshgrid( + patch_embeds.dim(2)?, + patch_embeds.dim(3)?, + patch_embeds.device(), + )?); let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?; - self.transformer - .forward(&patch_embeds, &self.patch_positional_embedding, None) + self.transformer.forward( + &patch_embeds, + &self.patch_positional_embedding, + subsampled_positions.as_ref(), + None, + ) } }