From 1956e27ed85e14683b23706c6591fbe0f8297f82 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Mon, 23 Dec 2024 18:03:03 +0100 Subject: [PATCH 1/4] add base structure --- docs/source/en/_toctree.yml | 2 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 2 + .../models/d-fine/modular_d_fine.py | 314 ++++++++++++++++++ 5 files changed, 321 insertions(+) create mode 100644 src/transformers/models/d-fine/modular_d_fine.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index a076f704b8ede2..9c5a6718f55ade 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -703,6 +703,8 @@ title: ResNet - local: model_doc/rt_detr title: RT-DETR + - local: model_doc/d-fine + title: D-FINE - local: model_doc/segformer title: SegFormer - local: model_doc/seggpt diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index ff03d09966a4d6..55e54f9af1a803 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -74,6 +74,7 @@ deprecated, depth_anything, detr, + d_fine, dialogpt, dinat, dinov2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 6c052aa0eaa0f3..f35471c988ff2e 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -251,6 +251,7 @@ ("roformer", "RoFormerConfig"), ("rt_detr", "RTDetrConfig"), ("rt_detr_resnet", "RTDetrResNetConfig"), + ("d_fine", "DFineConfig"), ("rwkv", "RwkvConfig"), ("sam", "SamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), @@ -579,6 +580,7 @@ ("roformer", "RoFormer"), ("rt_detr", "RT-DETR"), ("rt_detr_resnet", "RT-DETR-ResNet"), + ("d_fine", "D-FINE"), ("rwkv", "RWKV"), ("sam", "SAM"), ("seamless_m4t", "SeamlessM4T"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 861754f591769b..2cf1821ac7baae 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -232,6 +232,7 @@ ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), ("rt_detr", "RTDetrModel"), + ("d_fine", "DFineModel"), ("rwkv", "RwkvModel"), ("sam", "SamModel"), ("seamless_m4t", "SeamlessM4TModel"), @@ -875,6 +876,7 @@ ("deta", "DetaForObjectDetection"), ("detr", "DetrForObjectDetection"), ("rt_detr", "RTDetrForObjectDetection"), + ("d_fine", "DFineForObjectDetection"), ("table-transformer", "TableTransformerForObjectDetection"), ("yolos", "YolosForObjectDetection"), ] diff --git a/src/transformers/models/d-fine/modular_d_fine.py b/src/transformers/models/d-fine/modular_d_fine.py new file mode 100644 index 00000000000000..c4aab1cef40033 --- /dev/null +++ b/src/transformers/models/d-fine/modular_d_fine.py @@ -0,0 +1,314 @@ +from ..rt_detr.configuration_rt_detr import RTDetrConfig +from ..rt_detr.modeling_rt_detr import ( + RTDetrDecoderLayer, RTDetrModelOutput, RTDetrObjectDetectionOutput, RTDetrHybridEncoder, + RTDetrEncoderLayer, RTDetrConvEncoder, RTDetrMLPPredictionHead, RTDetrDecoderOutput, + RTDetrMultiscaleDeformableAttention, MultiScaleDeformableAttentionFunction, get_contrastive_denoising_training_group, + inverse_sigmoid +) + +import torch +from torch import nn +import torch.nn.init as init +import torch.nn.functional as F +import math +from typing import List, Optional +import functools + + +class DFineConfig(RTDetrConfig): + model_type = "d-fine" + + def __init__(self, + decoder_offset_scale=0.5, # default value + **super_kwargs): + super().__init__(**super_kwargs) + + self.decoder_offset_scale = decoder_offset_scale + + +def deformable_attention_core_func_v2(\ + value: torch.Tensor, + value_spatial_shapes, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + num_points_list: List[int], + method='default'): + """ + Args: + value (Tensor): [bs, value_length, n_head, c] + value_spatial_shapes (Tensor|List): [n_levels, 2] + value_level_start_index (Tensor|List): [n_levels] + sampling_locations (Tensor): [bs, query_length, n_head, n_levels * n_points, 2] + attention_weights (Tensor): [bs, query_length, n_head, n_levels * n_points] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, n_head, c, _ = value[0].shape + _, Len_q, _, _, _ = sampling_locations.shape + + # sampling_offsets [8, 480, 8, 12, 2] + if method == 'default': + sampling_grids = 2 * sampling_locations - 1 + + elif method == 'discrete': + sampling_grids = sampling_locations + + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) + sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) + + sampling_value_list = [] + for level, (h, w) in enumerate(value_spatial_shapes): + value_l = value[level].reshape(bs * n_head, c, h, w) + sampling_grid_l: torch.Tensor = sampling_locations_list[level] + + if method == 'default': + sampling_value_l = F.grid_sample( + value_l, + sampling_grid_l, + mode='bilinear', + padding_mode='zeros', + align_corners=False) + + elif method == 'discrete': + # n * m, seq, n, 2 + sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5).to(torch.int64) + + # FIX ME? for rectangle input + sampling_coord = sampling_coord.clamp(0, h - 1) + sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) + + s_idx = torch.arange(sampling_coord.shape[0], device=value_l.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1]) + sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c + + sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level]) + + sampling_value_list.append(sampling_value_l) + + attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, Len_q, sum(num_points_list)) + weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights + output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, Len_q) + + return output.permute(0, 2, 1) + + +class DFineMultiscaleDeformableAttention(nn.Module): + def __init__( + self, + config: DFineConfig, + method='default', + ): + """ + D-Fine version of multiscale deformable attention + """ + super(DFineMultiscaleDeformableAttention, self).__init__() + self.d_model = config.d_model + self.n_levels = config.num_feature_levels + self.n_heads = config.decoder_attention_heads + self.n_points = config.decoder_n_points + self.offset_scale = config.decoder_offset_scale + + if isinstance(self.n_points, list): + assert len(self.n_points) == self.n_levels, '' + num_points_list = self.n_points + else: + num_points_list = [self.n_points for _ in range(self.n_levels)] + + self.num_points_list = num_points_list + + num_points_scale = [1/n for n in num_points_list for _ in range(n)] + self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32)) + + self.total_points = self.n_heads * sum(num_points_list) + self.method = method + + self.head_dim = self.d_model // self.n_heads + assert self.head_dim * self.n_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + self.sampling_offsets = nn.Linear(self.d_model, self.total_points * 2) + self.attention_weights = nn.Linear(self.d_model, self.total_points) + + self.ms_deformable_attn_core = functools.partial(deformable_attention_core_func_v2, method=self.method) + + self._reset_parameters() + + if method == 'discrete': + for p in self.sampling_offsets.parameters(): + p.requires_grad = False + + def _reset_parameters(self): + # sampling_offsets + init.constant_(self.sampling_offsets.weight, 0) + thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values + grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) + scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1) + grid_init *= scaling + self.sampling_offsets.bias.data[...] = grid_init.flatten() + + # attention_weights + init.constant_(self.attention_weights.weight, 0) + init.constant_(self.attention_weights.bias, 0) + + + def forward(self, + query: torch.Tensor, + reference_points: torch.Tensor, + value: torch.Tensor, + value_spatial_shapes: List[int]): + """ + Args: + query (Tensor): [bs, query_length, C] + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area + value (Tensor): [bs, value_length, C] + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + + Returns: + output (Tensor): [bs, Length_{query}, C] + """ + bs, Len_q = query.shape[:2] + + sampling_offsets: torch.Tensor = self.sampling_offsets(query) + sampling_offsets = sampling_offsets.reshape(bs, Len_q, self.num_heads, sum(self.num_points_list), 2) + + attention_weights = self.attention_weights(query).reshape(bs, Len_q, self.num_heads, sum(self.num_points_list)) + attention_weights = F.softmax(attention_weights, dim=-1) + + if reference_points.shape[-1] == 2: + offset_normalizer = torch.tensor(value_spatial_shapes) + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) + sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer + elif reference_points.shape[-1] == 4: + # reference_points [8, 480, None, 1, 4] + # sampling_offsets [8, 480, 8, 12, 2] + num_points_scale = self.num_points_scale.to(dtype=query.dtype).unsqueeze(-1) + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * self.offset_scale + sampling_locations = reference_points[:, :, None, :, :2] + offset + else: + raise ValueError( + "Last dim of reference_points must be 2 or 4, but get {} instead.". + format(reference_points.shape[-1])) + + output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list) + + return output + + +class DFineDecoderLayer(RTDetrDecoderLayer): + def __init__(self, config: DFineConfig): + # initialize parent class + super().__init__(config) + + # override the encoder attention module with d-fine version + self.encoder_attn = DFineMultiscaleDeformableAttention(config=config) + # gate + self.gateway = Gate(config.d_model) + self._reset_parameters() + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[torch.Tensor] = None, + reference_points=None, + spatial_shapes=None, + spatial_shapes_list=None, + level_start_index=None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ): + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(seq_len, batch, embed_dim)`. + position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings that are added to the queries and keys in the self-attention layer. + reference_points (`torch.FloatTensor`, *optional*): + Reference points. + spatial_shapes (`torch.LongTensor`, *optional*): + Spatial shapes. + level_start_index (`torch.LongTensor`, *optional*): + Level start index. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(seq_len, batch, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative + values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=encoder_attention_mask, + position_embeddings=position_embeddings, + output_attentions=output_attentions, + ) + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + second_residual = hidden_states + + # Cross-Attention + cross_attn_weights = None + hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + position_embeddings=position_embeddings, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + spatial_shapes_list=spatial_shapes_list, + level_start_index=level_start_index, + output_attentions=output_attentions, + ) + + hidden_states = self.gateway(second_residual, nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)) + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states.clamp(min=-65504, max=65504)) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + return outputs + + + def _reset_parameters(self): + init.xavier_uniform_(self.fc1.weight) + init.xavier_uniform_(self.fc2.weight) + + +class Gate(nn.Module): + def __init__(self, d_model): + super(Gate, self).__init__() + self.gate = nn.Linear(2 * d_model, 2 * d_model) + bias = self._bias_init_with_prob(0.5) + init.constant_(self.gate.bias, bias) + init.constant_(self.gate.weight, 0) + self.norm = nn.LayerNorm(d_model) + + def forward(self, x1, x2): + gate_input = torch.cat([x1, x2], dim=-1) + gates = torch.sigmoid(self.gate(gate_input)) + gate1, gate2 = gates.chunk(2, dim=-1) + return self.norm(gate1 * x1 + gate2 * x2) + + def _bias_init_with_prob(self, prior_prob=0.01): + """initialize conv/fc bias value according to a given probability value.""" + bias_init = float(-math.log((1 - prior_prob) / prior_prob)) + return bias_init From 7700b4e25760d65646427fb2ff33fbb86dca954d Mon Sep 17 00:00:00 2001 From: Vladislav Date: Thu, 26 Dec 2024 15:02:38 +0100 Subject: [PATCH 2/4] add decoder --- .../models/d-fine/modular_d_fine.py | 362 ++++++++++++++---- 1 file changed, 294 insertions(+), 68 deletions(-) diff --git a/src/transformers/models/d-fine/modular_d_fine.py b/src/transformers/models/d-fine/modular_d_fine.py index c4aab1cef40033..edfb9380f53c4a 100644 --- a/src/transformers/models/d-fine/modular_d_fine.py +++ b/src/transformers/models/d-fine/modular_d_fine.py @@ -1,9 +1,10 @@ from ..rt_detr.configuration_rt_detr import RTDetrConfig from ..rt_detr.modeling_rt_detr import ( - RTDetrDecoderLayer, RTDetrModelOutput, RTDetrObjectDetectionOutput, RTDetrHybridEncoder, - RTDetrEncoderLayer, RTDetrConvEncoder, RTDetrMLPPredictionHead, RTDetrDecoderOutput, - RTDetrMultiscaleDeformableAttention, MultiScaleDeformableAttentionFunction, get_contrastive_denoising_training_group, - inverse_sigmoid + RTDetrDecoderLayer, + RTDetrDecoder, + RTDetrMLPPredictionHead, + RTDetrDecoderOutput, + inverse_sigmoid, ) import torch @@ -13,26 +14,28 @@ import math from typing import List, Optional import functools +from ...activations import ACT2CLS class DFineConfig(RTDetrConfig): model_type = "d-fine" - def __init__(self, - decoder_offset_scale=0.5, # default value - **super_kwargs): + def __init__(self, decoder_offset_scale=0.5, eval_idx=-1, layer_scale=2, **super_kwargs): super().__init__(**super_kwargs) self.decoder_offset_scale = decoder_offset_scale + self.eval_idx = eval_idx + self.layer_scale = layer_scale -def deformable_attention_core_func_v2(\ +def deformable_attention_core_func_v2( value: torch.Tensor, value_spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int], - method='default'): + method="default", +): """ Args: value (Tensor): [bs, value_length, n_head, c] @@ -48,10 +51,10 @@ def deformable_attention_core_func_v2(\ _, Len_q, _, _, _ = sampling_locations.shape # sampling_offsets [8, 480, 8, 12, 2] - if method == 'default': + if method == "default": sampling_grids = 2 * sampling_locations - 1 - elif method == 'discrete': + elif method == "discrete": sampling_grids = sampling_locations sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) @@ -62,15 +65,12 @@ def deformable_attention_core_func_v2(\ value_l = value[level].reshape(bs * n_head, c, h, w) sampling_grid_l: torch.Tensor = sampling_locations_list[level] - if method == 'default': + if method == "default": sampling_value_l = F.grid_sample( - value_l, - sampling_grid_l, - mode='bilinear', - padding_mode='zeros', - align_corners=False) + value_l, sampling_grid_l, mode="bilinear", padding_mode="zeros", align_corners=False + ) - elif method == 'discrete': + elif method == "discrete": # n * m, seq, n, 2 sampling_coord = (sampling_grid_l * torch.tensor([[w, h]], device=value_l.device) + 0.5).to(torch.int64) @@ -78,8 +78,12 @@ def deformable_attention_core_func_v2(\ sampling_coord = sampling_coord.clamp(0, h - 1) sampling_coord = sampling_coord.reshape(bs * n_head, Len_q * num_points_list[level], 2) - s_idx = torch.arange(sampling_coord.shape[0], device=value_l.device).unsqueeze(-1).repeat(1, sampling_coord.shape[1]) - sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c + s_idx = ( + torch.arange(sampling_coord.shape[0], device=value_l.device) + .unsqueeze(-1) + .repeat(1, sampling_coord.shape[1]) + ) + sampling_value_l: torch.Tensor = value_l[s_idx, :, sampling_coord[..., 1], sampling_coord[..., 0]] # n l c sampling_value_l = sampling_value_l.permute(0, 2, 1).reshape(bs * n_head, c, Len_q, num_points_list[level]) @@ -96,7 +100,7 @@ class DFineMultiscaleDeformableAttention(nn.Module): def __init__( self, config: DFineConfig, - method='default', + method="default", ): """ D-Fine version of multiscale deformable attention @@ -109,15 +113,15 @@ def __init__( self.offset_scale = config.decoder_offset_scale if isinstance(self.n_points, list): - assert len(self.n_points) == self.n_levels, '' + assert len(self.n_points) == self.n_levels, "" num_points_list = self.n_points else: num_points_list = [self.n_points for _ in range(self.n_levels)] self.num_points_list = num_points_list - num_points_scale = [1/n for n in num_points_list for _ in range(n)] - self.register_buffer('num_points_scale', torch.tensor(num_points_scale, dtype=torch.float32)) + num_points_scale = [1 / n for n in num_points_list for _ in range(n)] + self.register_buffer("num_points_scale", torch.tensor(num_points_scale, dtype=torch.float32)) self.total_points = self.n_heads * sum(num_points_list) self.method = method @@ -132,7 +136,7 @@ def __init__( self._reset_parameters() - if method == 'discrete': + if method == "discrete": for p in self.sampling_offsets.parameters(): p.requires_grad = False @@ -151,12 +155,9 @@ def _reset_parameters(self): init.constant_(self.attention_weights.weight, 0) init.constant_(self.attention_weights.bias, 0) - - def forward(self, - query: torch.Tensor, - reference_points: torch.Tensor, - value: torch.Tensor, - value_spatial_shapes: List[int]): + def forward( + self, query: torch.Tensor, reference_points: torch.Tensor, value: torch.Tensor, value_spatial_shapes: List[int] + ): """ Args: query (Tensor): [bs, query_length, C] @@ -179,7 +180,9 @@ def forward(self, if reference_points.shape[-1] == 2: offset_normalizer = torch.tensor(value_spatial_shapes) offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) - sampling_locations = reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer + sampling_locations = ( + reference_points.reshape(bs, Len_q, 1, self.num_levels, 1, 2) + sampling_offsets / offset_normalizer + ) elif reference_points.shape[-1] == 4: # reference_points [8, 480, None, 1, 4] # sampling_offsets [8, 480, 8, 12, 2] @@ -188,10 +191,12 @@ def forward(self, sampling_locations = reference_points[:, :, None, :, :2] + offset else: raise ValueError( - "Last dim of reference_points must be 2 or 4, but get {} instead.". - format(reference_points.shape[-1])) + "Last dim of reference_points must be 2 or 4, but get {} instead.".format(reference_points.shape[-1]) + ) - output = self.ms_deformable_attn_core(value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list) + output = self.ms_deformable_attn_core( + value, value_spatial_shapes, sampling_locations, attention_weights, self.num_points_list + ) return output @@ -206,40 +211,17 @@ def __init__(self, config: DFineConfig): # gate self.gateway = Gate(config.d_model) self._reset_parameters() - + def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[torch.Tensor] = None, reference_points=None, spatial_shapes=None, - spatial_shapes_list=None, - level_start_index=None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, ): - """ - Args: - hidden_states (`torch.FloatTensor`): - Input to the layer of shape `(seq_len, batch, embed_dim)`. - position_embeddings (`torch.FloatTensor`, *optional*): - Position embeddings that are added to the queries and keys in the self-attention layer. - reference_points (`torch.FloatTensor`, *optional*): - Reference points. - spatial_shapes (`torch.LongTensor`, *optional*): - Spatial shapes. - level_start_index (`torch.LongTensor`, *optional*): - Level start index. - encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(seq_len, batch, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative - values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ residual = hidden_states # Self Attention @@ -259,17 +241,15 @@ def forward( # Cross-Attention cross_attn_weights = None hidden_states, cross_attn_weights = self.encoder_attn( - hidden_states=hidden_states, - encoder_hidden_states=encoder_hidden_states, - position_embeddings=position_embeddings, + query=hidden_states, + value=encoder_hidden_states, reference_points=reference_points, spatial_shapes=spatial_shapes, - spatial_shapes_list=spatial_shapes_list, - level_start_index=level_start_index, - output_attentions=output_attentions, ) - hidden_states = self.gateway(second_residual, nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)) + hidden_states = self.gateway( + second_residual, nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + ) # Fully Connected residual = hidden_states @@ -287,12 +267,258 @@ def forward( return outputs - def _reset_parameters(self): init.xavier_uniform_(self.fc1.weight) init.xavier_uniform_(self.fc2.weight) +def weighting_function(reg_max, up, reg_scale): + """ + Generates the non-uniform Weighting Function W(n) for bounding box regression. + + Args: + reg_max (int): Max number of the discrete bins. + up (Tensor): Controls upper bounds of the sequence, + where maximum offset is ±up * H / W. + reg_scale (float): Controls the curvature of the Weighting Function. + Larger values result in flatter weights near the central axis W(reg_max/2)=0 + and steeper weights at both ends. + Returns: + Tensor: Sequence of Weighting Function. + """ + upper_bound1 = abs(up[0]) * abs(reg_scale) + upper_bound2 = abs(up[0]) * abs(reg_scale) * 2 + step = (upper_bound1 + 1) ** (2 / (reg_max - 2)) + left_values = [-((step) ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)] + right_values = [(step) ** i - 1 for i in range(1, reg_max // 2)] + values = [-upper_bound2] + left_values + [torch.zeros_like(up[0][None])] + right_values + [upper_bound2] + return torch.cat(values, 0) + + +class DFineMLPPredictionHead(RTDetrMLPPredictionHead): + pass + + +def box_xyxy_to_cxcywh(x: torch.Tensor) -> torch.Tensor: + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +def distance2bbox(points, distance, reg_scale): + """ + Decodes edge-distances into bounding box coordinates. + + Args: + points (Tensor): (B, N, 4) or (N, 4) format, representing [x, y, w, h], + where (x, y) is the center and (w, h) are width and height. + distance (Tensor): (B, N, 4) or (N, 4), representing distances from the + point to the left, top, right, and bottom boundaries. + + reg_scale (float): Controls the curvature of the Weighting Function. + + Returns: + Tensor: Bounding boxes in (N, 4) or (B, N, 4) format [cx, cy, w, h]. + """ + reg_scale = abs(reg_scale) + x1 = points[..., 0] - (0.5 * reg_scale + distance[..., 0]) * (points[..., 2] / reg_scale) + y1 = points[..., 1] - (0.5 * reg_scale + distance[..., 1]) * (points[..., 3] / reg_scale) + x2 = points[..., 0] + (0.5 * reg_scale + distance[..., 2]) * (points[..., 2] / reg_scale) + y2 = points[..., 1] + (0.5 * reg_scale + distance[..., 3]) * (points[..., 3] / reg_scale) + + bboxes = torch.stack([x1, y1, x2, y2], -1) + + return box_xyxy_to_cxcywh(bboxes) + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, act="relu"): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + self.act = ACT2CLS[act]() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class LQE(nn.Module): + def __init__(self, k, hidden_dim, num_layers, reg_max): + super(LQE, self).__init__() + self.k = k + self.reg_max = reg_max + self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers) + init.constant_(self.reg_conf.layers[-1].bias, 0) + init.constant_(self.reg_conf.layers[-1].weight, 0) + + def forward(self, scores, pred_corners): + B, L, _ = pred_corners.size() + prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), dim=-1) + prob_topk, _ = prob.topk(self.k, dim=-1) + stat = torch.cat([prob_topk, prob_topk.mean(dim=-1, keepdim=True)], dim=-1) + quality_score = self.reg_conf(stat.reshape(B, L, -1)) + return scores + quality_score + + +class DFineDecoderOutput(RTDetrDecoderOutput): + pass + + +class DFineDecoder(RTDetrDecoder): + """ + D-FINE Decoder implementing Fine-grained Distribution Refinement (FDR). + + This decoder refines object detection predictions through iterative updates across multiple layers, + utilizing attention mechanisms, location quality estimators, and distribution refinement techniques + to improve bounding box accuracy and robustness. + """ + + def __init__(self, config: DFineConfig, num_layers, reg_max, reg_scale, up): + super().__init__(config=config) + self.d_model = config.d_model + self.num_layers = num_layers + self.layer_scale = config.layer_scale + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else num_layers + config.eval_idx + self.num_head = config.decoder_attention_heads + self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max + self.layers = nn.ModuleList( + [DFineDecoderLayer(config=config) for _ in range(config.decoder_layers)] + + [DFineDecoderLayer(config=config) for _ in range(config.decoder_layers - self.eval_idx - 1)] + ) + self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(config.decoder_layers)]) + + def value_op(self, memory, value_proj, value_scale, memory_mask, memory_spatial_shapes): + """ + Preprocess values for MSDeformableAttention. + """ + value = value_proj(memory) if value_proj is not None else memory + value = F.interpolate(memory, size=value_scale) if value_scale is not None else value + if memory_mask is not None: + value = value * memory_mask.to(value.dtype).unsqueeze(-1) + value = value.reshape(value.shape[0], value.shape[1], self.num_head, -1) + split_shape = [h * w for h, w in memory_spatial_shapes] + return value.permute(0, 2, 3, 1).split(split_shape, dim=-1) + + def forward( + self, + inputs_embeds, + ref_points_unact, + memory, + spatial_shapes, + bbox_head, + score_head, + pre_bbox_head, + integral, + up, + reg_scale, + encoder_attn_mask=None, + memory_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = () if (output_attentions and memory is not None) else None + intermediate = () + intermediate_reference_points = () + intermediate_logits = () + + output_detach = pred_corners_undetach = 0 + value = self.value_op(memory, None, None, memory_mask, spatial_shapes) + + project = weighting_function(self.reg_max, up, reg_scale) + ref_points_detach = F.sigmoid(ref_points_unact) + + for i, decoder_layer in enumerate(self.layers): + ref_points_input = ref_points_detach.unsqueeze(2) + query_pos_embed = self.query_pos_head(ref_points_detach).clamp(min=-10, max=10) + + # TODO Adjust scale if needed for detachable wider layers + if i >= self.eval_idx + 1 and self.layer_scale > 1: + query_pos_embed = F.interpolate(query_pos_embed, scale_factor=self.layer_scale) + value = self.value_op(memory, None, query_pos_embed.shape[-1], memory_mask, spatial_shapes) + hidden_states = F.interpolate(hidden_states, size=query_pos_embed.shape[-1]) + output_detach = hidden_states.detach() + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = decoder_layer( + hidden_states=hidden_states, + position_embeddings=query_pos_embed, + reference_points=ref_points_input, + spatial_shapes=spatial_shapes, + encoder_hidden_states=value, + encoder_attention_mask=encoder_attn_mask, + output_attentions=output_attentions, + ) + + hidden_states = output[0] + + if i == 0: + # Initial bounding box predictions with inverse sigmoid refinement + pre_bboxes = F.sigmoid(pre_bbox_head(output) + inverse_sigmoid(ref_points_detach)) + pre_scores = score_head[0](hidden_states) + ref_points_initial = pre_bboxes.detach() + + # Refine bounding box corners using FDR, integrating previous layer's corrections + pred_corners = bbox_head[i](hidden_states + output_detach) + pred_corners_undetach + inter_ref_bbox = distance2bbox(ref_points_initial, integral(pred_corners, project), reg_scale) + + pred_corners_undetach = pred_corners + ref_points_detach = inter_ref_bbox.detach() + output_detach = hidden_states.detach() + + # Store intermediate results + intermediate += (hidden_states,) + intermediate_reference_points += (inter_ref_bbox,) + intermediate_logits += (pre_scores,) if i == 0 else (score_head[i](hidden_states),) + + if output_attentions: + all_self_attns += (output[1],) + if memory is not None: + all_cross_attentions += (output[2],) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + intermediate, + intermediate_logits, + intermediate_reference_points, + all_hidden_states, + all_self_attns, + all_cross_attentions, + ] + if v is not None + ) + + return DFineDecoderOutput( + last_hidden_state=hidden_states, + intermediate_hidden_states=torch.stack(intermediate, dim=1), + intermediate_logits=torch.stack(intermediate_logits, dim=1), + intermediate_reference_points=torch.stack(intermediate_reference_points, dim=1), + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + class Gate(nn.Module): def __init__(self, d_model): super(Gate, self).__init__() @@ -307,7 +533,7 @@ def forward(self, x1, x2): gates = torch.sigmoid(self.gate(gate_input)) gate1, gate2 = gates.chunk(2, dim=-1) return self.norm(gate1 * x1 + gate2 * x2) - + def _bias_init_with_prob(self, prior_prob=0.01): """initialize conv/fc bias value according to a given probability value.""" bias_init = float(-math.log((1 - prior_prob) / prior_prob)) From 2fcfe04c0949b4d22314bdb6f5cae7422b1e87d2 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 27 Dec 2024 16:46:11 +0100 Subject: [PATCH 3/4] update decoder, add encoder --- .../models/d-fine/modular_d_fine.py | 153 ++++++++++++------ 1 file changed, 104 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/d-fine/modular_d_fine.py b/src/transformers/models/d-fine/modular_d_fine.py index edfb9380f53c4a..c60b72d77751cd 100644 --- a/src/transformers/models/d-fine/modular_d_fine.py +++ b/src/transformers/models/d-fine/modular_d_fine.py @@ -2,8 +2,13 @@ from ..rt_detr.modeling_rt_detr import ( RTDetrDecoderLayer, RTDetrDecoder, + RTDetrModel, RTDetrMLPPredictionHead, RTDetrDecoderOutput, + RTDetrHybridEncoder, + RTDetrRepVggBlock, + RTDetrCSPRepLayer, + RTDetrConvNormLayer, inverse_sigmoid, ) @@ -20,12 +25,13 @@ class DFineConfig(RTDetrConfig): model_type = "d-fine" - def __init__(self, decoder_offset_scale=0.5, eval_idx=-1, layer_scale=2, **super_kwargs): + def __init__(self, decoder_offset_scale=0.5, eval_idx=-1, layer_scale=2, reg_max=32, **super_kwargs): super().__init__(**super_kwargs) self.decoder_offset_scale = decoder_offset_scale self.eval_idx = eval_idx self.layer_scale = layer_scale + self.reg_max = reg_max def deformable_attention_core_func_v2( @@ -364,7 +370,47 @@ def forward(self, scores, pred_corners): class DFineDecoderOutput(RTDetrDecoderOutput): + dec_out_bboxes:torch.FloatTensor = None, + dec_out_logits:torch.FloatTensor = None, + dec_out_pred_corners:torch.FloatTensor = None, + dec_out_refs:torch.FloatTensor = None, + pre_bboxes:torch.FloatTensor = None, + pre_scores:torch.FloatTensor = None, + + +class DFineVggBlock(RTDetrRepVggBlock): + pass + + +class DFineConvNormLayer(RTDetrConvNormLayer): + pass + + +class DFineCSPRepLayer(RTDetrCSPRepLayer): pass + + +class RepNCSPELAN4(nn.Module): + # csp-elan + def __init__(self, config: DFineConfig, c1, c2, c3, c4, n=3, + bias=False, + act="silu"): + super().__init__() + self.c = c3//2 + self.cv1 = DFineConvNormLayer(config, c1, c3, 1, 1, activation=act) + self.cv2 = nn.Sequential(DFineCSPRepLayer(c3//2, c4, n, 1, bias=bias, act=act, bottletype=DFineVggBlock), DFineConvNormLayer(config, c4, c4, 3, 1, activation=act)) + self.cv3 = nn.Sequential(DFineCSPRepLayer(c4, c4, n, 1, bias=bias, act=act, bottletype=DFineVggBlock), DFineConvNormLayer(config, c4, c4, 3, 1, activation=act)) + self.cv4 = DFineConvNormLayer(config, c3+(2*c4), c2, 1, 1, activation=act) + + def forward_chunk(self, x): + y = list(self.cv1(x).chunk(2, 1)) + y.extend((m(y[-1])) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) + + def forward(self, x): + y = list(self.cv1(x).split((self.c, self.c), 1)) + y.extend(m(y[-1]) for m in [self.cv2, self.cv3]) + return self.cv4(torch.cat(y, 1)) class DFineDecoder(RTDetrDecoder): @@ -376,19 +422,17 @@ class DFineDecoder(RTDetrDecoder): to improve bounding box accuracy and robustness. """ - def __init__(self, config: DFineConfig, num_layers, reg_max, reg_scale, up): + def __init__(self, config: DFineConfig): super().__init__(config=config) self.d_model = config.d_model - self.num_layers = num_layers self.layer_scale = config.layer_scale - self.eval_idx = config.eval_idx if config.eval_idx >= 0 else num_layers + config.eval_idx + self.eval_idx = config.eval_idx if config.eval_idx >= 0 else config.decoder_layers + config.eval_idx self.num_head = config.decoder_attention_heads - self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max self.layers = nn.ModuleList( [DFineDecoderLayer(config=config) for _ in range(config.decoder_layers)] + [DFineDecoderLayer(config=config) for _ in range(config.decoder_layers - self.eval_idx - 1)] ) - self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max) for _ in range(config.decoder_layers)]) + self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, config.reg_max) for _ in range(config.decoder_layers)]) def value_op(self, memory, value_proj, value_scale, memory_mask, memory_spatial_shapes): """ @@ -417,25 +461,17 @@ def forward( encoder_attn_mask=None, memory_mask=None, output_attentions=None, - output_hidden_states=None, return_dict=None, ): - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and memory is not None) else None - intermediate = () - intermediate_reference_points = () - intermediate_logits = () + dec_out_bboxes = [] + dec_out_logits = [] + dec_out_pred_corners = [] + dec_out_refs = [] output_detach = pred_corners_undetach = 0 value = self.value_op(memory, None, None, memory_mask, spatial_shapes) @@ -454,9 +490,6 @@ def forward( hidden_states = F.interpolate(hidden_states, size=query_pos_embed.shape[-1]) output_detach = hidden_states.detach() - if output_hidden_states: - all_hidden_states += (hidden_states,) - output = decoder_layer( hidden_states=hidden_states, position_embeddings=query_pos_embed, @@ -479,46 +512,68 @@ def forward( pred_corners = bbox_head[i](hidden_states + output_detach) + pred_corners_undetach inter_ref_bbox = distance2bbox(ref_points_initial, integral(pred_corners, project), reg_scale) + if self.training or i == self.eval_idx: + scores = score_head[i](output) + # Lqe does not affect the performance here. + scores = self.lqe_layers[i](scores, pred_corners) + dec_out_logits.append(scores) + dec_out_bboxes.append(inter_ref_bbox) + dec_out_pred_corners.append(pred_corners) + dec_out_refs.append(ref_points_initial) + + if not self.training: + break + pred_corners_undetach = pred_corners ref_points_detach = inter_ref_bbox.detach() output_detach = hidden_states.detach() - # Store intermediate results - intermediate += (hidden_states,) - intermediate_reference_points += (inter_ref_bbox,) - intermediate_logits += (pre_scores,) if i == 0 else (score_head[i](hidden_states),) - - if output_attentions: - all_self_attns += (output[1],) - if memory is not None: - all_cross_attentions += (output[2],) - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - intermediate, - intermediate_logits, - intermediate_reference_points, - all_hidden_states, - all_self_attns, - all_cross_attentions, - ] - if v is not None + return ( + torch.stack(dec_out_bboxes), + torch.stack(dec_out_logits), + torch.stack(dec_out_pred_corners), + torch.stack(dec_out_refs), + pre_bboxes, + pre_scores, ) return DFineDecoderOutput( last_hidden_state=hidden_states, - intermediate_hidden_states=torch.stack(intermediate, dim=1), - intermediate_logits=torch.stack(intermediate_logits, dim=1), - intermediate_reference_points=torch.stack(intermediate_reference_points, dim=1), - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, + dec_out_bboxes=torch.stack(dec_out_bboxes), + dec_out_logits=torch.stack(dec_out_logits), + dec_out_pred_corners=torch.stack(dec_out_pred_corners), + dec_out_refs=torch.stack(dec_out_refs), + pre_bboxes=pre_bboxes, + pre_scores=pre_scores, ) +class DFineHybridEncoder(RTDetrHybridEncoder): + def __init__(self, config: DFineConfig): + super().__init__(config=config) + # top-down fpn + self.lateral_convs = nn.ModuleList() + self.fpn_blocks = nn.ModuleList() + for _ in range(len(self.in_channels) - 1, 0, -1): + self.lateral_convs.append(DFineConvNormLayer(config, hidden_dim, hidden_dim, 1, 1)) + self.fpn_blocks.append( + RepNCSPELAN4(config, hidden_dim * 2, hidden_dim, hidden_dim * 2, round(expansion * hidden_dim // 2), round(3 * depth_mult)) + ) + + +class DFineModel(RTDetrModel): + def __init__(self, config: DFineConfig): + super().__init__(config) + + # decoder + self.decoder = DFineDecoder(config) + + # create encoder + self.encoder = DFineHybridEncoder(config=config) + + + class Gate(nn.Module): def __init__(self, d_model): super(Gate, self).__init__() From f98051f7b6a19974d80e8027d01bba80661c9bad Mon Sep 17 00:00:00 2001 From: Vladislav Date: Fri, 27 Dec 2024 16:47:03 +0100 Subject: [PATCH 4/4] temp remove --- src/transformers/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 55e54f9af1a803..ff03d09966a4d6 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -74,7 +74,6 @@ deprecated, depth_anything, detr, - d_fine, dialogpt, dinat, dinov2,