-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Skip state initializer subgraphs of cross attention for decoder stateful model #28164
base: master
Are you sure you want to change the base?
[GPU] Skip state initializer subgraphs of cross attention for decoder stateful model #28164
Conversation
608f280
to
abddffc
Compare
src/plugins/intel_gpu/src/graph/graph_optimizer/mark_state_init_subgraphs.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/graph_optimizer/mark_state_init_subgraphs.cpp
Outdated
Show resolved
Hide resolved
src/plugins/intel_gpu/src/graph/graph_optimizer/mark_state_init_subgraphs.cpp
Outdated
Show resolved
Hide resolved
1c7c4e3
to
07a5bbb
Compare
for (auto& inst : _exec_order) { | ||
if (inst->get_node().is_type<read_value>() && !inst->get_state_initializers().empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shall save all read_value ops to a network's member to avoid loop over full exec order. So in case when we don't have any state variables in the model the overhead will be negligible (just checking that read_value vec is empty).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied as comment to save and check all read_value ops from network
std::string state_id_of_init_subgraph; | ||
std::vector<primitive_id> state_initializers; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest not storing that in program_node. Instead, you can probably add a std::map<std::string, std::vector<std::string>>
to program itself which will reflect the relation between variable and init subgraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed state_initializers
from program_node
to program
|
||
void mark_state_init_subgraphs::run(program& p) { | ||
auto rit = p.get_processing_order().rbegin(); | ||
if (p.is_new_shape_infer()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where this limitation comes from?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unnecessary limitation. removed
|
||
using namespace cldnn; | ||
|
||
void mark_state_init_subgraphs::mark_node(program_node* node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void mark_state_init_subgraphs::mark_node(program_node* node) { | |
void mark_state_init_subgraphs::mark_init_subgraph(read_value_node& node) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
applied
for (auto& u : dep_node->get_users()) { | ||
if (u == cur_node) | ||
continue; | ||
if (u->get_state_variable_id_of_init_subgraph().compare(variable_id) != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will it work well if a part of or full init subgraph is shared between multiple variables? Please add a test for such case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added multiple variables test case
|
||
auto can_be_marked = [&](const program_node* dep_node, const program_node* cur_node) { | ||
for (auto& u : dep_node->get_users()) { | ||
if (u == cur_node) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this cur_node
arg is redundant as cur_node
is supposed to always pass get_state_variable_id_of_init_subgraph check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed cur_node
arg
for (auto& dep : cur_node->get_dependencies()) { | ||
if (can_be_marked(dep.first, cur_node)) { | ||
dep.first->set_state_variable_id_of_init_subgraph(variable_id); | ||
if (!dep.first->is_constant()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to skip constants here? What's the purpose of that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -61,6 +63,13 @@ struct read_value_impl : public typed_primitive_impl<read_value> { | |||
} else { | |||
variable.get_memory()->fill(stream); | |||
} | |||
if (!instance.get_user_insts().empty()) { | |||
auto user_inst = instance.get_user_insts().front(); | |||
if (!(user_inst->get_node().is_type<assign>() || user_inst->get_node().is_type<kv_cache>()) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get why you add this variable.set
here. Could you explain? My expectation is that set should be called by either assign or kv_cache node, not read_value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For cross attention block of whisper decoder stateful model, originally read_value has assign as user node. But since LoRA adapter update(https://github.com/openvinotoolkit/openvino/pull/26951/files#diff-1ad765a71b01b7bc7785efa67e742843153f65236bba833b3baa9644d1d7c538R79) is merged, assign primitives are not created because the input is read_value. Therefore, in this case(when the user of read_value is not assign or kv_cache), there is no choice but to set the state directly in read_value
@@ -355,6 +357,9 @@ class primitive_inst { | |||
// List of depandant shape_of primitives for shape_of subgraphs | |||
std::vector<primitive_inst*> dependant_shape_of_insts; | |||
|
|||
// List of dependant primitives for state initializer subgraphs | |||
std::vector<primitive_inst*> state_initializers; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That can also be moved to cldnn::network member. It's not a common property of primitive_inst, so it shouldn't be a part of base class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved from primitive_inst
to network
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but please follow up Vladimir's comments
TODO: model_cache support & set variable_id to the init subgraph, instead of pointer
Signed-off-by: Andrew Park <[email protected]>
Signed-off-by: Andrew Park <[email protected]>
1da293e
to
d43ae11
Compare
Details:
ReadValue
) is needed to be executed only once (i.e., at 1st token generation)Tickets: