Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Skip state initializer subgraphs of cross attention for decoder stateful model #28164

Open
wants to merge 10 commits into
base: master
Choose a base branch
from

Conversation

andrew-k-park
Copy link
Contributor

Details:

  • Decoder stateful model's cross attention will be generated as stateful model. However calculation of cross attention KV (which is to be an initializer subgraphs for ReadValue) is needed to be executed only once (i.e., at 1st token generation)
  • Optimize this cross attention stateful model by skipping state initializer subgraphs when the state is set
  • Add state initializer subgraphs markup

Tickets:

  • 144944

@andrew-k-park andrew-k-park requested review from a team as code owners December 20, 2024 07:13
@andrew-k-park andrew-k-park force-pushed the andrew_skip_init_subgraph branch 2 times, most recently from 608f280 to abddffc Compare December 22, 2024 06:59
@andrew-k-park andrew-k-park force-pushed the andrew_skip_init_subgraph branch from 1c7c4e3 to 07a5bbb Compare December 23, 2024 02:14
Comment on lines 702 to 703
for (auto& inst : _exec_order) {
if (inst->get_node().is_type<read_value>() && !inst->get_state_initializers().empty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shall save all read_value ops to a network's member to avoid loop over full exec order. So in case when we don't have any state variables in the model the overhead will be negligible (just checking that read_value vec is empty).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied as comment to save and check all read_value ops from network

Comment on lines 520 to 522
std::string state_id_of_init_subgraph;
std::vector<primitive_id> state_initializers;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest not storing that in program_node. Instead, you can probably add a std::map<std::string, std::vector<std::string>> to program itself which will reflect the relation between variable and init subgraph.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed state_initializers from program_node to program


void mark_state_init_subgraphs::run(program& p) {
auto rit = p.get_processing_order().rbegin();
if (p.is_new_shape_infer()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where this limitation comes from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unnecessary limitation. removed


using namespace cldnn;

void mark_state_init_subgraphs::mark_node(program_node* node) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void mark_state_init_subgraphs::mark_node(program_node* node) {
void mark_state_init_subgraphs::mark_init_subgraph(read_value_node& node) {

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

applied

for (auto& u : dep_node->get_users()) {
if (u == cur_node)
continue;
if (u->get_state_variable_id_of_init_subgraph().compare(variable_id) != 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it work well if a part of or full init subgraph is shared between multiple variables? Please add a test for such case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added multiple variables test case


auto can_be_marked = [&](const program_node* dep_node, const program_node* cur_node) {
for (auto& u : dep_node->get_users()) {
if (u == cur_node)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this cur_node arg is redundant as cur_node is supposed to always pass get_state_variable_id_of_init_subgraph check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed cur_node arg

for (auto& dep : cur_node->get_dependencies()) {
if (can_be_marked(dep.first, cur_node)) {
dep.first->set_state_variable_id_of_init_subgraph(variable_id);
if (!dep.first->is_constant())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to skip constants here? What's the purpose of that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -61,6 +63,13 @@ struct read_value_impl : public typed_primitive_impl<read_value> {
} else {
variable.get_memory()->fill(stream);
}
if (!instance.get_user_insts().empty()) {
auto user_inst = instance.get_user_insts().front();
if (!(user_inst->get_node().is_type<assign>() || user_inst->get_node().is_type<kv_cache>()) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get why you add this variable.set here. Could you explain? My expectation is that set should be called by either assign or kv_cache node, not read_value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
For cross attention block of whisper decoder stateful model, originally read_value has assign as user node. But since LoRA adapter update(https://github.com/openvinotoolkit/openvino/pull/26951/files#diff-1ad765a71b01b7bc7785efa67e742843153f65236bba833b3baa9644d1d7c538R79) is merged, assign primitives are not created because the input is read_value. Therefore, in this case(when the user of read_value is not assign or kv_cache), there is no choice but to set the state directly in read_value

@@ -355,6 +357,9 @@ class primitive_inst {
// List of depandant shape_of primitives for shape_of subgraphs
std::vector<primitive_inst*> dependant_shape_of_insts;

// List of dependant primitives for state initializer subgraphs
std::vector<primitive_inst*> state_initializers;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That can also be moved to cldnn::network member. It's not a common property of primitive_inst, so it shouldn't be a part of base class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved from primitive_inst to network

Copy link
Contributor

@yeonbok yeonbok left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but please follow up Vladimir's comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: GPU OpenVINO GPU plugin
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants