Skip to content

Commit

Permalink
[Snippets][CPU] Added external repacking via BrgemmCopyB
Browse files Browse the repository at this point in the history
  • Loading branch information
a-sidorova committed Dec 23, 2024
1 parent 6f3796b commit c6c62b5
Show file tree
Hide file tree
Showing 8 changed files with 328 additions and 65 deletions.
15 changes: 14 additions & 1 deletion src/common/snippets/include/snippets/utils/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,26 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_child_shape_infer_seq(const std
std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const std::shared_ptr<ov::Node>& start_node);

/**
*
* @param Get stride of input/output dimension
* @param expr_port target port that contains shape and layout info
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/

int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx = 1);
/**
* @brief Get stride of input dimension
* @param shape target shape
* @param layout target layout
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/
int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1);
/**
* @brief Get stride of output dimension
* @param shape target shape
* @param layout target layout
* @param idx index of the target dimension starting from the shape's end (default = 1)
*/
int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx = 1);

/**
* @brief Traverses path starting from "expr", and calls "func" for each expression.
Expand Down
17 changes: 12 additions & 5 deletions src/common/snippets/src/utils/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,14 +317,21 @@ std::shared_ptr<ov::Node> get_leaf_node_of_first_parent_shape_infer_seq(const st
}

int64_t get_dim_stride(const lowered::ExpressionPort& expr_port, size_t idx) {
size_t dim_idx = 0;
const auto& shape = expr_port.get_descriptor_ptr()->get_shape();
const auto& layout = expr_port.get_descriptor_ptr()->get_layout();
switch (expr_port.get_type()) {
case lowered::ExpressionPort::Input: dim_idx = utils::get_input_dim_idx(layout, idx); break;
case lowered::ExpressionPort::Output: dim_idx = utils::get_output_dim_idx(layout, idx); break;
default: OPENVINO_THROW("Unsupported expression port type!");
case lowered::ExpressionPort::Input: return get_dim_in_stride(shape, layout, idx);
case lowered::ExpressionPort::Output: return get_dim_out_stride(shape, layout, idx);
}
return get_stride(dim_idx, expr_port.get_descriptor_ptr()->get_shape());
OPENVINO_THROW("Unsupported expression port type!");
}

int64_t get_dim_in_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) {
return get_stride(utils::get_input_dim_idx(layout, idx), shape);
}

int64_t get_dim_out_stride(const VectorDims& shape, const VectorDims& layout, size_t idx) {
return get_stride(utils::get_output_dim_idx(layout, idx), shape);
}

void visit_path(const lowered::ExpressionPtr& expr,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ std::string CPURuntimeConfig::to_string() const {
}
#endif

CPURuntimeConfigurator::CPURuntimeConfigurator()
: ov::snippets::RuntimeConfigurator(std::make_shared<CPURuntimeConfig>()) {}
CPURuntimeConfigurator::CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache)
: ov::snippets::RuntimeConfigurator(std::make_shared<CPURuntimeConfig>()),
compiled_kernel_cache(std::move(cache)) {}

void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) {
RuntimeConfigurator::initialization(linear_ir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
#pragma once

#include "emitters/snippets/jit_snippets_call_args.hpp"

#ifdef OPENVINO_ARCH_X86_64
# include "emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp"
#endif

#include "cache/multi_cache.h"
#include "memory_desc/cpu_blocked_memory_desc.h"
#include "snippets/lowered/port_descriptor.hpp"
#include "snippets/runtime_configurator.hpp"
Expand All @@ -21,27 +27,59 @@ class CPURuntimeConfig : public ov::snippets::RuntimeConfig {
std::string to_string() const override;
#endif

#ifdef OPENVINO_ARCH_X86_64
struct RepackedInput {
RepackedInput() = default;
RepackedInput(CpuBlockedMemoryDescPtr desc_,
std::shared_ptr<BrgemmCopyBKernelExecutor> executor_,
VectorDims in_offsets_,
VectorDims out_offsets_)
: desc(std::move(desc_)),
executor(std::move(executor_)),
in_offsets(std::move(in_offsets_)),
out_offsets(std::move(out_offsets_)) {}

CpuBlockedMemoryDescPtr desc{nullptr};
std::shared_ptr<BrgemmCopyBKernelExecutor> executor{nullptr};
VectorDims in_offsets{};
VectorDims out_offsets{};
};
std::unordered_map<size_t, RepackedInput> repacked_inputs = {};

enum class RepackingImplType {
NONE, // no kernel-outside repacking
IN_PARALLEL, // should be executed in parallel_nt by each thread
SEPARATE, // should be separathy from kernel executed
};
RepackingImplType repacking_impl_type = RepackingImplType::NONE;
#endif // OPENVINO_ARCH_X86_64

std::vector<jit_snippets_call_args::loop_args_t> loop_args = {};
std::unordered_map<size_t, CpuBlockedMemoryDescPtr> m_in_requested_descs = {};
};

class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator {
public:
CPURuntimeConfigurator();
CPURuntimeConfigurator(ov::intel_cpu::MultiCacheWeakPtr cache = {});

/**
* @brief Calculate Loop parameters of Loop emitters and update these values in CPURuntimeConfig
* @param linear_ir LinearIR
*/
void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const;

const ov::intel_cpu::MultiCacheWeakPtr& get_cache() const {
return compiled_kernel_cache;
}

protected:
void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;
void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override;
void init_tensor_rank(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const override;
void initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override;

static const size_t rank6D;

ov::intel_cpu::MultiCacheWeakPtr compiled_kernel_cache;
};

} // namespace intel_cpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ class jit_snippet : public dnnl::impl::cpu::x64::jit_generator {

intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::intel_cpu::MultiCacheWeakPtr cache)
: TargetMachine(std::make_shared<CPURuntimeConfigurator>()),
: TargetMachine(std::make_shared<CPURuntimeConfigurator>(cache)),
h(new jit_snippet()),
isa(host_isa),
compiled_kernel_cache(std::move(cache)) {
Expand Down
Loading

0 comments on commit c6c62b5

Please sign in to comment.