Skip to content

Commit

Permalink
passing self.args.use_liger_loss without liger installed should raise…
Browse files Browse the repository at this point in the history
…d an error
  • Loading branch information
kashif committed Dec 15, 2024
1 parent 44aa20c commit 7682e31
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,18 +357,22 @@ def make_inputs_require_grad(module, input, output):
)

# Import Liger loss if enabled
if self.args.use_liger_loss and is_liger_kernel_available():
if self.args.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
"Please install liger-kernel first: `pip install liger-kernel`"
)
try:
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss

self.orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta)
self._using_liger = True
except ImportError:
warnings.warn(
"Liger package not found. Falling back to default ORPO loss implementation. "
"Install liger-kernel for optimized performance."
raise ImportError(
"Failed to import LigerFusedLinearORPOLoss from liger-kernel. "
"Please ensure you have the correct liger-kernel version installed."
)
self._using_liger = False
else:
self._using_liger = False

Expand Down

0 comments on commit 7682e31

Please sign in to comment.