diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 51897eeb44..804f2b427b 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -138,10 +138,19 @@ def __init__( if data_collator is None: data_collator = DataCollatorWithPadding(self.processing_class) - self.policy_model.generation_config.eos_token_id = ( - None # disable `pad_token_id` and `eos_token_id` because we just want to - ) - self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding + # Handle stop token settings + if args.stop_token and args.stop_token_id: + raise ValueError("You cannot set both `stop_token` and `stop_token_id`. ") + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = processing_class.eos_token_id + else: + raise ValueError( + f"Unknown `stop_token` {args.stop_token}. " + f"Allowed values are: `eos`, None (no stop token)" + ) + # Update policy model's generation_config to use provided stop token + self.policy_model.generation_config.eos_token_id = args.stop_token_id # peft support if not is_peft_available() and peft_config is not None: @@ -220,8 +229,6 @@ def __init__( for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]: if module is not None: disable_dropout_in_model(module) - if args.stop_token and args.stop_token == "eos": - args.stop_token_id = processing_class.eos_token_id self.model = PolicyAndValueWrapper(self.policy_model, self.value_model) self.model.config = self.policy_model.config # needed for pushing to hub self.create_optimizer_and_scheduler(