Skip to content

Commit

Permalink
add back the orpo nll labels
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Dec 28, 2024
1 parent 5fae1b2 commit e1918b7
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,12 +816,17 @@ def cross_entropy_loss(logits, labels):
loss = loss_fct(logits, labels)
return loss

labels = concatenated_batch["concatenated_labels"].clone()
if self.is_encoder_decoder:
labels = concatenated_batch["concatenated_labels"].clone()
else:
labels = concatenated_batch["concatenated_input_ids"].clone()
attention_mask = concatenated_batch["concatenated_attention_mask"]
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
labels,
concatenated_batch["concatenated_labels"],
average_log_prob=True,
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
Expand Down

0 comments on commit e1918b7

Please sign in to comment.