Skip to content

Commit

Permalink
xe: sdpa: Fix mask loads for unaligned memory
Browse files Browse the repository at this point in the history
  • Loading branch information
umar456 committed Dec 19, 2024
1 parent 1d9b22a commit a42ea4d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,9 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
A += DST_OFF(b1, b0, 0, 0, 0);
#if WITH_ATTN_MASK
msk += MSK_OFF(b1 % MSK_D0, b0 % MSK_D1, 0, 0);
#ifndef BLOCK_MSK
int mask_aligned = (((size_t)msk) % 4) == 0;
#endif
#endif

#if KEY_SCALES
Expand Down Expand Up @@ -320,9 +323,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
/* Load mask. No remainder handling needed assuming k block size is a power of 2. */
mask_tile_type mask_tile;
#if BROADCAST_MASK_Q
#if BLOCK_MSK
tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
#else
tile_load_t(&mask_tile, msk, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
if (mask_aligned) {
tile_load_block(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
} else {
tile_load_full(&mask_tile, msk, 0, k0 + sg_i0_kq, 0);
}
#endif
#else
tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
#endif
#endif

Expand Down
2 changes: 2 additions & 0 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
auto ldk = gemm_desc_t::get_ld(*pd()->key_md()) * key_mdw.data_type_size();
auto ldv = gemm_desc_t::get_ld(*pd()->val_md()) * val_mdw.data_type_size();
auto lda = gemm_desc_t::get_ld(*pd()->dst_md()) * dst_mdw.data_type_size();
auto ldmsk = pd()->attn_mask_md()->dims[3] * msk_mdw.data_type_size();
kernel_ctx.define_int("Q_ALIGN", jit::alignmentForLD(int(ldq)));
kernel_ctx.define_int("K_ALIGN", jit::alignmentForLD(int(ldk)));
kernel_ctx.define_int("V_ALIGN", jit::alignmentForLD(int(ldv)));
Expand Down Expand Up @@ -483,6 +484,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
if (d_full) {
if (ldq % 4 == 0) kernel_ctx.define_int("BLOCK_Q", 1);
if (lda % 4 == 0 && v_full) kernel_ctx.define_int("BLOCK_A", 1);
if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);
kernel_ctx.define_int("REMAINDER_Q", (d->queries() % tile_q) != 0);
} else if (pd()->arch() >= compute::gpu_arch_t::xe_hpc) {
auto vbytes = d->values() * val_mdw.data_type_size();
Expand Down

0 comments on commit a42ea4d

Please sign in to comment.