Skip to content

Commit

Permalink
cpu: x64: matmul: add f32:bf16 support on avx512_core and avx2
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 19, 2024
1 parent eaf336a commit 33a89d9
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 29 deletions.
6 changes: 3 additions & 3 deletions src/cpu/matmul/ref_matmul.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ struct ref_matmul_t : public primitive_t {
VDISPATCH_MATMUL(utils::one_of(dst_type, f32, bf16, f16, f8_e5m2,
f8_e4m3, f4_e2m1),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_MATMUL(
(src_type == wei_type
|| utils::one_of(wei_type, f16, u8, s8, u4, s4)),
VDISPATCH_MATMUL((src_type == wei_type
|| utils::one_of(wei_type, bf16, f16, u8,
s8, u4, s4)),
VERBOSE_UNSUPPORTED_DT);
/* int8 weights decompression support */
VDISPATCH_MATMUL(IMPLICATION(utils::one_of(wei_type, u8, s8),
Expand Down
9 changes: 6 additions & 3 deletions src/cpu/x64/matmul/brgemm_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= everyone_is(f16, src_dt, wei_dt) && one_of(dst_dt, f16, f32);
const bool is_f32_f16
= src_dt == f32 && wei_dt == f16 && one_of(dst_dt, f16, f32);
const bool is_f32_bf16
= src_dt == f32 && wei_dt == bf16 && one_of(dst_dt, bf16, f32);
const bool is_bf16_with_int_wei = src_dt == bf16
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, bf16, f32);
const bool is_f16_with_int_wei = src_dt == f16
Expand Down Expand Up @@ -121,7 +123,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
= [&]() -> bool { return attr()->zero_points_.common(); };
const bool problem_dt_correct
= one_of(true, is_int8, is_f8, is_bf16, is_f32, is_f16, is_f32_f16,
is_bf16_with_int_wei, is_f16_with_int_wei);
is_f32_bf16, is_bf16_with_int_wei, is_f16_with_int_wei);

auto src_d = memory_desc_wrapper(src_md_);
auto weights_d = memory_desc_wrapper(weights_md_);
Expand Down Expand Up @@ -161,7 +163,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {

// f32:f16 configuration on AVX2 doesn't support tails with proper
// instruction sequence in copy routines. Anchor: F32_F16_AVX2_NO_TAIL.
VDISPATCH_MATMUL(IMPLICATION(is_f32_f16 && isa == avx2, bgmmc_.N % 8 == 0),
VDISPATCH_MATMUL(IMPLICATION((is_f32_f16 || is_f32_bf16) && isa == avx2,
bgmmc_.N % 8 == 0),
"unsupported configuration");

const float alpha = 1.0;
Expand All @@ -181,7 +184,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
const auto backup_isa = is_amx && bgmmc_.is_runtime_M && !is_s8s8
? (is_f16 || is_f32_f16 || is_f16_with_int_wei
? avx512_core_fp16
: (is_bf16 || is_bf16_with_int_wei
: (is_bf16 || is_f32_bf16 || is_bf16_with_int_wei
? avx512_core_bf16
: (is_int8 ? avx512_core_vnni
: avx512_core)))
Expand Down
28 changes: 26 additions & 2 deletions src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3390,6 +3390,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::load_data(

switch (dt_in_) {
case data_type::f32: uni_vmovups(vmm, op); break;
case data_type::bf16:
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(vmm, op);
uni_vpslld(vmm, vmm, 16);
break;
case data_type::f16:
if (is_superset(conf_->isa, avx512_core_fp16)) {
vcvtph2psx(vmm, op);
Expand Down Expand Up @@ -3602,6 +3607,11 @@ struct jit_brgemm_matmul_copy_b_transposed_t
, use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16)
&& conf_->orig_wei_dt == data_type::f16
&& conf_->wei_dt == data_type::f32)
// This variable is responsible for enabling to upconversion from bf16
// to f32 similarly to f16, mostly for proper tail handling.
, use_bf16_instructions_(is_subset(conf_->isa, avx512_core_bf16)
&& conf_->orig_wei_dt == data_type::bf16
&& conf_->wei_dt == data_type::f32)
, max_tmp_idx(16
- (avx512_core_dot_product_
? 8
Expand Down Expand Up @@ -3648,6 +3658,7 @@ struct jit_brgemm_matmul_copy_b_transposed_t
const bool req_apply_scales_;
const bool avx512_core_dot_product_;
const bool use_fp16_instructions_;
const bool use_bf16_instructions_;
const int max_tmp_idx;

const dim_t src_stride_, tr_src_stride_, scales_K_stride_, typesize_scale_;
Expand Down Expand Up @@ -3793,8 +3804,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
const int columns_tail, const bool use_int4_mask) {
assert(IMPLICATION(use_int4_mask, is_src_int4_));
if (columns_tail > 0) {
const int dt_step
= req_cvtps2xf16_ || use_fp16_instructions_ ? 1 : typesize_;
const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_
|| use_bf16_instructions_
? 1
: typesize_;
const auto tail_mask = use_int4_mask
? size_t(((size_t)1 << div_up(dt_step * columns_tail, 2)) - 1)
: size_t(((size_t)1 << dt_step * columns_tail) - 1);
Expand Down Expand Up @@ -3998,6 +4011,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
} else {
vcvtph2ps(src_load, addr);
}
} else if (use_bf16_instructions_) {
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(src_load, addr);
uni_vpslld(src_load, src_load, 16);
} else {
vmovdqu8(src_load, addr);
}
Expand Down Expand Up @@ -4168,11 +4185,18 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
// For f32:f16 case need to convert raw bytes after `load_bytes`
// into f32 values.
vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx()));
} else if (use_bf16_instructions_) {
// Upconvert: move loaded 16 bits left.
uni_vpslld(vmm_src, vmm_src, 16);
}
} else {
if (use_fp16_instructions_) {
// For non-tailed case can use the convert instruction directly.
vcvtph2ps(vmm_src, ptr[reg_src + i * src_stride_]);
} else if (use_bf16_instructions_) {
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(vmm_src, ptr[reg_src + i * src_stride_]);
uni_vpslld(vmm_src, vmm_src, 16);
} else {
uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]);
}
Expand Down
49 changes: 32 additions & 17 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ status_t check_isa_with_datatype(
&& IMPLICATION(bm_conf_utils.is_f32_f16(),
one_of(isa, avx512_core_fp16, avx2_vnni_2, avx512_core,
avx2))
// `avx512_core_amx` is not supported for plain upconversion as HW
// supports native compute.
&& IMPLICATION(bm_conf_utils.is_f32_bf16(),
one_of(isa, avx512_core_bf16, avx2_vnni_2, avx512_core,
avx2))
&& IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(),
is_superset(isa, avx512_core) || isa == avx2_vnni_2)
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(),
Expand All @@ -207,12 +212,13 @@ status_t check_isa_with_datatype(
}

status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const bool ok = one_of(true, bm_conf_utils.is_f32(),
bm_conf_utils.is_bf16(), bm_conf_utils.is_f16(),
bm_conf_utils.is_f32_f16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
const bool ok
= one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(),
bm_conf_utils.is_f16(), bm_conf_utils.is_f32_f16(),
bm_conf_utils.is_f32_bf16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.with_weights_decompression());
Expand Down Expand Up @@ -251,6 +257,10 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
// avx2 as there's no kernel for such combination.
, f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
// Keep this var separate from bf16_dt to not slip bf16:bf16 on avx512_core
// and avx2 as there's no kernel for such combination.
, f32_bf16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, A_any_layout(A_any_layout)
Expand Down Expand Up @@ -372,7 +382,7 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
const bool is_adbc_allowed
= (this->is_bf16() || this->is_f32() || this->is_bf32()
|| this->is_f16() || this->is_f32_f16()
|| this->is_bf16_with_int_wei()
|| this->is_f32_bf16() || this->is_bf16_with_int_wei()
|| this->is_f16_with_int_wei())
&& !xf16_avx2_vnni_2;
bgmmc.src_tag = is_adbc_allowed
Expand Down Expand Up @@ -475,7 +485,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}

if (this->is_bf16() || this->is_bf16_with_int_wei()
|| ((this->is_f16() || this->is_f32_f16()
|| ((this->is_f16() || this->is_f32_f16() || this->is_f32_bf16()
|| this->is_f16_with_int_wei())
&& (is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2))))
Expand All @@ -488,7 +498,8 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
}
// Note: bf32 assumes f32 blocking
if (this->is_f32() || this->is_bf32() || this->is_f16()
|| this->is_f32_f16() || this->is_f16_with_int_wei())
|| this->is_f32_f16() || this->is_f32_bf16()
|| this->is_f16_with_int_wei())
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
Expand Down Expand Up @@ -742,7 +753,8 @@ void compute_blocking_heuristic_amx(const brgemm_matmul_conf_t &bgmmc,
= div_up(static_cast<int>(bgmmc.K), min_k_per_thread);
const bool is_amx_xf16 = bgmmc.is_amx
&& (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_bf32()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()
|| bm_conf_utils.is_bf32()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei());
const bool is_amx_int8 = bgmmc.is_amx && bm_conf_utils.is_int8();
Expand Down Expand Up @@ -1297,6 +1309,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
bgmmc.is_f32_bf16 = bm_conf_utils.is_f32_bf16();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4);

Expand All @@ -1314,15 +1327,17 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)) {
} else if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2)) {
// Note 1: Keep this branch separately from f16 one to have different
// ISA conditions (f16 includes f16:f32 and f16:f16 combinations).
// ISA conditions (f16 includes f16:f32 and f16:f16 combinations). Same
// applies for bf16 (which includes bf16:bf16).
// Note 2: If `use_buffer_b()` is false, let the kernel perform the
// conversion. Otherwise, make the copy_b routine handle the conversion
// and set kernel data types to f32.
// Note 3: Since `use_buffer_b()` depends on `bgmmc.wei_tag`, which is
// set later in the code due to its dependencies, the update of data
// types to f32 happens below in ANCHOR: `CONVERT_F32_F16_DATA_TYPES`.
// types to f32 happens below in ANCHOR: `CONVERT_F32_XF16_DATA_TYPES`.
} else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) {
bgmmc.src_dt = f16;
bgmmc.wei_dt = f16;
Expand Down Expand Up @@ -1449,9 +1464,9 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
&& bgmmc.is_oscale_per_k && bgmmc.is_oscale_per_n
&& bgmmc.transposed_B;

if (bm_conf_utils.is_f32_f16() && is_superset(bgmmc.isa, avx2)
&& bm_conf_utils.use_buffer_b()) {
// ANCHOR: `CONVERT_F32_F16_DATA_TYPES`
if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) {
// ANCHOR: `CONVERT_F32_XF16_DATA_TYPES`
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
Expand Down Expand Up @@ -1660,7 +1675,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);

if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei()) {
// empirical observation for performance breakpoint between amx and vnni
Expand Down
5 changes: 4 additions & 1 deletion src/cpu/x64/matmul/brgemm_matmul_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ struct brgemm_matmul_conf_t {
bool is_bf16_with_int_wei = false;
bool is_f16_with_int_wei = false;
bool is_f32_f16 = false;
bool is_f32_bf16 = false;
bool is_int4_weights = false;
bool req_wei_vnni_downconvert = false;
bool is_runtime_M = false;
Expand Down Expand Up @@ -303,6 +304,8 @@ struct brgemm_matmul_conf_utils_t {

inline bool is_f32_f16() const { return f32_f16_dt; }

inline bool is_f32_bf16() const { return f32_bf16_dt; }

inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; }

inline bool with_weights_decompression() const {
Expand Down Expand Up @@ -341,7 +344,7 @@ struct brgemm_matmul_conf_utils_t {

const bool f32_dt, bf16_dt, f16_dt, f8_dt, int8_dt, bf32_dt;
const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt,
f16_with_int_wei_dt;
f32_bf16_dt, f16_with_int_wei_dt;

const bool A_any_layout;
const bool B_any_layout;
Expand Down
6 changes: 3 additions & 3 deletions tests/benchdnn/inputs/matmul/test_matmul_bfloat16
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# bf16
--reset

--dt=bf16:bf16:f32,bf16
--dt=bf16:bf16:f32,bf16,f32:bf16:f32
--stag=ab,ba --wtag=ab,ba --dtag=ab
--runtime_dims_masks=0,2:1,1:0,3:1
--bia_dt=undef,f32 --bia_mask=2
Expand All @@ -28,13 +28,13 @@

# test any
--reset
--dt=bf16:bf16:f32,bf16
--dt=bf16:bf16:f32,bf16,f32:bf16:f32
--stag=ab,ba,any --wtag=ab,ba,any --dtag=ab,any
--batch=shapes_2d

# 3d
--reset
--dt=bf16:bf16:f32,bf16
--dt=bf16:bf16:f32,bf16,f32:bf16:f32
--stag=abc,acb --wtag=abc,acb --dtag=abc
--bia_dt=undef,f32 --bia_mask=4,6

Expand Down

0 comments on commit 33a89d9

Please sign in to comment.