Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Dec 26, 2024
1 parent b60247e commit f753449
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,31 +1162,18 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
# Flatten the input_ids, position_ids, and loss_mask
# input_ids = [[a, b, c, 0], -> input_ids = [a, b, c, d, e, f, g]
# [d, e, f, g]] position_ids = [0, 1, 2, 0, 1, 2, 3]
model_kwargs["input_ids"] = input_ids[attention_mask.bool()].unsqueeze(0)
input_ids = input_ids[attention_mask.bool()].unsqueeze(0)
loss_mask = loss_mask[attention_mask.bool()].unsqueeze(0)
model_kwargs["position_ids"] = attention_mask.cumsum(1)[attention_mask.bool()].unsqueeze(0) - 1
else:
model_kwargs["input_ids"] = input_ids
model_kwargs["attention_mask"] = attention_mask

outputs = model(**model_kwargs)

if self.padding_free:
# Reverse flattenings
batch_size, seq_len = input_ids.shape
vocab_size = outputs.logits.shape[-1]
logits = torch.zeros(
batch_size, seq_len, vocab_size, device=outputs.logits.device, dtype=outputs.logits.dtype
)
flat_logits = logits.view(batch_size * seq_len, vocab_size) # (B, L, V) -> (B * L, V)
flat_attention_mask = attention_mask.flatten() # (B, L) -> (B * L)
flat_logits[flat_attention_mask.bool()] = outputs.logits
else:
logits = outputs.logits
outputs = model(input_ids, **model_kwargs)
logits = outputs.logits

# Offset the logits by one to align with the labels
logits = logits[:, :-1, :]
labels = input_ids[:, 1:].clone()
loss_mask = loss_mask[:, 1:].bool()
labels = torch.roll(input_ids, shifts=-1, dims=1)
loss_mask = torch.roll(loss_mask, shifts=-1, dims=1).bool()

if self.use_num_logits_to_keep:
# Align labels with logits
Expand All @@ -1207,6 +1194,16 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to
labels[~loss_mask] = 0 # dummy token; we'll ignore the losses on these tokens later
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
per_token_logps[~loss_mask] = 0
per_token_logps = torch.roll(per_token_logps, shifts=1, dims=1)

if self.padding_free:
batch_size, seq_len = attention_mask.shape
per_token_logps_ = torch.zeros(
batch_size, seq_len, device=outputs.logits.device, dtype=outputs.logits.dtype
)
per_token_logps_[attention_mask.bool()] = per_token_logps
per_token_logps = per_token_logps_

all_logps = per_token_logps.sum(-1)

output = {}
Expand Down Expand Up @@ -1237,8 +1234,8 @@ def concatenated_forward(self, model: nn.Module, batch: dict[str, Union[list, to

output["chosen_logps"] = all_logps[:num_examples]
output["rejected_logps"] = all_logps[num_examples:]
output["mean_chosen_logits"] = logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = logits[num_examples:][loss_mask[num_examples:]].mean()
output["mean_chosen_logits"] = torch.zeros(1) # logits[:num_examples][loss_mask[:num_examples]].mean()
output["mean_rejected_logits"] = torch.zeros(1) # logits[num_examples:][loss_mask[num_examples:]].mean()

if self.aux_loss_enabled:
output["aux_loss"] = outputs.aux_loss
Expand Down

0 comments on commit f753449

Please sign in to comment.