diff --git a/src/common/snippets/include/snippets/op/reshape.hpp b/src/common/snippets/include/snippets/op/reshape.hpp index b4e0c9233c73f0..d80a02ebc33c9a 100644 --- a/src/common/snippets/include/snippets/op/reshape.hpp +++ b/src/common/snippets/include/snippets/op/reshape.hpp @@ -32,6 +32,27 @@ class Reshape : public ov::op::Op { ov::PartialShape m_target_shape = {}; }; +/** + * @interface ReshapeWithOrder + * @brief ReshapeWithOrder reshapes input tensor shape by reqiured target order. + * The tensor data is not updated. + * Note: Order is stored in input PortDescriptor + * @ingroup snippets + */ +class ReshapeWithOrder : public ov::op::Op { +public: + OPENVINO_OP("ReshapeWithOrder", "SnippetsOpset"); + ReshapeWithOrder() = default; + ReshapeWithOrder(const Output& x, std::vector order); + + bool visit_attributes(AttributeVisitor& visitor) override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + void validate_and_infer_types() override; + +private: + void custom_constructor_validate_and_infer_types(std::vector order); +}; + } // namespace op } // namespace snippets } // namespace ov diff --git a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp index 1b91ea573ab1c4..c062fed338638d 100644 --- a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp +++ b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp @@ -82,5 +82,13 @@ class ReshapeShapeInfer : public IShapeInferSnippets { explicit ReshapeShapeInfer(const std::shared_ptr& n); Result infer(const std::vector& input_shapes) override; }; + +class ReshapeWithOrderShapeInfer : public IShapeInferSnippets { + std::vector m_target_order {}; +public: + explicit ReshapeWithOrderShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; +}; + } // namespace snippets } // namespace ov diff --git a/src/common/snippets/include/snippets/snippets_isa_tbl.hpp b/src/common/snippets/include/snippets/snippets_isa_tbl.hpp index 9b207b09fe411f..5c5e0f3701ad42 100644 --- a/src/common/snippets/include/snippets/snippets_isa_tbl.hpp +++ b/src/common/snippets/include/snippets/snippets_isa_tbl.hpp @@ -17,6 +17,7 @@ OV_OP(LoopEnd, ov::snippets::op) OV_OP(Brgemm, ov::snippets::op) OV_OP(BroadcastLoad, ov::snippets::op) OV_OP(Reshape, ov::snippets::op) +OV_OP(ReshapeWithOrder, ov::snippets::op) OV_OP(Store, ov::snippets::op) diff --git a/src/common/snippets/include/snippets/utils/utils.hpp b/src/common/snippets/include/snippets/utils/utils.hpp index ff4646f24d03b7..0569a230e91f32 100644 --- a/src/common/snippets/include/snippets/utils/utils.hpp +++ b/src/common/snippets/include/snippets/utils/utils.hpp @@ -290,13 +290,26 @@ std::shared_ptr get_leaf_node_of_first_child_shape_infer_seq(const std std::shared_ptr get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr& start_node); /** - * * @param Get stride of input/output dimension * @param expr_port target port that contains shape and layout info * @param idx index of the target dimension starting from the shape's end (default = 1) */ int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx = 1); +/** + * @brief Get stride of input dimension + * @param shape target shape + * @param layout target layout + * @param idx index of the target dimension starting from the shape's end (default = 1) + */ +int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1); +/** + * @brief Get stride of output dimension + * @param shape target shape + * @param layout target layout + * @param idx index of the target dimension starting from the shape's end (default = 1) + */ +int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1); /** * @brief Traverses path starting from "expr", and calls "func" for each expression. diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index d059ddd94d5724..7869b4427d579d 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -77,6 +77,7 @@ RegType Generator::get_op_out_reg_type(const ov::Output& out) const { std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || std::dynamic_pointer_cast(op) #ifdef SNIPPETS_DEBUG_CAPS || std::dynamic_pointer_cast(op) diff --git a/src/common/snippets/src/op/reshape.cpp b/src/common/snippets/src/op/reshape.cpp index 72823d2815cdbf..ae7887e558b5f2 100644 --- a/src/common/snippets/src/op/reshape.cpp +++ b/src/common/snippets/src/op/reshape.cpp @@ -11,6 +11,7 @@ namespace ov { namespace snippets { namespace op { + Reshape::Reshape(const Output& arg, ov::PartialShape target_shape) : Op({arg}), m_target_shape(std::move(target_shape)) { constructor_validate_and_infer_types(); @@ -38,6 +39,46 @@ const ov::PartialShape& Reshape::get_target_shape() const { void Reshape::set_target_shape(ov::PartialShape shape) { m_target_shape = std::move(shape); } + +ReshapeWithOrder::ReshapeWithOrder(const Output& arg, std::vector order) + : Op({arg}) { + custom_constructor_validate_and_infer_types(std::move(order)); +} + +void ReshapeWithOrder::custom_constructor_validate_and_infer_types(std::vector order) { + INTERNAL_OP_SCOPE(ReshapeWithOrder_constructor_validate_and_infer_types); + + const auto& input_pshape = get_input_partial_shape(0); + OPENVINO_ASSERT(input_pshape.rank().is_static() && input_pshape.size() == order.size(), + "Incompatible shape and order sizes"); + + // During ctor call, ReshapeWithOrder doesn't know his port descriptors. + // So we use explicit layouts from parameters + set_output_type(0, get_input_element_type(0), ov::snippets::utils::get_planar_pshape(input_pshape, order)); +} + +void ReshapeWithOrder::validate_and_infer_types() { + const auto& input_pshape = get_input_partial_shape(0); + const auto order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(); + OPENVINO_ASSERT(input_pshape.rank().is_static() && input_pshape.size() == order.size(), + "Incompatible shape and order sizes"); + const auto output_pshape = utils::get_planar_pshape(get_input_partial_shape(0), order); + set_output_type(0, get_input_element_type(0), output_pshape); +} + +std::shared_ptr ReshapeWithOrder::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(ReshapeWithOrder); + check_new_args_count(this, new_args); + const auto& order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(); + return std::make_shared(new_args.at(0), order); +} + +bool ReshapeWithOrder::visit_attributes(AttributeVisitor& visitor) { + auto order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(); + visitor.on_attribute("target_order", order); + return true; +} + }// namespace op }// namespace snippets }// namespace ov \ No newline at end of file diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 98e3392a65e1e2..25934829b80e00 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -96,6 +96,7 @@ auto Subgraph::is_domain_sensitive_op(const std::shared_ptr& op) -> bo auto Subgraph::is_shape_infer_op(const std::shared_ptr& op) -> bool { return ov::is_type(op) || + ov::is_type(op) || ov::is_type(op); } diff --git a/src/common/snippets/src/runtime_configurator.cpp b/src/common/snippets/src/runtime_configurator.cpp index 06beb8db94ae3d..4ddb4c19ea5a32 100644 --- a/src/common/snippets/src/runtime_configurator.cpp +++ b/src/common/snippets/src/runtime_configurator.cpp @@ -118,7 +118,23 @@ void RuntimeConfigurator::init_data_info(const lowered::LinearIRCPtr& linear_ir) // input->shape changing ops->load PortDescriptorPtr desc = nullptr; const auto& shape_infer_seq = utils::get_first_child_shape_infer_expr_seq(param); - const auto& mem_desc_expr = shape_infer_seq.empty() ? param : shape_infer_seq.back(); + ExpressionPtr mem_desc_expr = param; + if (!shape_infer_seq.empty()) { + // If there is ReshapeWithOrder, we should take its desc because it affects on shape by target order + const auto& reordered_reshape_it = std::find_if(shape_infer_seq.cbegin(), shape_infer_seq.cend(), + [](const ExpressionPtr& expr) { + return ov::is_type(expr->get_node()); + }); + if (reordered_reshape_it != shape_infer_seq.cend()) { + const auto& reshape = *reordered_reshape_it; + const auto& etype = reshape->get_node()->get_output_element_type(0); + update_io_parameters(reshape->get_input_port_descriptor(0), etype); + continue; + } + + mem_desc_expr = shape_infer_seq.back(); + } + auto consumer_inputs = mem_desc_expr->get_output_port_connector(0)->get_consumers(); for (const auto& child_input : consumer_inputs) { const auto ma = std::dynamic_pointer_cast(child_input.get_expr()->get_node()); @@ -127,6 +143,7 @@ void RuntimeConfigurator::init_data_info(const lowered::LinearIRCPtr& linear_ir) break; } } + OPENVINO_ASSERT(desc, "Descriptor is missed!"); const auto& etype = mem_desc_expr->get_node()->get_output_element_type(0); update_io_parameters(desc, etype); } diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp index a3e3d9652c0ac8..417996ae2a5f31 100644 --- a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -245,5 +245,16 @@ Result ReshapeShapeInfer::infer(const std::vector& input_shapes) return {{target_shape}, ShapeInferStatus::success}; } +ReshapeWithOrderShapeInfer::ReshapeWithOrderShapeInfer(const std::shared_ptr& n) { + const auto& reshape = as_type_ptr(n); + OPENVINO_ASSERT(reshape, "Invalid node passed to ReshapeWithOrderShapeInfer."); + m_target_order = lowered::PortDescriptorUtils::get_port_descriptor_ptr(reshape->input(0))->get_layout(); +} + +Result ReshapeWithOrderShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 1, "Invalid number of shapes is passed in ReshapeWithOrderShapeInfer"); + return {{ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_target_order)}, ShapeInferStatus::success}; +} + } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/shape_inference/shape_inference.cpp b/src/common/snippets/src/shape_inference/shape_inference.cpp index 76a4c491c66983..017567ea86bd55 100644 --- a/src/common/snippets/src/shape_inference/shape_inference.cpp +++ b/src/common/snippets/src/shape_inference/shape_inference.cpp @@ -58,6 +58,7 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry SHAPE_INFER_PREDEFINED(op::KernelDynamic, EmptyShapeInfer), SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer), SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Reshape, ReshapeShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReshapeWithOrder, ReshapeWithOrderShapeInfer), SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer), SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer), SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::ReduceMax, ReduceShapeInfer), diff --git a/src/common/snippets/src/utils/utils.cpp b/src/common/snippets/src/utils/utils.cpp index e7381fe6754758..249970b65baa5d 100644 --- a/src/common/snippets/src/utils/utils.cpp +++ b/src/common/snippets/src/utils/utils.cpp @@ -317,14 +317,21 @@ std::shared_ptr get_leaf_node_of_first_parent_shape_infer_seq(const st } int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx) { - size_t dim_idx = 0; + const auto& shape = expr_port.get_descriptor_ptr()->get_shape(); const auto& layout = expr_port.get_descriptor_ptr()->get_layout(); switch (expr_port.get_type()) { - case lowered::ExpressionPort::Input: dim_idx = utils::get_input_dim_idx(layout, idx); break; - case lowered::ExpressionPort::Output: dim_idx = utils::get_output_dim_idx(layout, idx); break; - default: OPENVINO_THROW("Unsupported expression port type!"); + case lowered::ExpressionPort::Input: return get_dim_in_stride(shape, layout, idx); + case lowered::ExpressionPort::Output: return get_dim_out_stride(shape, layout, idx); } - return get_stride(dim_idx, expr_port.get_descriptor_ptr()->get_shape()); + OPENVINO_THROW("Unsupported expression port type!"); +} + +int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) { + return get_stride(utils::get_input_dim_idx(layout, idx), shape); +} + +int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) { + return get_stride(utils::get_output_dim_idx(layout, idx), shape); } void visit_path(const lowered::ExpressionPtr& expr, diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index 65741d7031d289..0971e9e69a661f 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -39,12 +39,45 @@ std::string CPURuntimeConfig::to_string() const { } #endif -CPURuntimeConfigurator::CPURuntimeConfigurator() - : ov::snippets::RuntimeConfigurator(std::make_shared()) {} +#ifndef OPENVINO_ARCH_ARM64 + +CPURuntimeConfig::RepackedInput::RepackedInput(std::shared_ptr kernel, + CpuBlockedMemoryDescPtr desc, + VectorDims in_offsets, + VectorDims out_offsets) + : m_kernel(std::move(kernel)), + m_desc(std::move(desc)), + m_in_offsets(std::move(in_offsets)), + m_out_offsets(std::move(out_offsets)) { + OPENVINO_ASSERT(m_in_offsets.size() == m_out_offsets.size(), "Incorrect size of offsets"); + OPENVINO_ASSERT(m_desc, "Descriptor is empty"); +} + +const CpuBlockedMemoryDescPtr& CPURuntimeConfig::RepackedInput::desc() const { + return m_desc; +} + +const std::shared_ptr& CPURuntimeConfig::RepackedInput::kernel() const { + return m_kernel; +} + +const VectorDims& CPURuntimeConfig::RepackedInput::in_offsets() const { + return m_in_offsets; +} + +const VectorDims& CPURuntimeConfig::RepackedInput::out_offsets() const { + return m_out_offsets; +} + +#endif // OPENVINO_ARCH_ARM64 + +CPURuntimeConfigurator::CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache) + : ov::snippets::RuntimeConfigurator(std::make_shared()), + compiled_kernel_cache(std::move(cache)) {} void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) { RuntimeConfigurator::initialization(linear_ir); -#ifndef OPENVINO_ARCH_ARM64 +#ifdef OPENVINO_ARCH_X86_64 RuntimeOptimizer::register_if_applicable(m_intermediate_optimizers, linear_ir, this); RuntimeOptimizer::register_if_applicable(m_final_optimizers, linear_ir, this); #endif diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp index 1706670ce870d1..abec42bbbe0abb 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp @@ -5,6 +5,12 @@ #pragma once #include "emitters/snippets/jit_snippets_call_args.hpp" + +#ifndef OPENVINO_ARCH_ARM64 +# include "emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp" +#endif + +#include "cache/multi_cache.h" #include "memory_desc/cpu_blocked_memory_desc.h" #include "snippets/lowered/port_descriptor.hpp" #include "snippets/runtime_configurator.hpp" @@ -21,13 +27,41 @@ class CPURuntimeConfig : public ov::snippets::RuntimeConfig { std::string to_string() const override; #endif +#ifndef OPENVINO_ARCH_ARM64 + struct RepackedInput { + RepackedInput() = default; + RepackedInput(std::shared_ptr kernel, + CpuBlockedMemoryDescPtr desc, + VectorDims in_offsets, + VectorDims out_offsets); + + const std::shared_ptr& kernel() const; + const CpuBlockedMemoryDescPtr& desc() const; + const VectorDims& in_offsets() const; + const VectorDims& out_offsets() const; + + private: + std::shared_ptr m_kernel{nullptr}; + CpuBlockedMemoryDescPtr m_desc{nullptr}; + VectorDims m_in_offsets{}; + VectorDims m_out_offsets{}; + }; + std::unordered_map repacked_inputs = {}; + + enum class RepackingImplType { + NONE, // no kernel-outside repacking + IN_PARALLEL, // should be executed in parallel_nt by each thread + SEPARATE, // should be separathy from kernel executed + }; + RepackingImplType repacking_impl_type = RepackingImplType::NONE; +#endif // OPENVINO_ARCH_ARM64 + std::vector loop_args = {}; - std::unordered_map m_in_requested_descs = {}; }; class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { public: - CPURuntimeConfigurator(); + CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache = {}); /** * @brief Calculate Loop parameters of Loop emitters and update these values in CPURuntimeConfig @@ -35,6 +69,10 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { */ void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const; + const ov::intel_cpu::MultiCacheWeakPtr& get_cache() const { + return compiled_kernel_cache; + } + protected: void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override; @@ -42,6 +80,8 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { void initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; static const size_t rank6D; + + ov::intel_cpu::MultiCacheWeakPtr compiled_kernel_cache; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index 39e384837856a1..7835f17adb97be 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -165,7 +165,7 @@ class jit_snippet : public dnnl::impl::cpu::x64::jit_generator { intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::intel_cpu::MultiCacheWeakPtr cache) - : TargetMachine(std::make_shared()), + : TargetMachine(std::make_shared(cache)), h(new jit_snippet()), isa(host_isa), compiled_kernel_cache(std::move(cache)) { @@ -177,6 +177,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[snippets::op::RankNormalization::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::Reshape::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); + jitters[snippets::op::ReshapeWithOrder::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); jitters[snippets::op::LoadReshape::get_type_info_static()] = diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index 6df658d8d72d0c..861b9779c25533 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -21,15 +21,6 @@ using namespace ov::snippets::utils; namespace ov { namespace intel_cpu { -namespace { -bool get_is_transposed(const ov::snippets::lowered::ExpressionPtr& expr) { - const auto& layout = expr->get_input_port_descriptor(0)->get_layout(); - const auto is_transposed = !layout.empty() && layout.back() != layout.size() - 1; - OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(is_transposed, (layout[layout.size() - 2] == layout.size() - 1)), - "supports only N dim placed as last or pre last dimension"); - return is_transposed; -} -} // namespace jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t isa, @@ -50,7 +41,7 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, const auto& src_prc = brgemm_repack->get_src_element_type(); const auto& wei_prc = brgemm_repack->get_input_element_type(0); const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc); - const auto is_transposed = get_is_transposed(expr); + const auto is_transposed = BrgemmCopyB::is_transposed(expr->get_input_port_descriptor(0)->get_layout()); const auto brgemm_type = get_brgemm_type(src_prc, is_transposed); const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type)); m_with_comp = with_compensations(brgemm_type); diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.cpp new file mode 100644 index 00000000000000..08c2a977f33c43 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.cpp @@ -0,0 +1,82 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/aarch64/subgraph.hpp" + +#include "snippets/op/subgraph.hpp" + + +namespace ov { +namespace intel_cpu { + +SubgraphExecutor::SubgraphExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache) + : SubgraphBaseExecutor(snippet_config, + snippet_attrs, + snippet, + start_offset_in, + start_offset_out, + allocator, + kernel_cache) { + m_buffer_scratchpad = allocator(m_internal_buffer_size); +} + +void SubgraphStaticExecutor::exec_impl(const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + const auto& callable = m_schedule->get_callable(); + + auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, inMemPtrs, outMemPtrs, m_start_offset_in, m_start_offset_out, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + }; + auto caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + callable(&call_args, indexes.data()); + }; + + if (m_parallel_exec_domain.size() == rank6D) { + parallel_for6d(initializer, caller); + } else { + parallel_forNd(initializer, caller); + } +} + +void SubgraphDynamicSpecializedExecutor::exec_impl(const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + const auto& callable = m_schedule->get_callable(); + + OPENVINO_ASSERT(m_data_offsets.size() == inMemPtrs.size() + outMemPtrs.size(), "Incorrect data offset count!"); + OPENVINO_ASSERT(m_data_offsets.front().size() == m_parallel_exec_domain.size(), + "Data offsets with invalid ranks detected"); + + // Note: we need to reset KernelExecutorTable to the state that was recorded in the + // SubgraphDynamicSpecializedExecutor constructor because the table might've been used for other shapes + m_reset_exec_table_state(); + + std::vector src_ptrs; + std::vector dst_ptrs; + init_original_ptrs(inMemPtrs, outMemPtrs, src_ptrs, dst_ptrs, m_start_offset_in, m_start_offset_out); + + auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + }; + auto caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + update_ptrs(call_args, src_ptrs, dst_ptrs, indexes); + callable(&call_args); + }; + + if (m_parallel_exec_domain.size() == rank6D) { + parallel_for6d(initializer, caller); + } else { + parallel_forNd(initializer, caller); + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.hpp b/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.hpp new file mode 100644 index 00000000000000..47b06e8c69a99b --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/aarch64/subgraph.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "nodes/executors/subgraph.hpp" + +namespace ov { +namespace intel_cpu { + +class SubgraphExecutor : public SubgraphBaseExecutor { +public: + SubgraphExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache); +}; + +class SubgraphStaticExecutor : public SubgraphExecutor, public SubgraphStaticBaseExecutor { +public: + template + SubgraphStaticExecutor(const std::shared_ptr& snippet_config, Args... args) + : SubgraphExecutor(snippet_config, args...), + SubgraphStaticBaseExecutor() {} + + void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override; +}; + +class SubgraphDynamicSpecializedExecutor : public SubgraphExecutor, public SubgraphDynamicSpecializedBaseExecutor { +public: + template + SubgraphDynamicSpecializedExecutor(const std::shared_ptr& snippet_config, Args... args) + : SubgraphExecutor(snippet_config, args...), + SubgraphDynamicSpecializedBaseExecutor(snippet_config) {} + + void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/executors/subgraph.cpp new file mode 100644 index 00000000000000..34ae1449b56567 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/subgraph.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/subgraph.hpp" +#if defined(OPENVINO_ARCH_ARM64) +# include "emitters/snippets/aarch64/cpu_generator.hpp" +#else +# include "emitters/snippets/x64/cpu_generator.hpp" +#endif +#include "openvino/core/parallel.hpp" + +namespace ov { +namespace intel_cpu { + +SubgraphCodeGenerator::SubgraphCodeGenerator(const std::shared_ptr& snippet_attrs, + const std::shared_ptr& config) { + OPENVINO_ASSERT(snippet_attrs, "Subgraph attributes are empty!"); + OPENVINO_ASSERT(config, "Runtime Config is empty!"); + + jit_snippets_compile_args jcp; + jcp.data_offsets = config->io_data_offsets; + SubgraphBaseExecutor::init_parallel_domain(config, jcp.exec_domain); + schedule = + std::make_shared(snippet_attrs->snippet->generate(reinterpret_cast(&jcp))); +} + +SubgraphBaseExecutor::SubgraphBaseExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache) + : m_schedule(snippet->get()), + m_start_offset_in(start_offset_in), + m_start_offset_out(start_offset_out) { + OPENVINO_ASSERT(m_schedule, "Schedule is empty!"); + OPENVINO_ASSERT(snippet_config, "Runtime Config is empty!"); + init_parallel_domain(snippet_config, m_parallel_exec_domain); + + m_tensor_rank = snippet_config->tensor_rank; + m_harness_work_amount = std::accumulate(m_parallel_exec_domain.cbegin(), + m_parallel_exec_domain.cend(), + size_t(1), + std::multiplies()); + m_nthreads = std::min(parallel_get_max_threads(), static_cast(m_harness_work_amount)); + + m_buffer_scratchpad_size = snippet_config->buffer_scratchpad_size; + OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(m_buffer_scratchpad_size), + "Undefined buffer scratchpad size!"); + m_internal_buffer_size = static_cast(m_nthreads) * m_buffer_scratchpad_size; +} + +void SubgraphBaseExecutor::init_parallel_domain(const std::vector& master_shape, + size_t tensor_rank, + size_t tile_rank, + std::vector& domain) { + domain.resize(tensor_rank, 1); + std::fill(domain.begin(), domain.end(), 1); + std::copy(master_shape.cbegin(), + master_shape.cbegin() + (master_shape.size() - tile_rank), + domain.begin() + (tensor_rank - master_shape.size())); +} + +void SubgraphBaseExecutor::init_parallel_domain(const std::shared_ptr& snippet_config, + std::vector& domain) { + init_parallel_domain(snippet_config->master_shape, snippet_config->tensor_rank, snippet_config->tile_rank, domain); +} +void SubgraphBaseExecutor::parallel_for6d( + const std::function& initializer, + const std::function&, size_t)>& caller) { + const auto& dom = m_parallel_exec_domain; + + parallel_nt_static(m_nthreads, [&](const int ithr, const int nthr) { + jit_snippets_call_args call_args; + initializer(call_args, ithr); + + size_t start = 0, end = 0; + splitter(m_harness_work_amount, nthr, ithr, start, end); + + std::vector indexes{0, 0, 0, 0, 0}; + parallel_it_init(start, + indexes[0], + dom[0], + indexes[1], + dom[1], + indexes[2], + dom[2], + indexes[3], + dom[3], + indexes[4], + dom[4]); + for (size_t iwork = start; iwork < end; ++iwork) { + caller(call_args, indexes, ithr); + parallel_it_step(indexes[0], + dom[0], + indexes[1], + dom[1], + indexes[2], + dom[2], + indexes[3], + dom[3], + indexes[4], + dom[4]); + } + }); +} + +void SubgraphBaseExecutor::parallel_forNd( + const std::function& initializer, + const std::function&, size_t)>& caller) { + const auto& dom = m_parallel_exec_domain; + + parallel_nt_static(m_nthreads, [&](const int ithr, const int nthr) { + jit_snippets_call_args call_args; + initializer(call_args, ithr); + + size_t start = 0, end = 0; + splitter(m_harness_work_amount, nthr, ithr, start, end); + + std::vector indexes(dom.size() - 1, 0); + for (size_t iwork = start; iwork < end; ++iwork) { + size_t tmp = iwork; + for (ptrdiff_t j = static_cast(dom.size()) - 2; j >= 0; j--) { + indexes[j] = tmp % dom[j]; + tmp /= dom[j]; + } + + caller(call_args, indexes, ithr); + } + }); +} + +void SubgraphBaseExecutor::execute(const dnnl::stream& strm, + const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + exec_impl(inMemPtrs, outMemPtrs); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/subgraph.hpp b/src/plugins/intel_cpu/src/nodes/executors/subgraph.hpp new file mode 100644 index 00000000000000..f2dd48ab6788a6 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/subgraph.hpp @@ -0,0 +1,190 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "cpu_memory.h" +#include "emitters/snippets/cpu_runtime_configurator.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "snippets/generator.hpp" +#include "snippets/op/subgraph.hpp" + +namespace ov { +namespace intel_cpu { + +struct SubgraphAttrs { + // Local copy of subgraph node for canonization & code generation + std::shared_ptr snippet; + uint64_t bodyHash; + std::vector inMemOrders; + std::vector outMemOrders; + std::vector inMemPrecs; + std::vector outMemPrecs; +}; + +class SubgraphCodeGenerator { +public: + SubgraphCodeGenerator(const std::shared_ptr& snippet_attrs, + const std::shared_ptr& config); + + const std::shared_ptr& get() const { + return schedule; + } + +private: + std::shared_ptr schedule; +}; + +class SubgraphBaseExecutor { +public: + using BufferScratchpadAllocator = std::function; + + SubgraphBaseExecutor() = default; + SubgraphBaseExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache); + virtual ~SubgraphBaseExecutor() = default; + + virtual void execute(const dnnl::stream& strm, + const std::vector& inMemPtrs, + const std::vector& outMemPtrs); + + static void init_parallel_domain(const std::vector& master_shape, + size_t tensor_rank, + size_t tile_rank, + std::vector& domain); + static void init_parallel_domain(const std::shared_ptr& snippet_config, + std::vector& domain); + +protected: + virtual void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) = 0; + + virtual void parallel_for6d( + const std::function& initializer, + const std::function&, size_t)>& caller); + virtual void parallel_forNd( + const std::function& initializer, + const std::function&, size_t)>& caller); + + inline void update_scratchpad_ptr(void*& scratchpad_ptr, size_t ithr) const { + if (m_buffer_scratchpad_size > 0) + scratchpad_ptr = m_buffer_scratchpad->getDataAs() + ithr * m_buffer_scratchpad_size; + } + + using initializer_functor = std::function; + using call_functor = std::function&, size_t)>; + + std::shared_ptr m_schedule; + // Holds index of output used as in execution domain + // it should be compatible with a schedule's work size + std::vector m_parallel_exec_domain = {}; + size_t m_harness_work_amount = 0; + + // Buffer scratchpad + MemoryPtr m_buffer_scratchpad = nullptr; + size_t m_buffer_scratchpad_size = 0; + size_t m_internal_buffer_size = 0; + size_t m_tensor_rank = 0; + + const size_t rank6D = 6; + + // Count of threads for parallel_nt + int m_nthreads = 0; + + std::vector m_start_offset_in = {}; + std::vector m_start_offset_out = {}; +}; + +// Class for Subgraphs with static shapes +class SubgraphStaticBaseExecutor { +public: + SubgraphStaticBaseExecutor() = default; + virtual ~SubgraphStaticBaseExecutor() = default; + +protected: + typedef void (*kernel)(const void*, const void*); + + inline void init_call_args(jit_snippets_call_args& call_args, + const std::vector& srcMemPtrs, + const std::vector& dstMemPtrs, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + size_t ithr) { + for (size_t i = 0; i < srcMemPtrs.size(); i++) + call_args.src_ptrs[i] = srcMemPtrs[i]->getDataAs() + start_offset_in[i]; + + for (size_t i = 0; i < dstMemPtrs.size(); i++) + call_args.dst_ptrs[i] = dstMemPtrs[i]->getDataAs() + start_offset_out[i]; + } +}; + +// Specialized dynamic executor based on shape agnostic kernel for the specific input shapes +class SubgraphDynamicSpecializedBaseExecutor { +public: + SubgraphDynamicSpecializedBaseExecutor(const std::shared_ptr& snippet_config) { + m_buffer_offsets = snippet_config->buffer_cluster_offsets; + m_data_offsets = snippet_config->io_data_offsets; + m_loop_args = snippet_config->loop_args; + m_reset_exec_table_state = snippet_config->kernel_executor_table->get_state_reset(); + } + virtual ~SubgraphDynamicSpecializedBaseExecutor() = default; + +protected: + typedef void (*dynamic_kernel)(const void*); + + inline void init_call_args(jit_snippets_call_args& call_args, size_t ithr) { + call_args.register_loops(m_loop_args); + std::copy(m_buffer_offsets.cbegin(), m_buffer_offsets.cend(), call_args.buffer_offsets); + } + + inline void init_original_ptrs(const std::vector& srcMemPtrs, + const std::vector& dstMemPtrs, + std::vector& src_ptrs, + std::vector& dst_ptrs, + const std::vector& start_offset_in, + const std::vector& start_offset_out) { + const auto in_num = srcMemPtrs.size(); + const auto out_num = dstMemPtrs.size(); + + src_ptrs.resize(in_num, nullptr); + dst_ptrs.resize(out_num, nullptr); + + for (size_t i = 0; i < in_num; i++) + src_ptrs[i] = srcMemPtrs[i]->getDataAs() + start_offset_in[i]; + for (size_t i = 0; i < out_num; i++) + dst_ptrs[i] = dstMemPtrs[i]->getDataAs() + start_offset_out[i]; + } + + inline void update_ptrs(jit_snippets_call_args& call_args, + const std::vector& src_ptrs, + const std::vector& dst_ptrs, + const std::vector& indexes) const { + for (size_t i = 0; i < src_ptrs.size(); i++) { + auto i_ptr = src_ptrs[i]; + for (size_t j = 0; j < indexes.size(); j++) { + i_ptr += m_data_offsets[i][j] * indexes[j]; + } + call_args.src_ptrs[i] = i_ptr; + } + for (size_t i = 0; i < dst_ptrs.size(); i++) { + auto i_ptr = dst_ptrs[i]; + for (size_t j = 0; j < indexes.size(); j++) { + i_ptr += m_data_offsets[i + src_ptrs.size()][j] * indexes[j]; + } + call_args.dst_ptrs[i] = i_ptr; + } + } + + std::vector m_buffer_offsets = {}; + std::vector> m_data_offsets = {}; + std::vector m_loop_args = {}; + std::function m_reset_exec_table_state; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.cpp new file mode 100644 index 00000000000000..f0024707164c8d --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.cpp @@ -0,0 +1,309 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "nodes/executors/x64/subgraph.hpp" + +#include "emitters/snippets/x64/cpu_generator.hpp" +#include "openvino/core/parallel.hpp" +#include "snippets/op/subgraph.hpp" + +#if defined(__linux__) && defined(SNIPPETS_DEBUG_CAPS) +# include + +# include "emitters/snippets/x64/jit_segfault_detector_emitter.hpp" +std::mutex err_print_lock; +#endif + +namespace ov { +namespace intel_cpu { + +namespace { +inline void parallel4d_repacking(const BrgemmCopyBKernel* ker, + const VectorDims& dom, + const VectorDims& in_str, + const VectorDims& out_str, + const uint8_t* src, + uint8_t* dst) { + parallel_for4d(dom[0], dom[1], dom[2], dom[3], [&](size_t d0, size_t d1, size_t d2, size_t d3) { + BrgemmCopyBKernel::call_args args; + args.src = src + d0 * in_str[0] + d1 * in_str[1] + d2 * in_str[2] + d3 * in_str[3]; + args.tr_src = dst + d0 * out_str[0] + d1 * out_str[1] + d2 * out_str[2] + d3 * out_str[3]; + (*ker)(&args); + }); +}; +inline void parallelNd_repacking(const BrgemmCopyBKernel* ker, + const VectorDims& dom, + const VectorDims& in_str, + const VectorDims& out_str, + const uint8_t* src, + uint8_t* dst) { + const size_t batch = std::accumulate(dom.rbegin() + 2, dom.rend(), 1lu, std::multiplies()); + parallel_nt_static(0, [&](const int ithr, const int nthr) { + BrgemmCopyBKernel::call_args args; + size_t start = 0, end = 0; + splitter(batch, nthr, ithr, start, end); + for (size_t iwork = start; iwork < end; ++iwork) { + const uint8_t* src_u8 = src; + uint8_t* dst_u8 = dst; + size_t tmp = iwork; + for (ptrdiff_t j = static_cast(dom.size()) - 3; j >= 0; j--) { + auto idx = tmp % dom[j]; + tmp /= dom[j]; + + src_u8 += idx * in_str[j]; + dst_u8 += idx * out_str[j]; + } + args.src = src_u8; + args.tr_src = dst_u8; + (*ker)(&args); + } + }); +}; +} // namespace + +SubgraphExecutor::SubgraphExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache) + : SubgraphBaseExecutor(snippet_config, + snippet_attrs, + snippet, + start_offset_in, + start_offset_out, + allocator, + kernel_cache) { + m_repacking_impl_type = snippet_config->repacking_impl_type; + m_repacked_inputs = snippet_config->repacked_inputs; + + auto external_buffer_size = + std::accumulate(m_repacked_inputs.begin(), + m_repacked_inputs.end(), + size_t(0), + [](size_t sum, const std::pair& p) { + return sum + p.second.desc()->getCurrentMemSize(); + }); + + if (should_repacking_be_in_parallel()) { + // When external repacking is applied in parallel section, + // each thread should have own buffer to store repacked data + external_buffer_size *= m_nthreads; + + // To avoid extra overheads in runtime on vector creation, + // we initialize `repacked_offsets_by_threads` by default here + m_repacked_offsets_by_threads.resize(m_nthreads); + for (size_t i = 0; i < m_repacked_offsets_by_threads.size(); ++i) + clean_repacked_offsets(i); + + if (m_tensor_rank == rank6D) { + init_offset = [](const std::vector& offsets, const std::vector& indexes, size_t& offset) { + offset += offsets[0] * indexes[0] + offsets[1] * indexes[1] + offsets[2] * indexes[2] + + offsets[3] * indexes[3]; + }; + } else { + init_offset = [](const std::vector& offsets, const std::vector& indexes, size_t& offset) { + for (size_t j = 0; j < indexes.size(); j++) + offset += offsets[j] * indexes[j]; + }; + } + } + + m_buffer_scratchpad = allocator(m_internal_buffer_size + external_buffer_size); + +#if defined(__linux__) && defined(SNIPPETS_DEBUG_CAPS) + const auto target = std::dynamic_pointer_cast( + snippet_attrs->snippet->get_generator()->get_target_machine()); + enabled_segfault_detector = target && target->debug_config.enable_segfault_detector; +#endif +} + +#if defined(__linux__) && defined(SNIPPETS_DEBUG_CAPS) +void SubgraphExecutor::segfault_detector() { + if (enabled_segfault_detector) { + __sighandler_t signal_handler = [](int signal) { + std::lock_guard guard(err_print_lock); + if (auto segfault_detector_emitter = ov::intel_cpu::g_custom_segfault_handler->local()) + std::cout << segfault_detector_emitter->info() << std::endl; + auto tid = parallel_get_thread_num(); + OPENVINO_THROW("Segfault was caught by the signal handler in subgraph node execution on thread " + + std::to_string(tid)); + }; + struct sigaction new_handler {}; + new_handler.sa_handler = signal_handler; + sigaction(SIGSEGV, &new_handler, nullptr); + } +} +#endif + +std::vector SubgraphExecutor::separately_repack_inputs(const dnnl::stream& strm, + const std::vector& srcMemPtrs) { + auto reordered_in_ptrs = srcMemPtrs; + size_t offset = m_internal_buffer_size; + for (const auto& p : m_repacked_inputs) { + const auto in_idx = p.first; + const auto& repacked_input = p.second; + const auto& desc = repacked_input.desc(); + const void* data_ptr = m_buffer_scratchpad->getDataAs() + offset; + + OPENVINO_ASSERT(in_idx < srcMemPtrs.size(), "Incorrect index of input repacked mem ptr"); + const auto& src_mem = srcMemPtrs[in_idx]; + const auto& dst_mem = std::make_shared(strm.get_engine(), desc, data_ptr, false); + + const auto* src = src_mem->getDataAs() + m_start_offset_in[in_idx]; + auto* dst = dst_mem->getDataAs(); + + VectorDims dom; + const auto& shape = dst_mem->getShape().getDims(); + OPENVINO_ASSERT(shape.size() <= m_tensor_rank, "Unsupported shape rank of repacking data"); + init_parallel_domain(shape, m_tensor_rank, 2lu, dom); + + const auto& in_strides = repacked_input.in_offsets(); + const auto& out_strides = repacked_input.out_offsets(); + OPENVINO_ASSERT(everyone_is(m_tensor_rank, in_strides.size(), out_strides.size(), dom.size()), + "Unsupported shape rank of repacking data"); + + const auto& kernel = repacked_input.kernel(); + if (m_tensor_rank == rank6D) + parallel4d_repacking(kernel.get(), dom, in_strides, out_strides, src, dst); + else + parallelNd_repacking(kernel.get(), dom, in_strides, out_strides, src, dst); + + reordered_in_ptrs[in_idx] = dst_mem; + offset += desc->getCurrentMemSize(); + } + return reordered_in_ptrs; +} + +void SubgraphExecutor::in_parallel_repack_inputs(const std::vector& inMemPtrs, + const std::vector& indexes, + int ithr, + jit_snippets_call_args& call_args) { + size_t repacked_offset_idx = 0; + for (const auto& p : m_repacked_inputs) { + const auto& in_idx = p.first; + const auto& repacked_in = p.second; + + size_t src_offset = m_start_offset_in[in_idx]; + init_offset(repacked_in.in_offsets(), indexes, src_offset); + + auto* repacked_ptr = get_external_scratchpad_ptr(ithr, in_idx); + + auto& last_processed_src_offset = m_repacked_offsets_by_threads[ithr][repacked_offset_idx]; + if (src_offset != last_processed_src_offset) { + BrgemmCopyBKernel::call_args args; + args.src = inMemPtrs[in_idx]->getDataAs() + src_offset; + args.tr_src = repacked_ptr; + (*repacked_in.kernel())(&args); + + last_processed_src_offset = src_offset; + } + + call_args.src_ptrs[in_idx] = repacked_ptr; + ++repacked_offset_idx; + } +} + +void SubgraphExecutor::execute(const dnnl::stream& strm, + const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + if (should_repacking_be_separately()) { + exec_impl(separately_repack_inputs(strm, inMemPtrs), outMemPtrs); + return; + } + + exec_impl(inMemPtrs, outMemPtrs); +} + +void SubgraphStaticExecutor::exec_impl(const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + const auto& callable = m_schedule->get_callable(); + + initializer_functor initializer; + call_functor caller; + if (should_repacking_be_in_parallel()) { + initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, inMemPtrs, outMemPtrs, m_start_offset_in, m_start_offset_out, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + clean_repacked_offsets(ithr); + }; + caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + in_parallel_repack_inputs(inMemPtrs, indexes, ithr, call_args); + callable(&call_args, indexes.data()); + }; + } else { + initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, inMemPtrs, outMemPtrs, m_start_offset_in, m_start_offset_out, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + }; + caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + callable(&call_args, indexes.data()); + }; + } + +#if defined(__linux__) && defined(SNIPPETS_DEBUG_CAPS) + segfault_detector(); +#endif + + if (m_parallel_exec_domain.size() == rank6D) { + parallel_for6d(initializer, caller); + } else { + parallel_forNd(initializer, caller); + } +} + +void SubgraphDynamicSpecializedExecutor::exec_impl(const std::vector& inMemPtrs, + const std::vector& outMemPtrs) { + const auto& callable = m_schedule->get_callable(); + + OPENVINO_ASSERT(m_data_offsets.size() == inMemPtrs.size() + outMemPtrs.size(), "Incorrect data offset count!"); + OPENVINO_ASSERT(m_data_offsets.front().size() == m_parallel_exec_domain.size(), + "Data offsets with invalid ranks detected"); + + // Note: we need to reset KernelExecutorTable to the state that was recorded in the + // SubgraphDynamicSpecializedExecutor constructor because the table might've been used for other shapes + m_reset_exec_table_state(); + + std::vector src_ptrs; + std::vector dst_ptrs; + init_original_ptrs(inMemPtrs, outMemPtrs, src_ptrs, dst_ptrs, m_start_offset_in, m_start_offset_out); + + initializer_functor initializer; + call_functor caller; + if (should_repacking_be_in_parallel()) { + initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + clean_repacked_offsets(ithr); + }; + caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + update_ptrs(call_args, src_ptrs, dst_ptrs, indexes); + in_parallel_repack_inputs(inMemPtrs, indexes, ithr, call_args); + callable(&call_args); + }; + } else { + initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { + init_call_args(call_args, ithr); + update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); + }; + caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes, size_t ithr) { + update_ptrs(call_args, src_ptrs, dst_ptrs, indexes); + callable(&call_args); + }; + } + +#if defined(__linux__) && defined(SNIPPETS_DEBUG_CAPS) + segfault_detector(); +#endif + + if (m_parallel_exec_domain.size() == rank6D) { + parallel_for6d(initializer, caller); + } else { + parallel_forNd(initializer, caller); + } +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.hpp b/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.hpp new file mode 100644 index 00000000000000..fa0eb5f1583d2d --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/executors/x64/subgraph.hpp @@ -0,0 +1,95 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "nodes/executors/subgraph.hpp" + +namespace ov { +namespace intel_cpu { + +class SubgraphExecutor : public SubgraphBaseExecutor { +public: + SubgraphExecutor(const std::shared_ptr& snippet_config, + const std::shared_ptr& snippet_attrs, + const std::shared_ptr& snippet, + const std::vector& start_offset_in, + const std::vector& start_offset_out, + const BufferScratchpadAllocator& allocator, + const ov::intel_cpu::MultiCacheWeakPtr& kernel_cache); + + void execute(const dnnl::stream& strm, + const std::vector& inMemPtrs, + const std::vector& outMemPtrs) override; + +protected: + std::vector separately_repack_inputs(const dnnl::stream& strm, const std::vector& srcMemPtrs); + void in_parallel_repack_inputs(const std::vector& inMemPtrs, + const std::vector& indexes, + int ithr, + jit_snippets_call_args& call_args); + + inline void* get_external_scratchpad_ptr(size_t ithr, size_t idx) const { + if (m_repacked_inputs.empty()) + return nullptr; + + uint8_t* data_ptr = m_buffer_scratchpad->getDataAs() + m_internal_buffer_size; + for (const auto& p : m_repacked_inputs) { + const auto& desc = p.second.desc(); + const auto size = desc->getCurrentMemSize(); + if (p.first == idx) { + return data_ptr + ithr * size; + } + data_ptr += m_nthreads * size; + } + OPENVINO_THROW("External buffer pointer has not been found"); + } + + // [ Thread Index -> Index of input with repacking data - > last repacked src_offset ] + std::vector> m_repacked_offsets_by_threads = {}; + std::unordered_map m_repacked_inputs = {}; + + std::function&, const std::vector&, size_t&)> init_offset = {}; + + inline bool should_repacking_be_separately() const { + return m_repacking_impl_type == CPURuntimeConfig::RepackingImplType::SEPARATE; + } + inline bool should_repacking_be_in_parallel() const { + return m_repacking_impl_type == CPURuntimeConfig::RepackingImplType::IN_PARALLEL; + } + inline void clean_repacked_offsets(size_t ithr) { + m_repacked_offsets_by_threads[ithr].assign(m_repacked_inputs.size(), std::numeric_limits::max()); + } + +#ifdef SNIPPETS_DEBUG_CAPS + bool enabled_segfault_detector = false; + inline void segfault_detector(); +#endif + +private: + CPURuntimeConfig::RepackingImplType m_repacking_impl_type = CPURuntimeConfig::RepackingImplType::NONE; +}; + +class SubgraphStaticExecutor : public SubgraphExecutor, public SubgraphStaticBaseExecutor { +public: + template + SubgraphStaticExecutor(const std::shared_ptr& snippet_config, Args... args) + : SubgraphExecutor(snippet_config, args...), + SubgraphStaticBaseExecutor() {} + + void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override; +}; + +class SubgraphDynamicSpecializedExecutor : public SubgraphExecutor, public SubgraphDynamicSpecializedBaseExecutor { +public: + template + SubgraphDynamicSpecializedExecutor(const std::shared_ptr& snippet_config, Args... args) + : SubgraphExecutor(snippet_config, args...), + SubgraphDynamicSpecializedBaseExecutor(snippet_config) {} + + void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 2b0c7b55fb043d..5bd83e6f6c82a0 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -7,7 +7,6 @@ #include "dnnl_extension_utils.h" #include "onednn/dnnl.h" #include "openvino/core/parallel.hpp" -#include "openvino/core/rt_info.hpp" #include "shape_inference/custom/subgraph.hpp" #include "snippets/lowered/pass/init_loops.hpp" #include "snippets/lowered/pass/insert_buffers.hpp" @@ -27,9 +26,11 @@ #if defined(OPENVINO_ARCH_ARM64) # include "emitters/snippets/aarch64/cpu_generator.hpp" +# include "executors/aarch64/subgraph.hpp" # include "transformations/snippets/aarch64/shape_inference.hpp" #else # include "emitters/snippets/x64/cpu_generator.hpp" +# include "executors/x64/subgraph.hpp" # include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp" # include "transformations/snippets/x64/pass/eliminate_brgemm_copy_b.hpp" # include "transformations/snippets/x64/pass/enforce_precision.hpp" @@ -48,13 +49,6 @@ #include "utils/cpu_utils.hpp" #include "utils/ngraph_utils.hpp" -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) -# include - -# include "emitters/snippets/x64/jit_segfault_detector_emitter.hpp" -std::mutex err_print_lock; -#endif - #ifdef SNIPPETS_LIBXSMM_TPP # include "snippets/lowered/pass/optimize_domain.hpp" # include "transformations/tpp/x64/pass/brgemm_to_brgemm_tpp.hpp" @@ -70,152 +64,9 @@ namespace intel_cpu { namespace node { namespace { -// Class for Subgraphs with static shapes -class SubgraphStaticExecutor : public Subgraph::SubgraphExecutor { -public: - SubgraphStaticExecutor(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& snippet, - const std::vector& start_offset_in, - const std::vector& start_offset_out, - const std::shared_ptr& snippet_config, - const BufferScratchpadAllocator& allocator) - : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out, snippet_config, allocator) {} - - void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override { - const auto& callable = m_schedule->get_callable(); - - auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { - init_call_args(call_args, inMemPtrs, outMemPtrs, ithr); - }; - auto caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes) { - callable(&call_args, indexes.data()); - }; - - if (m_parallel_exec_domain.size() == rank6D) { - parallel_for6d(initializer, caller); - } else { - parallel_forNd(initializer, caller); - } - } - -protected: - typedef void (*kernel)(const void*, const void*); - - inline void init_call_args(jit_snippets_call_args& call_args, - const std::vector& srcMemPtrs, - const std::vector& dstMemPtrs, - size_t ithr) { - for (size_t i = 0; i < srcMemPtrs.size(); i++) - call_args.src_ptrs[i] = srcMemPtrs[i]->getDataAs() + m_start_offset_in[i]; - - for (size_t i = 0; i < dstMemPtrs.size(); i++) - call_args.dst_ptrs[i] = dstMemPtrs[i]->getDataAs() + m_start_offset_out[i]; - - update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); - } -}; - -// Specialized dynamic executor based on shape agnostic kernel for the specific input shapes -class SubgraphDynamicSpecializedExecutor : public Subgraph::SubgraphExecutor { -public: - SubgraphDynamicSpecializedExecutor(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& snippet, - const std::vector& start_offset_in, - const std::vector& start_offset_out, - const std::shared_ptr& snippet_config, - const BufferScratchpadAllocator& allocator) - : SubgraphExecutor(snippet_attrs, snippet, start_offset_in, start_offset_out, snippet_config, allocator) { - buffer_offsets = snippet_config->buffer_cluster_offsets; - data_offsets = snippet_config->io_data_offsets; - loop_args = snippet_config->loop_args; - reset_exec_table_state = snippet_config->kernel_executor_table->get_state_reset(); - } - - void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) override { - const auto& callable = m_schedule->get_callable(); - - OPENVINO_ASSERT(data_offsets.size() == inMemPtrs.size() + outMemPtrs.size(), "Incorrect data offset count!"); - OPENVINO_ASSERT(data_offsets.front().size() == m_parallel_exec_domain.size(), - "Data offsets with invalid ranks detected"); - - // Note: we need to reset KernelExecutorTable to the state that was recorded in the - // SubgraphDynamicSpecializedExecutor constructor because the table might've been used for other shapes - reset_exec_table_state(); - - std::vector src_ptrs; - std::vector dst_ptrs; - init_original_ptrs(inMemPtrs, outMemPtrs, src_ptrs, dst_ptrs); - - auto initializer = [&](jit_snippets_call_args& call_args, size_t ithr) { - init_call_args(call_args, ithr); - }; - auto caller = [&](jit_snippets_call_args& call_args, const std::vector& indexes) { - update_ptrs(call_args, src_ptrs, dst_ptrs, indexes); - callable(&call_args); - }; - - if (m_parallel_exec_domain.size() == rank6D) { - parallel_for6d(initializer, caller); - } else { - parallel_forNd(initializer, caller); - } - } - -protected: - typedef void (*dynamic_kernel)(const void*); - - inline void init_call_args(jit_snippets_call_args& call_args, size_t ithr) { - call_args.register_loops(loop_args); - std::copy(buffer_offsets.cbegin(), buffer_offsets.cend(), call_args.buffer_offsets); - - update_scratchpad_ptr(call_args.buffer_scratchpad_ptr, ithr); - } - - inline void init_original_ptrs(const std::vector& srcMemPtrs, - const std::vector& dstMemPtrs, - std::vector& src_ptrs, - std::vector& dst_ptrs) { - const auto in_num = srcMemPtrs.size(); - const auto out_num = dstMemPtrs.size(); - - src_ptrs.resize(in_num, nullptr); - dst_ptrs.resize(out_num, nullptr); - - for (size_t i = 0; i < in_num; i++) - src_ptrs[i] = srcMemPtrs[i]->getDataAs() + m_start_offset_in[i]; - for (size_t i = 0; i < out_num; i++) - dst_ptrs[i] = dstMemPtrs[i]->getDataAs() + m_start_offset_out[i]; - } - - inline void update_ptrs(jit_snippets_call_args& call_args, - const std::vector& src_ptrs, - const std::vector& dst_ptrs, - const std::vector& indexes) const { - for (size_t i = 0; i < src_ptrs.size(); i++) { - auto i_ptr = src_ptrs[i]; - for (size_t j = 0; j < indexes.size(); j++) { - i_ptr += data_offsets[i][j] * indexes[j]; - } - call_args.src_ptrs[i] = i_ptr; - } - for (size_t i = 0; i < dst_ptrs.size(); i++) { - auto i_ptr = dst_ptrs[i]; - for (size_t j = 0; j < indexes.size(); j++) { - i_ptr += data_offsets[i + src_ptrs.size()][j] * indexes[j]; - } - call_args.dst_ptrs[i] = i_ptr; - } - } - - std::vector buffer_offsets = {}; - std::vector> data_offsets = {}; - std::vector loop_args = {}; - std::function reset_exec_table_state; -}; - struct SubgraphKey { SubgraphKey() = default; - SubgraphKey(const std::shared_ptr& attrs_, const std::vector& in_shapes_) + SubgraphKey(const std::shared_ptr& attrs_, const std::vector& in_shapes_) : attrs(attrs_), in_shapes(in_shapes_) {} virtual ~SubgraphKey() = default; @@ -223,19 +74,19 @@ struct SubgraphKey { size_t hash() const; bool operator==(const SubgraphKey& rhs) const; - std::shared_ptr attrs = nullptr; + std::shared_ptr attrs = nullptr; std::vector in_shapes = {}; }; struct SubgraphCodeGeneratorKey { - SubgraphCodeGeneratorKey(const std::shared_ptr& attrs_, uint8_t mask_) + SubgraphCodeGeneratorKey(const std::shared_ptr& attrs_, uint8_t mask_) : attrs(attrs_), broadcasting_mask(mask_) {} size_t hash() const; bool operator==(const SubgraphCodeGeneratorKey& rhs) const; - std::shared_ptr attrs = nullptr; + std::shared_ptr attrs = nullptr; uint8_t broadcasting_mask = 0; }; @@ -251,7 +102,7 @@ struct SubgraphShapeInferResultKey { uint64_t body_hash = 0; }; -size_t get_attr_hash(size_t seed, const std::shared_ptr& attrs) { +size_t get_attr_hash(size_t seed, const std::shared_ptr& attrs) { using namespace dnnl::impl; using namespace dnnl::impl::primitive_hashing; @@ -301,7 +152,7 @@ size_t SubgraphShapeInferResultKey::hash() const { return seed; } -bool operator==(const Subgraph::SubgraphAttrs& lhs, const Subgraph::SubgraphAttrs& rhs) { +bool operator==(const SubgraphAttrs& lhs, const SubgraphAttrs& rhs) { if (&lhs == &rhs) return true; if (lhs.bodyHash != rhs.bodyHash) @@ -796,10 +647,10 @@ void Subgraph::optimizeIR() { void Subgraph::prepareParams() { const auto& cache = context->getParamsCache(); - auto builder = [this, &cache](const SubgraphKey& key) -> std::shared_ptr { + auto builder = [this, &cache](const SubgraphKey& key) -> std::shared_ptr { const auto& snippet = subgraph_attrs->snippet; - SubgraphExecutor::BufferScratchpadAllocator allocator = [this](size_t size) { + SubgraphBaseExecutor::BufferScratchpadAllocator allocator = [this](size_t size) { return getScratchPadMem(std::make_shared(ov::element::u8, intel_cpu::Shape{size})); }; @@ -822,12 +673,13 @@ void Subgraph::prepareParams() { code_gen->get()->lowering_result.kernel_executor_table); } const auto& snippet_config = ov::as_type_ptr(snippet->update_runtime_config()); - return std::make_shared(key.attrs, + return std::make_shared(snippet_config, + key.attrs, code_gen, start_offset_in, start_offset_out, - snippet_config, - allocator); + allocator, + cache); } else { // Static case: // 1. Update runtime config to get static scheduling data (io data offsets, parallel domain) which will be @@ -840,12 +692,13 @@ void Subgraph::prepareParams() { [&snippet_config](const SubgraphCodeGeneratorKey& key) -> std::shared_ptr { return std::make_shared(key.attrs, snippet_config); }); - return std::make_shared(key.attrs, + return std::make_shared(snippet_config, + key.attrs, code_gen_result.first, start_offset_in, start_offset_out, - snippet_config, - allocator); + allocator, + cache); } }; @@ -905,191 +758,6 @@ void Subgraph::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -namespace { -inline void init_parallel_domain(const std::shared_ptr& snippet_config, std::vector& domain) { - const auto& master_shape = snippet_config->master_shape; - const auto& tensor_rank = snippet_config->tensor_rank; - const auto& tile_rank = snippet_config->tile_rank; - domain.resize(tensor_rank, 1); - - std::fill(domain.begin(), domain.end(), 1); - std::copy(master_shape.cbegin(), - master_shape.cbegin() + (master_shape.size() - tile_rank), - domain.begin() + (tensor_rank - master_shape.size())); -} -} // namespace - -Subgraph::SubgraphCodeGenerator::SubgraphCodeGenerator(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& config) { - OPENVINO_ASSERT(snippet_attrs, "Subgraph attributes are empty!"); - OPENVINO_ASSERT(config, "Runtime Config is empty!"); - - jit_snippets_compile_args jcp; - jcp.data_offsets = config->io_data_offsets; - init_parallel_domain(config, jcp.exec_domain); - schedule = - std::make_shared(snippet_attrs->snippet->generate(reinterpret_cast(&jcp))); -} - -Subgraph::SubgraphExecutor::SubgraphExecutor(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& snippet, - const std::vector& start_offset_in, - const std::vector& start_offset_out, - const std::shared_ptr& snippet_config, - const BufferScratchpadAllocator& allocator) - : m_schedule(snippet->get()), - m_start_offset_in(start_offset_in), - m_start_offset_out(start_offset_out) { - OPENVINO_ASSERT(m_schedule, "Schedule is empty!"); - OPENVINO_ASSERT(snippet_config, "Runtime Config is empty!"); - init_parallel_domain(snippet_config, m_parallel_exec_domain); - - m_harness_work_amount = std::accumulate(m_parallel_exec_domain.cbegin(), - m_parallel_exec_domain.cend(), - size_t(1), - std::multiplies()); - m_nthreads = std::min(parallel_get_max_threads(), static_cast(m_harness_work_amount)); - - m_buffer_scratchpad_size = snippet_config->buffer_scratchpad_size; - OPENVINO_ASSERT(!ov::snippets::utils::is_dynamic_value(m_buffer_scratchpad_size), - "Undefined buffer scratchpad size!"); - m_internal_buffer_size = static_cast(m_nthreads) * m_buffer_scratchpad_size; - m_in_requested_descs = snippet_config->m_in_requested_descs; - const auto external_repacking_buffer_size = - std::accumulate(m_in_requested_descs.begin(), - m_in_requested_descs.end(), - size_t(0), - [](size_t sum, const std::pair& requested_desc_elem) { - return sum + requested_desc_elem.second->getCurrentMemSize(); - }); - m_buffer_scratchpad = allocator(m_internal_buffer_size + external_repacking_buffer_size); - -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) - const auto target = std::dynamic_pointer_cast( - snippet_attrs->snippet->get_generator()->get_target_machine()); - enabled_segfault_detector = target && target->debug_config.enable_segfault_detector; -#endif -} - -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) -void Subgraph::SubgraphExecutor::segfault_detector() { - if (enabled_segfault_detector) { - __sighandler_t signal_handler = [](int signal) { - std::lock_guard guard(err_print_lock); - if (auto segfault_detector_emitter = ov::intel_cpu::g_custom_segfault_handler->local()) - std::cout << segfault_detector_emitter->info() << std::endl; - auto tid = parallel_get_thread_num(); - OPENVINO_THROW("Segfault was caught by the signal handler in subgraph node execution on thread " + - std::to_string(tid)); - }; - struct sigaction new_handler {}; - new_handler.sa_handler = signal_handler; - sigaction(SIGSEGV, &new_handler, nullptr); - } -} -#endif - -void Subgraph::SubgraphExecutor::parallel_for6d( - const std::function& initializer, - const std::function&)>& caller) { - const auto& dom = m_parallel_exec_domain; - -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) - segfault_detector(); -#endif - - parallel_nt_static(m_nthreads, [&](const int ithr, const int nthr) { - jit_snippets_call_args call_args; - initializer(call_args, ithr); - - size_t start = 0, end = 0; - splitter(m_harness_work_amount, nthr, ithr, start, end); - - std::vector indexes{0, 0, 0, 0, 0}; - parallel_it_init(start, - indexes[0], - dom[0], - indexes[1], - dom[1], - indexes[2], - dom[2], - indexes[3], - dom[3], - indexes[4], - dom[4]); - for (size_t iwork = start; iwork < end; ++iwork) { - caller(call_args, indexes); - parallel_it_step(indexes[0], - dom[0], - indexes[1], - dom[1], - indexes[2], - dom[2], - indexes[3], - dom[3], - indexes[4], - dom[4]); - } - }); -} - -void Subgraph::SubgraphExecutor::parallel_forNd( - const std::function& initializer, - const std::function&)>& caller) { - const auto& dom = m_parallel_exec_domain; - -#if defined(__linux__) && defined(OPENVINO_ARCH_X86_64) && defined(SNIPPETS_DEBUG_CAPS) - segfault_detector(); -#endif - - parallel_nt_static(m_nthreads, [&](const int ithr, const int nthr) { - jit_snippets_call_args call_args; - initializer(call_args, ithr); - - size_t start = 0, end = 0; - splitter(m_harness_work_amount, nthr, ithr, start, end); - - std::vector indexes(dom.size() - 1, 0); - for (size_t iwork = start; iwork < end; ++iwork) { - size_t tmp = iwork; - for (ptrdiff_t j = static_cast(dom.size()) - 2; j >= 0; j--) { - indexes[j] = tmp % dom[j]; - tmp /= dom[j]; - } - - caller(call_args, indexes); - } - }); -} - -void Subgraph::SubgraphExecutor::execute(const dnnl::stream& strm, - const std::vector& inMemPtrs, - const std::vector& outMemPtrs) { - if (!m_in_requested_descs.empty()) { - auto reorderedInMemPtrs = reorder_inputs(strm, inMemPtrs); - exec_impl(reorderedInMemPtrs, outMemPtrs); - } else { - exec_impl(inMemPtrs, outMemPtrs); - } -} - -std::vector Subgraph::SubgraphExecutor::reorder_inputs(const dnnl::stream& strm, - const std::vector& inMemPtrs) { - auto reordered_in_ptrs = inMemPtrs; - size_t offset = m_internal_buffer_size; - for (const auto& requested_descs_elem : m_in_requested_descs) { - const auto in_idx = requested_descs_elem.first; - const auto& requested_desc = requested_descs_elem.second; - - const void* data_ptr = m_buffer_scratchpad->getDataAs() + offset; - const auto scratch_mem = std::make_shared(strm.get_engine(), requested_desc, data_ptr, false); - scratch_mem->load(*reordered_in_ptrs[in_idx]); - reordered_in_ptrs[in_idx] = scratch_mem; - offset += requested_desc->getCurrentMemSize(); - } - return reordered_in_ptrs; -} - } // namespace node } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.h b/src/plugins/intel_cpu/src/nodes/subgraph.h index aac0fa1ea2f535..ea7d51650e5cad 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.h +++ b/src/plugins/intel_cpu/src/nodes/subgraph.h @@ -4,10 +4,8 @@ #pragma once -#include "emitters/snippets/cpu_runtime_configurator.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "executors/subgraph.hpp" #include "node.h" -#include "snippets/op/subgraph.hpp" #if defined(OPENVINO_ARCH_ARM64) # include "cpu/aarch64/cpu_isa_traits.hpp" @@ -15,8 +13,6 @@ # include "cpu/x64/cpu_isa_traits.hpp" #endif -#include - namespace ov { namespace intel_cpu { namespace node { @@ -41,21 +37,6 @@ class Subgraph : public Node { void execute(dnnl::stream strm) override; void executeDynamicImpl(dnnl::stream strm) override; - struct SubgraphAttrs { - // Local copy of subgraph node for canonization & code generation - std::shared_ptr snippet; - uint64_t bodyHash; - std::vector inMemOrders; - std::vector outMemOrders; - std::vector inMemPrecs; - std::vector outMemPrecs; - }; - - // Class for snippet compilation - class SubgraphCodeGenerator; - // Base class for executors - class SubgraphExecutor; - protected: IShapeInfer::Result shapeInfer() const override; @@ -103,79 +84,7 @@ class Subgraph : public Node { // Input shapes that are used in PrepareParams and ShapeInfer to avoid frequent memory allocation mutable std::vector in_shapes; - std::shared_ptr execPtr = nullptr; -}; - -class Subgraph::SubgraphCodeGenerator { -public: - SubgraphCodeGenerator(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& config); - - const std::shared_ptr& get() const { - return schedule; - } - -private: - std::shared_ptr schedule; -}; - -class Subgraph::SubgraphExecutor { -public: - using BufferScratchpadAllocator = std::function; - - SubgraphExecutor(const std::shared_ptr& snippet_attrs, - const std::shared_ptr& snippet, - const std::vector& start_offset_in, - const std::vector& start_offset_out, - const std::shared_ptr& snippet_config, - const BufferScratchpadAllocator& allocator); - virtual ~SubgraphExecutor() = default; - - void execute(const dnnl::stream& strm, - const std::vector& inMemPtrs, - const std::vector& outMemPtrs); - -protected: - virtual void exec_impl(const std::vector& inMemPtrs, const std::vector& outMemPtrs) = 0; - - void parallel_for6d(const std::function& initializer, - const std::function&)>& caller); - void parallel_forNd(const std::function& initializer, - const std::function&)>& caller); - - inline void update_scratchpad_ptr(void*& scratchpad_ptr, size_t ithr) const { - if (m_buffer_scratchpad_size > 0) - scratchpad_ptr = m_buffer_scratchpad->getDataAs() + ithr * m_buffer_scratchpad_size; - } - - std::shared_ptr m_schedule; - // Holds index of output used as in execution domain - // it should be compatible with a schedule's work size - std::vector m_parallel_exec_domain = {}; - size_t m_harness_work_amount = 0; - - // Buffer scratchpad - MemoryPtr m_buffer_scratchpad = nullptr; - size_t m_buffer_scratchpad_size = 0; - size_t m_internal_buffer_size = 0; - - const size_t rank6D = 6; - - // Count of threads for parallel_nt - int m_nthreads = 0; - - std::vector m_start_offset_in = {}; - std::vector m_start_offset_out = {}; - -#ifdef SNIPPETS_DEBUG_CAPS - bool enabled_segfault_detector = false; - inline void segfault_detector(); -#endif - -private: - std::vector reorder_inputs(const dnnl::stream& strm, const std::vector& inMemPtrs); - - std::unordered_map m_in_requested_descs = {}; + std::shared_ptr execPtr = nullptr; }; } // namespace node diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index 7e52905145869f..ce57cd1529b893 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -114,6 +114,13 @@ size_t BrgemmCopyB::get_offset_compensations() const { return get_output_offset(1); } +bool BrgemmCopyB::is_transposed(const std::vector& layout) { + const auto is_transposed = !layout.empty() && layout.back() != layout.size() - 1; + OPENVINO_ASSERT(IMPLICATION(is_transposed, (layout[layout.size() - 2] == layout.size() - 1)), + "supports only N dim placed as last or pre last dimension"); + return is_transposed; +} + BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { const auto& brg_copyb = ov::as_type_ptr(n); OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer"); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp index 54e2c39fcf1c06..b4e7b030fc605b 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp @@ -72,6 +72,8 @@ class BrgemmCopyB : public snippets::modifier::MemoryAccess, public ov::op::Op { Result infer(const std::vector& input_shapes) override; }; + static bool is_transposed(const std::vector& layout); + private: void custom_constructor_validate_and_infer_types(std::vector layout_input = {}); void validate_element_type(const ov::element::Type& element_type); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp index 939ae93ad92b18..f17d052e7ffe43 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/eliminate_brgemm_copy_b.cpp @@ -10,6 +10,7 @@ #include "openvino/pass/pattern/op/wrap_type.hpp" #include "snippets/itt.hpp" #include "snippets/op/rank_normalization.hpp" +#include "snippets/op/reshape.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" namespace ov { @@ -30,12 +31,28 @@ pass::EliminateBrgemmCopyB::EliminateBrgemmCopyB() { const auto& in_desc = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(copy_b_node->input(0)); const auto& layout = in_desc->get_layout(); - // TODO: - // 1. Ticket 157340: support external repacking for copyB with compensations - // 2. Ticket 157339: support external repacking for non-planar layout - if (!ov::snippets::utils::is_planar_layout(layout) || - brgemm_utils::with_compensations(copy_b_node->get_type()) || transformation_callback(copy_b_node)) + + auto is_supported_layout = [](const std::vector& layout) { + return layout.empty() || (layout.size() - 1 == layout.back()); + }; + + // TODO [157340]: support external repacking for copyB with compensations + if (!is_supported_layout(layout) || brgemm_utils::with_compensations(copy_b_node->get_type()) || + transformation_callback(copy_b_node)) return false; + + // If there is non-empty and non-planar layout, we should insert reshape to support shape inference + if (!ov::snippets::utils::is_planar_layout(layout)) { + const auto& subtensor = in_desc->get_subtensor(); + const auto& reshape = + std::make_shared(copy_b_node->input_value(0), layout); + ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(reshape->input(0), subtensor, layout); + ov::snippets::lowered::PortDescriptorUtils::set_port_descriptor(reshape->output(0), subtensor); + ov::replace_node_update_name(copy_b_node, reshape); + return true; + } + + // If there is no layout, we can just remove BrgemmCopyB from the subgraph return ov::replace_output_update_name(copy_b_out, copy_b_node->input_value(0)); }; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp index 1cb8263d189d18..16df97bb209ed9 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/adjust_brgemm_copy_b_loop_ports.cpp @@ -70,8 +70,11 @@ bool pass::AdjustBrgemmCopyBLoopPorts::run(const snippets::lowered::LinearIR& li auto get_repacking_loop_idces = [](const snippets::lowered::ExpressionPtr& brgemm_expr) { // Repacking may be extracted outside the snippets kernel. In this case, brgemm parent expression is a // parameter. - if (is_type( - brgemm_expr->get_input_port_connector(1)->get_source().get_expr()->get_node())) + const auto& brgemm_in1 = brgemm_expr->get_input_port_connector(1)->get_source(); + const auto& shape_infer_seq = ov::snippets::utils::get_first_parent_shape_infer_expr_seq(brgemm_in1.get_expr()); + const auto source = + shape_infer_seq.empty() ? brgemm_in1 : shape_infer_seq.back()->get_input_port_connector(0)->get_source(); + if (is_type(source.get_expr()->get_node())) return std::vector{}; const auto repacking_expr = brgemm_utils::repacking::get_copy_b_expr(brgemm_expr); OPENVINO_ASSERT(repacking_expr, "BrgemmCopyB expression is not found"); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp index 78f9b928298a9d..6f9a652620df2d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.cpp @@ -14,59 +14,144 @@ namespace ov { namespace intel_cpu { +const size_t BrgemmExternalRepackingAdjuster::brgemm_kernel_rank = 2; + BrgemmExternalRepackingAdjuster::BrgemmExternalRepackingAdjuster(const ov::snippets::lowered::LinearIRCPtr& linear_ir, const CPURuntimeConfigurator* configurator) : snippets::lowered::pass::RuntimeOptimizer(configurator) { const auto& params = linear_ir->get_parameters(); for (size_t i = 0; i < params.size(); ++i) { const auto& param = params[i]; - const auto consumers = param->get_output_port_connector(0)->get_consumers(); - const bool brgemm_with_extracted_repacking = - std::any_of(consumers.begin(), consumers.end(), [](const ov::snippets::lowered::ExpressionPort& port) { - auto brgemm = ov::as_type_ptr(port.get_expr()->get_node()); - return brgemm && brgemm_utils::with_repacking(brgemm->get_type()) && port.get_index() == 1; - }); - if (brgemm_with_extracted_repacking) { - m_param_idces_with_external_repacking.insert(i); - // Ticket 157339: Support non-planar layout - OPENVINO_ASSERT(ov::snippets::utils::is_planar_layout(configurator->get_io_descs()[i]->get_layout()), - "Non-planar layout is not supported for external repacking"); + const auto& shape_infer_consumers = ov::snippets::utils::get_first_child_shape_infer_expr_seq(param); + const auto& out = shape_infer_consumers.empty() ? param->get_output_port(0) + : shape_infer_consumers.back()->get_output_port(0); + const auto consumers = out.get_connected_ports(); + + for (const auto& consumer : consumers) { + auto brgemm = ov::as_type_ptr(consumer.get_expr()->get_node()); + if (brgemm && brgemm_utils::with_repacking(brgemm->get_type()) && consumer.get_index() == 1) { + const auto src_prc = brgemm->get_input_element_type(0); + const auto wei_prc = brgemm->get_input_element_type(1); + const auto isa = brgemm_utils::get_primitive_isa(src_prc, brgemm_utils::with_amx(brgemm->get_type())); + const auto inner_n_block = brgemm_utils::repacking::compute_inner_n_block(wei_prc); + const auto is_transposed_b = + BrgemmCopyB::is_transposed(m_configurator->get_io_descs()[i]->get_layout()); + auto config = BrgemmCopyBKernelConfig(src_prc, wei_prc, isa, false, is_transposed_b, inner_n_block); + m_executors[i] = std::make_shared( + static_cast(m_configurator)->get_cache(), + config); + } } } } +VectorDims BrgemmExternalRepackingAdjuster::get_blk_order(size_t shape_rank) { + VectorDims order(shape_rank - brgemm_kernel_rank); + std::iota(order.begin(), order.end(), 0); + const auto last_idx = shape_rank - 1; + order.insert(order.end(), {last_idx - 1, last_idx, last_idx - 1}); + return order; +} + +VectorDims BrgemmExternalRepackingAdjuster::get_blk_shape(const VectorDims& planar_shape, ov::element::Type prc) { + const auto vnni_factor = brgemm_utils::compute_vnni_factor(prc); + const auto K = *++planar_shape.rbegin(); + const auto N = *planar_shape.rbegin(); + const auto new_K = snippets::utils::div_up(K, vnni_factor); + const auto new_N = std::max(N, brgemm_utils::repacking::compute_inner_n_block(prc)); + VectorDims blk_shape(planar_shape.begin(), planar_shape.end() - brgemm_kernel_rank); + blk_shape.insert(blk_shape.end(), {new_K, new_N, vnni_factor}); + return blk_shape; +} + +void BrgemmExternalRepackingAdjuster::update_kernel(const RepackExecutorPtr& executor, + const VectorDims& shape, + const VectorDims& layout, + size_t N, + size_t K, + ov::element::Type prc) { + const auto generic_config = executor->get_config().get_clone_ptr(); + auto config = static_cast(generic_config.get()); + const auto idx = config->is_transposed_B() ? 0 : 1; + const auto copy_wei_stride = ov::snippets::utils::get_dim_in_stride(shape, layout, idx) * prc.size(); + config->update(N, N, K, K, copy_wei_stride, brgemm_utils::repacking::compute_LDB(N, prc)); + executor->update_by_config(*config); +} + bool BrgemmExternalRepackingAdjuster::run(const snippets::lowered::LinearIR& linear_ir) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmExternalRepackingAdjuster") const auto& cpu_config = ov::as_type_ptr(m_configurator->get_config()); - auto& optimal_descs = cpu_config->m_in_requested_descs; - for (const auto& i : m_param_idces_with_external_repacking) { + + size_t data_size = 0; + for (const auto& p : m_executors) { + const auto& i = p.first; const auto& shape = cpu_config->io_shapes[i]; - const auto& K = *++shape.rbegin(); - const auto& N = *shape.rbegin(); - - const auto& precision = linear_ir.get_parameters()[i]->get_node()->get_output_element_type(0); - const auto vnni_factor = brgemm_utils::compute_vnni_factor(precision); - const size_t brgemm_kernel_rank = 2; - // Firstly, batch dims are set - VectorDims requested_blocked_shape(shape.begin(), shape.end() - brgemm_kernel_rank); - // Then, the blocked dims are formed - requested_blocked_shape.insert(requested_blocked_shape.end(), - {snippets::utils::div_up(K, vnni_factor), - std::max(N, brgemm_utils::repacking::compute_inner_n_block(precision)), - vnni_factor}); - - VectorDims requested_order(shape.size() - brgemm_kernel_rank); - std::iota(requested_order.begin(), requested_order.end(), 0); - const auto last_idx = shape.size() - 1; - requested_order.insert(requested_order.end(), {last_idx - 1, last_idx, last_idx - 1}); - - optimal_descs[i] = - std::make_shared(precision, Shape(shape), requested_blocked_shape, requested_order); - - ov::snippets::VectorDims shape_for_offset(cpu_config->tensor_rank - shape.size(), 1); - shape_for_offset.insert(shape_for_offset.end(), requested_blocked_shape.begin(), requested_blocked_shape.end()); - m_configurator->compute_offsets(shape_for_offset, i, 0); + + const auto& layout = cpu_config->io_layouts[i]; + const auto planar_shape = ov::snippets::utils::get_planar_vdims(shape, layout); + const auto& K = *++planar_shape.rbegin(); + const auto& N = *planar_shape.rbegin(); + + const auto& prc = linear_ir.get_parameters()[i]->get_node()->get_output_element_type(0); + const auto blk_shape = get_blk_shape(planar_shape, prc); + + // src data + dst data per kernel call + const auto src_data = N * K * prc.size(); + const auto dst_data = + std::accumulate(blk_shape.rbegin(), blk_shape.rbegin() + 3, prc.size(), std::multiplies()); + data_size += src_data + dst_data; + + update_kernel(p.second, shape, layout, N, K, prc); } + + const auto L2_cache_size = dnnl::utils::get_cache_size(2, true); + const auto fit_into_L2 = data_size < L2_cache_size; + // Heuristic: If external repacking data doesn't fit in the cache L2, + // external repacking should be executed in seperate parallel section before kernel execution. + cpu_config->repacking_impl_type = + fit_into_L2 ? CPURuntimeConfig::RepackingImplType::IN_PARALLEL : CPURuntimeConfig::RepackingImplType::SEPARATE; + + const auto is_impl_parallel = cpu_config->repacking_impl_type == CPURuntimeConfig::RepackingImplType::IN_PARALLEL; + + for (const auto& p : m_executors) { + const auto& i = p.first; + const auto& shape = cpu_config->io_shapes[i]; + auto& repacked_in = cpu_config->repacked_inputs[i]; + + const auto& prc = linear_ir.get_parameters()[i]->get_node()->get_output_element_type(0); + auto planar_shape = ov::snippets::utils::get_planar_vdims(shape, cpu_config->io_layouts[i]); + auto blk_shape = get_blk_shape(planar_shape, prc); + // In parallel impl, each thread needs buffer with only shape [K_blk, N_blk, VNNI] to store repacking data + if (is_impl_parallel) { + std::fill(planar_shape.rbegin() + brgemm_kernel_rank, planar_shape.rend(), 1); + std::fill(blk_shape.rbegin() + brgemm_kernel_rank + 1, blk_shape.rend(), 1); + } + const auto order = get_blk_order(planar_shape.size()); + const auto desc = std::make_shared(prc, Shape(planar_shape), blk_shape, order); + + // Save original input offsets for input before repacking. + // If the shape has not been changed, it means that we already created `RepackedInput` for this input + // on previous pass call and now `cpu_config->io_data_offsets[i]` contains offsets not for original input - + // they were updated for blocked shapes/zeroed for previous initialization and we canonot use them as original + // offsets. + const auto in_offsets = + shape == cpu_config->latest_shapes[i] ? repacked_in.in_offsets() : cpu_config->io_data_offsets[i]; + + // In parallel case Kernel should not add offsets to repacked inputs because + // they will be applied during repacking in execution stage + if (is_impl_parallel) { + auto& offsets = cpu_config->io_data_offsets[i]; + std::fill(offsets.begin(), offsets.end(), 0); + } else { + ov::snippets::VectorDims shape_for_offset(cpu_config->tensor_rank - shape.size(), 1); + shape_for_offset.insert(shape_for_offset.end(), blk_shape.begin(), blk_shape.end()); + m_configurator->compute_offsets(shape_for_offset, i, 0); + } + const auto out_offsets = cpu_config->io_data_offsets[i]; + + repacked_in = CPURuntimeConfig::RepackedInput(p.second->get_kernel(), desc, in_offsets, out_offsets); + } + return true; } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp index 4d0c9586f3be31..5efc5a738c5d76 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp @@ -24,11 +24,23 @@ class BrgemmExternalRepackingAdjuster : public ov::snippets::lowered::pass::Runt bool run(const snippets::lowered::LinearIR& linear_ir) override; bool applicable() const override { - return !m_param_idces_with_external_repacking.empty(); + return !m_executors.empty(); } private: - std::set m_param_idces_with_external_repacking; + using RepackExecutorPtr = std::shared_ptr; + static VectorDims get_blk_order(size_t shape_rank); + static VectorDims get_blk_shape(const VectorDims& planar_shape, ov::element::Type prc); + + void update_kernel(const RepackExecutorPtr& executor, + const VectorDims& shape, + const VectorDims& layout, + size_t N, + size_t K, + ov::element::Type prc); + + static const size_t brgemm_kernel_rank; + std::unordered_map m_executors; }; } // namespace intel_cpu