Skip to content

Commit

Permalink
🖨 Add Script Utilities section to the documentation (#2407)
Browse files Browse the repository at this point in the history
* Add script_utils.md to the documentation

* Refactor ScriptArguments class documentation

* Refactor TrlParser class to improve code organization and readability
  • Loading branch information
qgallouedec authored Nov 28, 2024
1 parent c10cc89 commit a34e9bf
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 27 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
title: Data Utilities
- local: text_environments
title: Text Environments
- local: script_utils
title: Script Utilities
title: API
- sections:
- local: example_overview
Expand Down
9 changes: 9 additions & 0 deletions docs/source/script_utils.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Scripts Utilities

## ScriptArguments

[[autodoc]] ScriptArguments

## TrlParser

[[autodoc]] TrlParser
34 changes: 20 additions & 14 deletions trl/commands/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,26 +151,29 @@ class ChatArguments:


class TrlParser(HfArgumentParser):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`List[argparse.ArgumentParser]`):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
and not raise errors.
"""

def __init__(self, parsers, ignore_extra_args=False):
"""
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
parsers for users that pass a valid `config` field and merge the values that are set in the config
with the processed parsers.
Args:
parsers (`list[argparse.ArgumentParser`]):
List of parsers.
ignore_extra_args (`bool`):
Whether to ignore extra arguments passed by the config
and not raise errors.
"""
super().__init__(parsers)
self.yaml_parser = YamlConfigParser()
self.ignore_extra_args = ignore_extra_args

def post_process_dataclasses(self, dataclasses):
# Apply additional post-processing in case some arguments needs a special
# care
"""
Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments.
"""

training_args = trl_args = None
training_args_index = None

Expand All @@ -192,6 +195,9 @@ def post_process_dataclasses(self, dataclasses):
return dataclasses

def parse_args_and_config(self, return_remaining_strings=False):
"""
Parse the command line arguments and the config file.
"""
yaml_config = None
if "--config" in sys.argv:
config_index = sys.argv.index("--config")
Expand Down
27 changes: 14 additions & 13 deletions trl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@ class ScriptArguments:
"""
Arguments common to all scripts.
dataset_name (`str`):
Dataset name.
dataset_train_split (`str`, *optional*, defaults to `"train"`):
Dataset split to use for training.
dataset_test_split (`str`, *optional*, defaults to `"test"`):
Dataset split to use for evaluation.
config (`str` or `None`, *optional*, defaults to `None`):
Path to the optional config file.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`):
Whether to apply `use_reentrant` for gradient_checkpointing.
ignore_bias_buffers (`bool`, *optional*, defaults to `False`):
Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type,
inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
Args:
dataset_name (`str`):
Dataset name.
dataset_train_split (`str`, *optional*, defaults to `"train"`):
Dataset split to use for training.
dataset_test_split (`str`, *optional*, defaults to `"test"`):
Dataset split to use for evaluation.
config (`str` or `None`, *optional*, defaults to `None`):
Path to the optional config file.
gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`):
Whether to apply `use_reentrant` for gradient_checkpointing.
ignore_bias_buffers (`bool`, *optional*, defaults to `False`):
Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar
type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992.
"""

dataset_name: str
Expand Down

0 comments on commit a34e9bf

Please sign in to comment.