From 5ebeba259e06d28b74ad4114074cc3ee1c53fa0a Mon Sep 17 00:00:00 2001 From: Nikolai Shchegolev Date: Mon, 6 Feb 2023 11:46:09 +0400 Subject: [PATCH] [CPU] I64 initial support. --- include/oneapi/dnnl/dnnl.hpp | 2 + include/oneapi/dnnl/dnnl_common_types.h | 2 + src/common/c_types_map.hpp | 1 + src/common/dnnl_traits.hpp | 8 + src/common/memory_zero_pad.cpp | 2 + src/common/type_helpers.hpp | 3 +- src/cpu/reorder/cpu_reorder.cpp | 1 + src/cpu/reorder/cpu_reorder.hpp | 1 + src/cpu/reorder/cpu_reorder_regular_s64.cpp | 61 +++++ src/cpu/x64/jit_generator.hpp | 258 +++++++++++++++++++- 10 files changed, 334 insertions(+), 5 deletions(-) create mode 100644 src/cpu/reorder/cpu_reorder_regular_s64.cpp diff --git a/include/oneapi/dnnl/dnnl.hpp b/include/oneapi/dnnl/dnnl.hpp index 157265d0699..d927a15084a 100644 --- a/include/oneapi/dnnl/dnnl.hpp +++ b/include/oneapi/dnnl/dnnl.hpp @@ -864,6 +864,8 @@ struct memory : public handle { f64 = dnnl_f64, /// 32-bit signed integer. s32 = dnnl_s32, + /// 64-bit signed integer. + s64 = dnnl_s64, /// 8-bit signed integer. s8 = dnnl_s8, /// 8-bit unsigned integer. diff --git a/include/oneapi/dnnl/dnnl_common_types.h b/include/oneapi/dnnl/dnnl_common_types.h index 8b80b7d0a60..8df682237ba 100644 --- a/include/oneapi/dnnl/dnnl_common_types.h +++ b/include/oneapi/dnnl/dnnl_common_types.h @@ -94,6 +94,8 @@ typedef enum { dnnl_boolean = 8, /// 1-bit integer. dnnl_bin = 9, + /// 64-bit signed integer. + dnnl_s64 = 10, /// Parameter to allow internal only data_types without undefined behavior. /// This parameter is chosen to be valid for so long as sizeof(int) >= 2. dnnl_data_type_max = 0x7fff, diff --git a/src/common/c_types_map.hpp b/src/common/c_types_map.hpp index 4e8651addb5..41808120fd9 100644 --- a/src/common/c_types_map.hpp +++ b/src/common/c_types_map.hpp @@ -160,6 +160,7 @@ const data_type_t bf16 = dnnl_bf16; const data_type_t f32 = dnnl_f32; const data_type_t f64 = dnnl_f64; const data_type_t s32 = dnnl_s32; +const data_type_t s64 = dnnl_s64; const data_type_t s8 = dnnl_s8; const data_type_t u8 = dnnl_u8; const data_type_t boolean = dnnl_boolean; diff --git a/src/common/dnnl_traits.hpp b/src/common/dnnl_traits.hpp index 971542a7178..acbe200dcb9 100644 --- a/src/common/dnnl_traits.hpp +++ b/src/common/dnnl_traits.hpp @@ -63,6 +63,10 @@ struct prec_traits { typedef int32_t type; }; template <> +struct prec_traits { + typedef int64_t type; +}; +template <> struct prec_traits { typedef int8_t type; }; @@ -96,6 +100,10 @@ struct data_traits { static constexpr data_type_t data_type = data_type::s32; }; template <> +struct data_traits { + static constexpr data_type_t data_type = data_type::s64; +}; +template <> struct data_traits { static constexpr data_type_t data_type = data_type::s8; }; diff --git a/src/common/memory_zero_pad.cpp b/src/common/memory_zero_pad.cpp index de3a84f9100..9c6351e5f7b 100644 --- a/src/common/memory_zero_pad.cpp +++ b/src/common/memory_zero_pad.cpp @@ -285,7 +285,9 @@ static status_t zero_pad(const memory_t *memory, const exec_ctx_t &ctx) { case f16: return typed_zero_pad(memory, ctx); case bf16: return typed_zero_pad(memory, ctx); case f32: return typed_zero_pad(memory, ctx); + case f64: return typed_zero_pad(memory, ctx); case s32: return typed_zero_pad(memory, ctx); + case s64: return typed_zero_pad(memory, ctx); case s8: return typed_zero_pad(memory, ctx); case u8: return typed_zero_pad(memory, ctx); case bin: return typed_zero_pad(memory, ctx); diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 3e17667603e..8a8ea428609 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -89,6 +89,7 @@ inline size_t data_type_size(data_type_t data_type) { case f32: return sizeof(prec_traits::type); case f64: return sizeof(prec_traits::type); case s32: return sizeof(prec_traits::type); + case s64: return sizeof(prec_traits::type); case s8: return sizeof(prec_traits::type); case u8: return sizeof(prec_traits::type); case boolean: return sizeof(prec_traits::type); @@ -948,7 +949,7 @@ inline bool memory_desc_sanity_check(int ndims, const dims_t dims, if (ndims == 0) return true; bool ok = dims != nullptr && 0 < ndims && ndims <= DNNL_MAX_NDIMS - && utils::one_of(data_type, f16, bf16, f32, f64, s32, s8, u8, bin); + && utils::one_of(data_type, f16, bf16, f32, f64, s32, s64, s8, u8, bin); if (!ok) return false; bool has_runtime_dims = false; diff --git a/src/cpu/reorder/cpu_reorder.cpp b/src/cpu/reorder/cpu_reorder.cpp index 44007099608..13185c29dd6 100644 --- a/src/cpu/reorder/cpu_reorder.cpp +++ b/src/cpu/reorder/cpu_reorder.cpp @@ -34,6 +34,7 @@ regular_impl_list_map() { {{f32, bin, 0}, ®ular_f32_bin_impl_list_map()}, {{bf16, data_type::undef, 0}, ®ular_bf16_impl_list_map()}, {{f16, data_type::undef, 0}, ®ular_f16_impl_list_map()}, + {{s64, data_type::undef, 0}, ®ular_s64_impl_list_map()}, {{s32, data_type::undef, 0}, ®ular_s32_impl_list_map()}, {{s8, data_type::undef, 0}, ®ular_s8_impl_list_map()}, {{u8, data_type::undef, 0}, ®ular_u8_impl_list_map()}, diff --git a/src/cpu/reorder/cpu_reorder.hpp b/src/cpu/reorder/cpu_reorder.hpp index 7f78df807cf..0344afa5542 100644 --- a/src/cpu/reorder/cpu_reorder.hpp +++ b/src/cpu/reorder/cpu_reorder.hpp @@ -76,6 +76,7 @@ extern const impl_list_map_t ®ular_f32_u8_impl_list_map(); extern const impl_list_map_t ®ular_f32_bin_impl_list_map(); extern const impl_list_map_t ®ular_bf16_impl_list_map(); extern const impl_list_map_t ®ular_f16_impl_list_map(); +extern const impl_list_map_t ®ular_s64_impl_list_map(); extern const impl_list_map_t ®ular_s32_impl_list_map(); extern const impl_list_map_t ®ular_s8_impl_list_map(); extern const impl_list_map_t ®ular_u8_impl_list_map(); diff --git a/src/cpu/reorder/cpu_reorder_regular_s64.cpp b/src/cpu/reorder/cpu_reorder_regular_s64.cpp new file mode 100644 index 00000000000..b9f52be09e7 --- /dev/null +++ b/src/cpu/reorder/cpu_reorder_regular_s64.cpp @@ -0,0 +1,61 @@ +/******************************************************************************* +* Copyright 2020-2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/reorder/cpu_reorder.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { + +// clang-format off + +const impl_list_map_t ®ular_s64_impl_list_map() { + static const impl_list_map_t the_map = REG_REORDER_P({ + // s32 -> + {{s64, data_type::undef, 0}, { + REG_FAST_DIRECT_COPY(s64, f32) + REG_FAST_DIRECT_COPY(s64, s64) + REG_FAST_DIRECT_COPY(s64, s32) + REG_FAST_DIRECT_COPY(s64, s8) + REG_FAST_DIRECT_COPY(s64, u8) + + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_blk_reorder_t)) + DNNL_X64_ONLY(CPU_REORDER_INSTANCE(x64_jit_uni_reorder_t)) + + DNNL_AARCH64_ONLY(CPU_REORDER_INSTANCE(aarch64_jit_uni_reorder_t)) + + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s64, any, f32, nChw16c)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s64, any, s32, nChw16c)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s64, any, s8, nChw16c)) + DNNL_NON_X64_ONLY(REG_SR_BIDIR(s64, any, u8, nChw16c)) + + REG_SR(s64, any, f32, any, fmt_order_any, spec_reference) + REG_SR(s64, any, s64, any, fmt_order_any, spec_reference) + REG_SR(s64, any, s32, any, fmt_order_any, spec_reference) + REG_SR(s64, any, s8, any, fmt_order_any, spec_reference) + REG_SR(s64, any, u8, any, fmt_order_any, spec_reference) + + nullptr, + }}, + }); + return the_map; +} + +// clang-format on + +} // namespace cpu +} // namespace impl +} // namespace dnnl diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index 5baca44a153..936a190541a 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -584,6 +584,28 @@ class jit_generator : public Xbyak::MmapAllocator, } } + void uni_vbroadcastsd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vmovsd(x, x, op); + vshufpd(x, x, x, 0x0); + } else { + movsd(x, op); + shufpd(x, x, 0x0); + } + } + void uni_vbroadcastsd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { + if (op.isMEM() || is_valid_isa(avx2)) { + vbroadcastsd(x, op); + } else { + Xbyak::Xmm t(x.getIdx()); + if (t.getIdx() != op.getIdx()) { + movsd(t, op); + } + vinsertf128(x, x, t, 1); + vshufpd(x, x, x, 0); + } + } + void uni_vpbroadcastb(const Xbyak::Ymm &x, const Xbyak::Reg8 &r) { if (is_valid_isa(avx512_core)) vpbroadcastb(x, r); // broadcast reg32 directly @@ -639,7 +661,9 @@ class jit_generator : public Xbyak::MmapAllocator, if (is_valid_isa(avx)) vshufps(x1, x2, op, imm); else { - movups(x1, x2); + if (x1.getIdx() != x2.getIdx()) { + movups(x1, x2); + } shufps(x1, op, imm); } } @@ -687,7 +711,9 @@ class jit_generator : public Xbyak::MmapAllocator, if (is_valid_isa(avx)) vdivps(x, op1, op2); else { - assert(x.isEqualIfNotInherited(op1)); + if (x.getIdx() != op1.getIdx()) { + movups(x, op1); + } divps(x, op2); } } @@ -722,6 +748,19 @@ class jit_generator : public Xbyak::MmapAllocator, vdivps(x, op1, op2); } + void uni_vdivpd(const Xbyak::Xmm& x, + const Xbyak::Operand& op1, + const Xbyak::Operand& op2) { + if (is_valid_isa(x64::avx)) { + vdivpd(x, op1, op2); + } else { + if (x.getIdx() != op1.getIdx()) { + movupd(x, op1); + } + divpd(x, op2); + } + } + void uni_vaddps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) @@ -735,6 +774,18 @@ class jit_generator : public Xbyak::MmapAllocator, const Xbyak::Operand &op2) { vaddps(x, op1, op2); } + + void uni_vaddpd(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { + if (is_valid_isa(avx)) { + vaddpd(x, op1, op2); + } else { + if (x.getIdx() != op1.getIdx()) { + movups(x, op1); + } + addpd(x, op2); + } + } + void uni_vaddss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) @@ -797,12 +848,30 @@ class jit_generator : public Xbyak::MmapAllocator, vpsignd(x1, x2, op); } + void uni_vpsubq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpsubq(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + movups(x1, x2); + } + psubq(x1, op); + } + } + void uni_vpsubq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, + const Xbyak::Operand &op) { + vpsubq(x1, x2, op); + } + void uni_vpsubd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { if (is_valid_isa(avx)) vpsubd(x1, x2, op); else { - assert(x1.getIdx() == x2.getIdx()); + if (x1.getIdx() != x2.getIdx()) { + movups(x1, x2); + } psubd(x1, op); } } @@ -883,6 +952,19 @@ class jit_generator : public Xbyak::MmapAllocator, vsubps(x, op1, op2); } + void uni_vsubpd(const Xbyak::Xmm& x, + const Xbyak::Operand& op1, + const Xbyak::Operand& op2) { + if (is_valid_isa(x64::avx)) { + vsubpd(x, op1, op2); + } else { + if (x.getIdx() != op1.getIdx()) { + movups(x, op1); + } + subpd(x, op2); + } + } + void uni_vaddsubps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) { @@ -909,6 +991,17 @@ class jit_generator : public Xbyak::MmapAllocator, vpmulld(x1, x2, op); } + void uni_vpmuludq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpmuludq(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + movdqa(x1, x2); + } + pmuludq(x1, op); + } + } + void uni_vmulps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) @@ -939,6 +1032,19 @@ class jit_generator : public Xbyak::MmapAllocator, vmulps(x, op1, op2); } + void uni_vmulpd(const Xbyak::Xmm& x, + const Xbyak::Operand& op1, + const Xbyak::Operand& op2) { + if (is_valid_isa(x64::avx)) { + vmulpd(x, op1, op2); + } else { + if (x.getIdx() != op1.getIdx()) { + movupd(x, op1); + } + mulpd(x, op2); + } + } + void uni_vmulss(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) @@ -1320,6 +1426,28 @@ class jit_generator : public Xbyak::MmapAllocator, vsqrtps(x, op); } + void uni_vsqrtpd(const Xbyak::Xmm &vmm_dst, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vsqrtpd(vmm_dst, op); + } else { + sqrtpd(vmm_dst, op); + } + } + + void uni_vpaddq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpaddq(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + movdqa(x1, x2); + } + paddq(x1, op); + } + } + void uni_vpaddq(const Xbyak::Ymm &x1, const Xbyak::Ymm &x2, const Xbyak::Operand &op) { + vpaddq(x1, x2, op); + } + void uni_vpaddd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { if (is_valid_isa(avx)) @@ -1455,6 +1583,17 @@ class jit_generator : public Xbyak::MmapAllocator, vpslld(x, op, imm); } + void uni_vpsllq(const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { + if (is_valid_isa(avx)) + vpsllq(x, op, imm); + else { + if (x.getIdx() != op.getIdx()) { + movups(x, op); + } + psllq(x, imm); + } + } + void uni_vpsrld( const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { if (is_valid_isa(avx)) @@ -1469,6 +1608,20 @@ class jit_generator : public Xbyak::MmapAllocator, vpsrld(x, op, imm); } + void uni_vpsrlq(const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { + if (is_valid_isa(avx)) { + vpsrlq(x, op, imm); + } else { + if (x.getIdx() != op.getIdx()) { + uni_vmovups(x, op); + } + psrlq(x, imm); + } + } + void uni_vpsrlq(const Xbyak::Ymm &x, const Xbyak::Operand &op, const int imm) { + vpsrlq(x, op, imm); + } + void uni_vmaxps(const Xbyak::Xmm &x, const Xbyak::Operand &op1, const Xbyak::Operand &op2) { if (is_valid_isa(avx)) @@ -1625,6 +1778,16 @@ class jit_generator : public Xbyak::MmapAllocator, vrndscaleps(x, op, imm & 0x3); } + void uni_vroundpd(const Xbyak::Xmm &x, const Xbyak::Operand &op, const int imm) { + if (is_valid_isa(avx512_core)) { + vrndscalepd(x, op, imm & 0x3); + } else if (is_valid_isa(avx)) { + vroundpd(x, op, imm); + } else { + roundpd(x, op, imm); + } + } + void uni_vcvtps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { if (is_valid_isa(avx)) vcvtps2dq(x, op); @@ -1641,11 +1804,18 @@ class jit_generator : public Xbyak::MmapAllocator, else cvtdq2ps(x, op); } - void uni_vcvtdq2ps(const Xbyak::Ymm &x, const Xbyak::Operand &op) { vcvtdq2ps(x, op); } + void uni_vcvtdq2pd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vcvtdq2pd(x, op); + } else { + cvtdq2pd(x, op); + } + } + void uni_vcvtph2psx(const Xbyak::Xmm &x, const Xbyak::Operand &op) { assert(is_valid_isa(avx2)); if (is_valid_isa(avx512_core_fp16)) @@ -1674,6 +1844,22 @@ class jit_generator : public Xbyak::MmapAllocator, vcvtps2ph(x1, x2, _op_mxcsr); } + void uni_vcvtpd2ps(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(x64::avx)) { + vcvtpd2ps(x, op); + } else { + cvtpd2ps(x, op); + } + } + + void uni_vcvtpd2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { + if (is_valid_isa(x64::avx)) { + vcvtpd2dq(x, op); + } else { + cvtpd2dq(x, op); + } + } + void uni_vcvttps2dq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { if (is_valid_isa(avx)) vcvttps2dq(x, op); @@ -2035,6 +2221,66 @@ class jit_generator : public Xbyak::MmapAllocator, vpcmpeqd(x1, x2, op); } + void uni_vpcmpeqq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpcmpeqq(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + uni_vmovups(x1, x2); + } + pcmpeqq(x1, op); + } + } + + void uni_vpcmpeqb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpcmpeqb(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + uni_vmovups(x1, x2); + } + pcmpeqb(x1, op); + } + } + + void uni_vpcmpgtd(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpcmpgtd(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + uni_vmovups(x1, x2); + } + pcmpgtd(x1, op); + } + } + + void uni_vpcmpgtq(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpcmpgtq(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + uni_vmovups(x1, x2); + } + pcmpgtd(x1, op); + } + } + + void uni_vpcmpgtb(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, + const Xbyak::Operand &op) { + if (is_valid_isa(avx)) { + vpcmpgtb(x1, x2, op); + } else { + if (x1.getIdx() != x2.getIdx()) { + uni_vmovups(x1, x2); + } + pcmpgtb(x1, op); + } + } + void uni_vpackusdw(const Xbyak::Xmm &x1, const Xbyak::Xmm &x2, const Xbyak::Operand &op) { if (is_valid_isa(avx)) @@ -2803,6 +3049,10 @@ class jit_generator : public Xbyak::MmapAllocator, return (jit_ker_) ? status::success : status::runtime_error; } + cpu_isa_t get_isa() { + return max_cpu_isa_; + } + private: const cpu_isa_t max_cpu_isa_; const Xbyak::uint8 *getCode() {