From 1661bc295e694dc1ec9d50f80746612347051da3 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Wed, 9 Oct 2024 12:13:18 +0200 Subject: [PATCH] [GKD] interpolate in prob. space (#2204) * interpolate in prob. space * better var names * use logsumexp * set beta dtype * beta tensor --- trl/trainer/gkd_trainer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/trl/trainer/gkd_trainer.py b/trl/trainer/gkd_trainer.py index 5b640a2aec..b1a944ed97 100644 --- a/trl/trainer/gkd_trainer.py +++ b/trl/trainer/gkd_trainer.py @@ -124,13 +124,18 @@ def generalized_jsd_loss( student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - # Compute the interpolated log probabilities - interpolated_log_probs = beta * student_log_probs + (1 - beta) * teacher_log_probs + # Compute the log of the mixture distribution + # log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture + beta = torch.tensor(beta, dtype=student_log_probs.dtype) + mixture_log_probs = torch.logsumexp( + torch.stack([student_log_probs + torch.log(beta), teacher_log_probs + torch.log(1 - beta)]), + dim=0, + ) # Compute KL divergences using F.kl_div # PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper. - kl_teacher = F.kl_div(interpolated_log_probs, teacher_log_probs, reduction="none", log_target=True) - kl_student = F.kl_div(interpolated_log_probs, student_log_probs, reduction="none", log_target=True) + kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True) + kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True) # Compute the Generalized Jensen-Shannon Divergence jsd = beta * kl_teacher + (1 - beta) * kl_student