Skip to content

Commit

Permalink
Merge branch 'main' into rename-collator
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec authored Dec 22, 2024
2 parents 9b3d716 + 5239b94 commit bcf4f98
Show file tree
Hide file tree
Showing 16 changed files with 88 additions and 91 deletions.
30 changes: 0 additions & 30 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import platform
import subprocess
import tempfile
import unittest

Expand All @@ -24,34 +22,6 @@
from trl import RLOOConfig, RLOOTrainer


def test():
command = """\
python examples/scripts/rloo/rloo.py \
--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \
--dataset_train_split descriptiveness \
--learning_rate 3e-6 \
--output_dir models/minimal/rloo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 1 \
--total_episodes 10 \
--model_name_or_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \
--sft_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \
--reward_model_path trl-internal-testing/tiny-Qwen2ForCausalLM-2.5 \
--missing_eos_penalty 1.0 \
--save_strategy no \
--stop_token eos
"""
if platform.system() == "Windows":
# windows CI does not work with subprocesses for some reason
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
return
subprocess.run(
command,
shell=True,
check=True,
)


class RLOOTrainerTester(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/bco_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class BCOConfig(TrainingArguments):
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
This argument is required if you want to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model and reference model.
generate_during_eval (`bool`, *optional*, defaults to `False`):
If `True`, generates and logs completions from both the model and the reference model to W&B during
evaluation.
Expand Down Expand Up @@ -78,6 +80,7 @@ class BCOConfig(TrainingArguments):
label_pad_token_id: int = -100
padding_value: Optional[int] = None
truncation_mode: str = "keep_end"
disable_dropout: bool = True
generate_during_eval: bool = False
is_encoder_decoder: Optional[bool] = None
precompute_ref_log_probs: bool = False
Expand Down
11 changes: 5 additions & 6 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,6 @@ class BCOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -538,10 +536,11 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# disable dropout in the model and reference model
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

self.max_length = max_length
self.generate_during_eval = args.generate_during_eval
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/cpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
102 changes: 54 additions & 48 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import transformers
from accelerate import PartialState
from accelerate.utils import is_deepspeed_available, tqdm
from datasets import Dataset
from datasets import Dataset, IterableDataset
from packaging import version
from torch.utils.data import DataLoader
from transformers import (
Expand Down Expand Up @@ -376,6 +376,7 @@ def make_inputs_require_grad(module, input, output):
if data_collator is None:
data_collator = DataCollatorForPreference(pad_token_id=self.padding_value)

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down Expand Up @@ -436,53 +437,16 @@ def make_inputs_require_grad(module, input, output):
# that the warning has already been issued.
model.warnings_issued["estimate_tokens"] = True

# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
# Extract the prompt if needed, and apply the chat template if needed
train_dataset = train_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
)
train_dataset = train_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
)
eval_dataset = eval_dataset.map(
maybe_apply_chat_template,
fn_kwargs={"tokenizer": processing_class},
num_proc=args.dataset_num_proc,
desc="Applying chat template to eval dataset",
)

# tokenize the dataset, lower writer batch size to avoid OOM (frequent in vision models)
fn_kwargs = {
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": self.is_encoder_decoder,
}
train_dataset = train_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing train dataset",
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
fn_kwargs=fn_kwargs,
num_proc=self.dataset_num_proc,
writer_batch_size=10,
desc="Tokenizing eval dataset",
)
# Dataset preparation
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")

super().__init__(
model=model,
Expand Down Expand Up @@ -540,6 +504,48 @@ def make_inputs_require_grad(module, input, output):
if self.loss_type == "bco_pair":
self.running = RunningMoments(self.accelerator)

def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin],
args: DPOConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Build the kwargs for the `map` function
map_kwargs = {"writer_batch_size": 10}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc

with PartialState().local_main_process_first():
# Extract prompt if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

# Apply the chat template if needed
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
dataset = dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}, **map_kwargs)

# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

