Skip to content

Commit

Permalink
Include stop token in policy model's generation_config
Browse files Browse the repository at this point in the history
  • Loading branch information
dawidm committed Dec 28, 2024
1 parent aed5da5 commit 15ec8f4
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions trl/trainer/ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,19 @@ def __init__(
if data_collator is None:
data_collator = DataCollatorWithPadding(self.processing_class)

self.policy_model.generation_config.eos_token_id = (
None # disable `pad_token_id` and `eos_token_id` because we just want to
)
self.policy_model.generation_config.pad_token_id = None # generate tokens without truncation / padding
# Handle stop token settings
if args.stop_token and args.stop_token_id:
raise ValueError("You cannot set both `stop_token` and `stop_token_id`. ")
if args.stop_token:
if args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
else:
raise ValueError(
f"Unknown `stop_token` {args.stop_token}. "
f"Allowed values are: `eos`, None (no stop token)"
)
# Update policy model's generation_config to use provided stop token
self.policy_model.generation_config.eos_token_id = args.stop_token_id

# peft support
if not is_peft_available() and peft_config is not None:
Expand Down Expand Up @@ -220,8 +229,6 @@ def __init__(
for module in [self.policy_model, self.ref_model, self.value_model, self.reward_model]:
if module is not None:
disable_dropout_in_model(module)
if args.stop_token and args.stop_token == "eos":
args.stop_token_id = processing_class.eos_token_id
self.model = PolicyAndValueWrapper(self.policy_model, self.value_model)
self.model.config = self.policy_model.config # needed for pushing to hub
self.create_optimizer_and_scheduler(
Expand Down

0 comments on commit 15ec8f4

Please sign in to comment.