Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Skip state initializer subgraphs of cross attention for decoder stateful model #28164

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/network.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ struct network {
bool is_primary_stream() const { return _is_primary_stream; }
bool is_dynamic() const { return _is_dynamic; }
size_t get_weights_cache_capacity() const { return _weights_cache_capacity; }
bool contains_state(const std::string& variable_id);

memory_pool& get_memory_pool() const {
return *_memory_pool;
Expand Down Expand Up @@ -225,6 +226,8 @@ struct network {

ov::intel_gpu::VariablesMap _variables_states;
ov::intel_gpu::VariablesInfoMap _variables_state_info;
std::vector<std::shared_ptr<primitive_inst>> _read_values;
std::unordered_map<primitive_id, std::vector<std::shared_ptr<primitive_inst>>> _state_initializers;

program::primitives_info _prims_info;
size_t _weights_cache_capacity = 1;
Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/graph/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ struct program {
program_node const& get_node(primitive_id const& id) const;
std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) { return nodes_map.at(prim); }
std::shared_ptr<program_node> get_node_ptr(const primitive_id& prim) const { return nodes_map.at(prim); }
void set_state_initializers(const std::string& variable_id, const primitive_id& id);
bool has_state_initializers(const std::string& variable_id, const primitive_id& id);
bool contains_state(const std::string& variable_id);
const std::vector<primitive_id>& get_initializers(const std::string& variable_id) { return state_initializers.at(variable_id); }
const std::map<std::string, std::vector<primitive_id>>& get_state_initializers() const { return state_initializers; }

// returns already existing program_node for given primitive 'prim' (lookup in 'nodes_map')
// if it was previously created, otherwise creates and then returns program_node
Expand Down Expand Up @@ -322,6 +327,8 @@ struct program {
primitives_info prim_info;
graph_optimizer_info optimizer_passes_info;

std::map<std::string, std::vector<primitive_id>> state_initializers;

primitives_info get_current_stage_info() const;
/*
** High-level functions, in order of usage
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "read_value_inst.h"
#include "pass_manager.h"
#include <queue>

#include "intel_gpu/graph/program.hpp"

using namespace cldnn;

void mark_state_init_subgraphs::mark_init_subgraph(program& p, read_value_node& node) {
const auto& variable_id = node.get_primitive()->variable_id;
if (p.contains_state(variable_id))
return;

std::queue<program_node*> q;
q.push(&node);

auto can_be_marked = [&](const program_node* dep_node) {
if (p.has_state_initializers(variable_id, dep_node->id()))
return false;

for (auto& u : dep_node->get_users()) {
if (u == &node)
continue;
if (p.has_state_initializers(variable_id, u->id()))
continue;
else
return false;
}
GPU_DEBUG_TRACE_DETAIL << "marked " << dep_node->id() << " as node in a init_subgraph for " << node.id() << std::endl;
return true;
};

while (!q.empty()) {
auto cur_size = q.size();
for (size_t i = 0; i < cur_size; ++i) {
auto& cur_node = q.front();
q.pop();
for (auto& dep : cur_node->get_dependencies()) {
if (can_be_marked(dep.first)) {
p.set_state_initializers(variable_id, dep.first->id());
q.push(dep.first);
}
}
}
}
}

void mark_state_init_subgraphs::run(program& p) {
auto rit = p.get_processing_order().rbegin();
for (; rit != p.get_processing_order().rend(); rit++) {
auto& node = *rit;
if (node->is_type<read_value>()) {
mark_init_subgraph(p, node->as<read_value>());
}
}
}
9 changes: 9 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/cpu/read_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
//

#include "impls/cpu/cpu_impl_helpers.hpp"
#include "assign_inst.h"
#include "kv_cache_inst.h"
#include "read_value_inst.h"
#include "impls/registry/implementation_map.hpp"
#include "register.hpp"
Expand Down Expand Up @@ -61,6 +63,13 @@ struct read_value_impl : public typed_primitive_impl<read_value> {
} else {
variable.get_memory()->fill(stream);
}
if (!instance.get_user_insts().empty()) {
auto user_inst = instance.get_user_insts().front();
if (!(user_inst->get_node().is_type<assign>() || user_inst->get_node().is_type<kv_cache>()) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get why you add this variable.set here. Could you explain? My expectation is that set should be called by either assign or kv_cache node, not read_value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
For cross attention block of whisper decoder stateful model, originally read_value has assign as user node. But since LoRA adapter update(https://github.com/openvinotoolkit/openvino/pull/26951/files#diff-1ad765a71b01b7bc7785efa67e742843153f65236bba833b3baa9644d1d7c538R79) is merged, assign primitives are not created because the input is read_value. Therefore, in this case(when the user of read_value is not assign or kv_cache), there is no choice but to set the state directly in read_value

instance.get_network().contains_state(variable_id)) {
variable.set();
}
}
}

if (!instance.can_be_optimized()) {
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "quantize_inst.h"
#include "eltwise_inst.h"
#include "convolution_inst.h"
#include "read_value_inst.h"
#include <string>
#include <vector>
#include <memory>
Expand Down Expand Up @@ -89,6 +90,16 @@ class mark_nodes : public base_pass {
void run(program& p) override;
};

class mark_state_init_subgraphs : public base_pass {
// This optimization pass aggregates nodes into state initializer subgraphs
public:
mark_state_init_subgraphs() : base_pass("mark_state_init_subgraphs") {}

private:
void run(program& p) override;
void mark_init_subgraph(program& p, read_value_node& node);
};

class mark_shape_of_subgraphs : public base_pass {
// This optimization pass aggregates nodes into shape_of subgraphs for further optimizations.
// There are few key requirements to decide if node belongs to shape_of subgraph or not:
Expand Down
31 changes: 31 additions & 0 deletions src/plugins/intel_gpu/src/graph/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,15 @@ void network::build_exec_order() {
}
}
}

bool network::contains_state(const std::string& variable_id) {
auto it = _state_initializers.find(variable_id);
if (it != _state_initializers.end())
return true;
else
return false;
}

void network::add_to_exec_order(const primitive_id& id) {
auto inst = get_primitive(id);
_exec_order.push_back(inst);
Expand Down Expand Up @@ -698,6 +707,19 @@ std::map<primitive_id, network_output> network::execute(const std::vector<event:
}
}

for (auto& inst : _read_values) {
const auto& prim = inst->get_node().as<read_value>().get_primitive();
auto it = _state_initializers.find(prim->variable_id);
if (it != _state_initializers.end()) {
const auto& variable = get_variable(prim->variable_id);
if (variable.is_set()) {
for (auto& init_inst : it->second) {
init_inst->set_flag(ExecutionFlags::SKIP);
}
}
}
}

// We shouldn't call surfaces_lock::create() function constantly here, but due to
// some changes in assembler code, performance drops in case if we move it under
// `shared_mem_found` condition (it somehow connected with get_cl_queue() - this function call
Expand Down Expand Up @@ -913,6 +935,15 @@ void network::allocate_primitive_instance(program_node const& node) {
if (node.is_type<data>())
_data_outputs.push_back(inst);
}
if (node.is_type<read_value>()) {
_read_values.push_back(inst);
yeonbok marked this conversation as resolved.
Show resolved Hide resolved
const auto& variable_id = node.as<read_value>().get_primitive()->variable_id;
if (_program->contains_state(variable_id)) {
for (const auto& id : _program->get_initializers(variable_id)) {
_state_initializers[variable_id].push_back(get_primitive(id));
}
}
}
if (auto state_prim = std::dynamic_pointer_cast<memory_state::variable>(inst)) {
auto prim = inst->get_node().get_primitive();
set_variables_state_info(state_prim->variable_id(), node.get_output_layout(0), state_prim->get_user_specified_type(), prim.get());
Expand Down
23 changes: 23 additions & 0 deletions src/plugins/intel_gpu/src/graph/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,8 @@ void program::post_optimize_graph(bool is_internal) {
// for OOO queue
if (_config.get_property(ov::intel_gpu::queue_type) == QueueTypes::out_of_order)
get_processing_order().calculate_BFS_processing_order();

apply_opt_pass<mark_state_init_subgraphs>();
}

// mark if the node is constant assuming that all dependencies are marked properly
Expand Down Expand Up @@ -830,6 +832,27 @@ void program::reverse_connection(program_node& dep_node, program_node& user_node
}
}

void program::set_state_initializers(const std::string& variable_id, const primitive_id& id) {
state_initializers[variable_id].push_back(id);
}

bool program::has_state_initializers(const std::string& variable_id, const primitive_id& id) {
auto it = state_initializers.find(variable_id);
if (it != state_initializers.end()) {
const auto& initializers = it->second;
return std::find(initializers.begin(), initializers.end(), id) != initializers.end();
}
return false;
}

bool program::contains_state(const std::string& variable_id) {
auto it = state_initializers.find(variable_id);
if (it != state_initializers.end())
return true;
else
return false;
}

program_node& program::get_or_create(std::shared_ptr<primitive> prim) {
auto itr = nodes_map.lower_bound(prim->id);
if (itr != nodes_map.end() && itr->first == prim->id)
Expand Down
2 changes: 0 additions & 2 deletions src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include "gemm_inst.h"
#include "fully_connected_inst.h"
#include "deconvolution_inst.h"
#include "quantize_inst.h"
#include "reorder_inst.h"
#include "pooling_inst.h"
#include "reduce_inst.h"
#include <impls/onednn/utils.hpp>
Expand Down
Loading
Loading