From 1a960878451c50fd0e087b055b694fa110dae416 Mon Sep 17 00:00:00 2001 From: Saravanabalagi Ramachandran Date: Thu, 7 Mar 2024 11:20:56 +0000 Subject: [PATCH] Fix 1 out ch conv layer being treated as depthwise conv layer --- nni/compression/speedup/dependency.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/nni/compression/speedup/dependency.py b/nni/compression/speedup/dependency.py index cd9e590a45..91b98c8384 100644 --- a/nni/compression/speedup/dependency.py +++ b/nni/compression/speedup/dependency.py @@ -316,8 +316,12 @@ def build_channel_dependency(graph_module: torch.fx.GraphModule, if node.op == 'call_module': submodule = graph_module.get_submodule(node.target) # additional denpendency for (group number == output channel number) depth-wise conv: - if (isinstance(submodule, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)) and submodule.groups == submodule.out_channels) \ - or (isinstance(submodule, torch.nn.GroupNorm) and submodule.num_groups == submodule.num_channels): + if ( + # check for conv layers with groups > 1, for depthwise convolutions + isinstance(submodule, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)) + and submodule.groups == submodule.out_channels + and submodule.out_channels != 1 + ) or (isinstance(submodule, torch.nn.GroupNorm) and submodule.num_groups == submodule.num_channels): d_set = set([node] + find_adjacent_layers(node, graph_module, target_types, 'parent')) elif node.op == 'call_function':