dataset = dataset.map(
self.tokenize_row if not self.is_vision_model else self.process_row,
remove_columns=["prompt", "chosen", "rejected"],
fn_kwargs={
"processing_class": processing_class,
"max_prompt_length": args.max_prompt_length,
"max_completion_length": args.max_completion_length,
# for enc-dec, we add the special tokens ([bos_token] + prompt + [eos_token]; completion + [eos_token])
"add_special_tokens": False,
},
**map_kwargs,
)

return dataset

@staticmethod
def tokenize_row(features, processing_class, max_prompt_length, max_completion_length, add_special_tokens):
"""
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/gkd_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GKDConfig(SFTConfig):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
from a string.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether or not to disable dropouts in `model`.
Whether to disable dropout in the model.
seq_kd (`bool`, *optional*, defaults to `False`):
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
on teacher-generated output).
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/gkd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
else:
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(self.model)

Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/kto_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class KTOConfig(TrainingArguments):
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 1e-6
Expand Down
3 changes: 1 addition & 2 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ class KTOTrainer(Trainer):
The function to use to preprocess the logits before computing the metrics.
peft_config (`dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
The function to use to compute the metrics. Must take a `EvalPrediction` and return
a dictionary string to metric values.
Expand Down Expand Up @@ -526,6 +524,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/online_dpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class OnlineDPOConfig(TrainingArguments):
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
Whether to disable dropout in the model and reference model.
"""

learning_rate: float = 5e-7
Expand Down
4 changes: 3 additions & 1 deletion trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,11 @@ def __init__(
# Get peft model with the given config
model = get_peft_model(model, peft_config)

# Disable dropout in the model if specified
# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)

# Handle the ref_model
# Usually, the user wants the ref model to be the initial version of the model. When using PEFT, it's easy to
Expand Down
1 change: 1 addition & 0 deletions trl/trainer/orpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def make_inputs_require_grad(module, input, output):
else:
self.use_dpo_data_collator = False

# Disable dropout in the model and reference model
if args.disable_dropout:
disable_dropout_in_model(model)

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class PRMConfig(TrainingArguments):
Maximum length of the sequences (prompt + completion) used for truncation.
max_completion_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
step_separator (`str`, *optional*, defaults to `"\n"`):
Separator used to separate each step of the reasoning process.
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
Expand All @@ -46,6 +48,7 @@ class PRMConfig(TrainingArguments):
learning_rate: float = 1e-5
max_length: Optional[int] = None
max_completion_length: Optional[int] = None
disable_dropout: bool = True
step_separator: str = "\n"
train_on_last_step_only: bool = False
dataset_num_proc: Optional[int] = None
6 changes: 5 additions & 1 deletion trl/trainer/prm_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from transformers.utils import is_peft_available

from .prm_config import PRMConfig
from .utils import compute_accuracy, generate_model_card
from .utils import compute_accuracy, disable_dropout_in_model, generate_model_card


if is_peft_available():
Expand Down Expand Up @@ -130,6 +130,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments):
max_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
to use the default data collator.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
dataset_num_proc (`int`, *optional*, defaults to `None`):
Number of processes to use for processing the dataset.
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
Expand All @@ -42,6 +44,7 @@ class RewardConfig(TrainingArguments):
"""

max_length: Optional[int] = None
disable_dropout: bool = True
dataset_num_proc: Optional[int] = None
center_rewards_coefficient: Optional[float] = None
remove_unused_columns: bool = False
5 changes: 5 additions & 0 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
generate_model_card,
get_comet_experiment_url,
log_table_to_comet_experiment,
Expand Down Expand Up @@ -169,6 +170,10 @@ def __init__(

model = get_peft_model(model, peft_config)

# Disable dropout in the model
if args.disable_dropout:
disable_dropout_in_model(model)

if compute_metrics is None:
compute_metrics = compute_accuracy

Expand Down

0 comments on commit bcf4f98

Please sign in to comment.