Skip to content

Commit

Permalink
Merge latest modifications from the BiRefNet repository
Browse files Browse the repository at this point in the history
  • Loading branch information
dimitribarbot committed Aug 21, 2024
1 parent d1a2b3d commit b80fb2f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions birefnet/models/birefnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def forward(self, features):
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
if self.config.out_ref:
p4_gdt = self.gdt_convs_4(p4)
if self.training:
Expand All @@ -206,7 +206,7 @@ def forward(self, features):
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
if self.config.out_ref:
p3_gdt = self.gdt_convs_3(p3)
if self.training:
Expand All @@ -232,7 +232,7 @@ def forward(self, features):
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
if self.config.out_ref:
p2_gdt = self.gdt_convs_2(p2)
if self.training:
Expand Down Expand Up @@ -260,7 +260,7 @@ def forward(self, features):
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)

if self.config.ms_supervision:
if self.config.ms_supervision and self.training:
outs.append(m4)
outs.append(m3)
outs.append(m2)
Expand Down

0 comments on commit b80fb2f

Please sign in to comment.