Skip to content

Commit

Permalink
Fixed bug in PReLU emitter, enabled PReLU, Sqrt, Round tokenization
Browse files Browse the repository at this point in the history
  • Loading branch information
0xfedcafe committed Dec 29, 2024
1 parent 82d553e commit dbc6f96
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "common/utils.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"

namespace ov {
Expand Down Expand Up @@ -2128,7 +2129,7 @@ size_t jit_prelu_emitter::get_aux_vecs_count() const {

std::set<std::vector<element::Type>> jit_prelu_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
return {{element::f32, element::f32}};
}

void jit_prelu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
#include "emitters/snippets/cpu_runtime_configurator.hpp"
#include "emitters/utils.hpp"
#include "jit_snippets_emitters.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/round.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/opsets/opset13.hpp"
#include "snippets/emitter.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
Expand Down Expand Up @@ -73,6 +78,38 @@ namespace ov {
} \
}

#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \
{ \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
const auto& n = expr->get_node(); \
const auto& round = std::dynamic_pointer_cast<ov::op::v5::Round>(n); \
if(round == nullptr) { \
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
} \
const auto roundingMode = round->get_mode(); \
if (roundingMode == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
return std::make_shared<e_type_from_zero>(h.get(), isa, n); \
} else if(roundingMode == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
return std::make_shared<e_type_even>(h.get(), isa, n); \
} \
else { \
OPENVINO_THROW("Unsupported Round mode"); \
} \
}, \
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
const auto& round = std::dynamic_pointer_cast<ov::op::v5::Round>(n); \
if (round == nullptr) { \
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
} \
if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
return e_type_from_zero::get_supported_precisions(n); \
} else if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
return e_type_even::get_supported_precisions(n); \
} \
OPENVINO_THROW("Unsupported Round mode"); \
} \
}

class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_snippet)
Expand Down Expand Up @@ -149,8 +186,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa)
CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter);
jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_hswish_emitter);
jitters[ov::op::v4::Mish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mish_emitter);
jitters[ov::op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_prelu_emitter);
jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_relu_emitter);
jitters[ov::op::v5::Round::get_type_info_static()] =
CREATE_ROUND_V5_EMITTER(jit_round_half_away_from_zero_emitter, jit_round_half_to_even_emitter);
jitters[ov::op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sigmoid_emitter);
jitters[ov::op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sqrt_emitter);
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <ov_ops/gather_compressed.hpp>

#include "openvino/op/paged_attention.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/round.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset2.hpp"
Expand Down Expand Up @@ -1127,9 +1130,10 @@ void Transformations::MainSnippets(void) {
ov::is_type<ov::op::v4::HSwish>(n) || ov::is_type<ov::op::v1::Maximum>(n) ||
ov::is_type<ov::op::v1::Minimum>(n) || ov::is_type<ov::op::v4::Mish>(n) ||
ov::is_type<ov::op::v1::Mod>(n) || ov::is_type<ov::op::v1::Multiply>(n) ||
ov::is_type<ov::op::v0::Relu>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
ov::is_type<ov::op::v1::Subtract>(n) || ov::is_type<ov::op::v4::Swish>(n) ||
ov::is_type<ov::op::v0::Tanh>(n));
ov::is_type<ov::op::v0::PRelu>(n) || ov::is_type<ov::op::v0::Relu>(n) ||
ov::is_type<ov::op::v5::Round>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
ov::is_type<ov::op::v0::Sqrt>(n) || ov::is_type<ov::op::v1::Subtract>(n) ||
ov::is_type<ov::op::v4::Swish>(n) || ov::is_type<ov::op::v0::Tanh>(n));
#else
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant,
// and CPU Plugin does not support Mish for x64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,22 +265,26 @@ const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activat

const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activationTypesSnippets() {
static const std::map<utils::ActivationTypes, std::vector<std::vector<float>>> activationTypes {
{Abs, {{}}},
{Exp, {{}}},
{Ceiling, {{}}},
{Clamp, {{-2.0f, 2.0f}}},
{Elu, {{0.1f}}},
{Floor, {{}}},
{GeluErf, {{}}},
{GeluTanh, {{}}},
{Relu, {{}}},
{HSwish, {{}}},
{Abs, {{}}},
{Exp, {{}}},
{Ceiling, {{}}},
{Clamp, {{-2.0f, 2.0f}}},
{Elu, {{0.1f}}},
{Floor, {{}}},
{GeluErf, {{}}},
{GeluTanh, {{}}},
{Relu, {{}}},
{HSwish, {{}}},
{PReLu, {{-0.01f}}},
{Sqrt, {{}}},
{RoundHalfToEven, {{}}},
{RoundHalfAwayFromZero, {{}}},
#if defined(OPENVINO_ARCH_ARM64)
{Mish, {{}}},
{Mish, {{}}},
#endif
{Sigmoid, {{}}},
{Swish, {{0.1f}}},
{Tanh, {{}}},
{Sigmoid, {{}}},
{Swish, {{0.1f}}},
{Tanh, {{}}},
};

return activationTypes;
Expand Down

0 comments on commit dbc6f96

Please sign in to comment.