From a34e9bf84f2c7d016557c6e158a14aa0c0cb2972 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 28 Nov 2024 16:43:08 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=96=A8=20Add=20Script=20Utilities=20secti?= =?UTF-8?q?on=20to=20the=20documentation=20(#2407)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add script_utils.md to the documentation * Refactor ScriptArguments class documentation * Refactor TrlParser class to improve code organization and readability --- docs/source/_toctree.yml | 2 ++ docs/source/script_utils.md | 9 +++++++++ trl/commands/cli_utils.py | 34 ++++++++++++++++++++-------------- trl/utils.py | 27 ++++++++++++++------------- 4 files changed, 45 insertions(+), 27 deletions(-) create mode 100644 docs/source/script_utils.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f3d39ba445..95fe5ab611 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -65,6 +65,8 @@ title: Data Utilities - local: text_environments title: Text Environments + - local: script_utils + title: Script Utilities title: API - sections: - local: example_overview diff --git a/docs/source/script_utils.md b/docs/source/script_utils.md new file mode 100644 index 0000000000..344d13aaef --- /dev/null +++ b/docs/source/script_utils.md @@ -0,0 +1,9 @@ +# Scripts Utilities + +## ScriptArguments + +[[autodoc]] ScriptArguments + +## TrlParser + +[[autodoc]] TrlParser diff --git a/trl/commands/cli_utils.py b/trl/commands/cli_utils.py index 76cc777da2..384daf4927 100644 --- a/trl/commands/cli_utils.py +++ b/trl/commands/cli_utils.py @@ -151,26 +151,29 @@ class ChatArguments: class TrlParser(HfArgumentParser): + """ + The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config + parsers for users that pass a valid `config` field and merge the values that are set in the config + with the processed parsers. + + Args: + parsers (`List[argparse.ArgumentParser]`): + List of parsers. + ignore_extra_args (`bool`): + Whether to ignore extra arguments passed by the config + and not raise errors. + """ + def __init__(self, parsers, ignore_extra_args=False): - """ - The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config - parsers for users that pass a valid `config` field and merge the values that are set in the config - with the processed parsers. - - Args: - parsers (`list[argparse.ArgumentParser`]): - List of parsers. - ignore_extra_args (`bool`): - Whether to ignore extra arguments passed by the config - and not raise errors. - """ super().__init__(parsers) self.yaml_parser = YamlConfigParser() self.ignore_extra_args = ignore_extra_args def post_process_dataclasses(self, dataclasses): - # Apply additional post-processing in case some arguments needs a special - # care + """ + Post process dataclasses to merge the TrainingArguments with the SFTScriptArguments or DPOScriptArguments. + """ + training_args = trl_args = None training_args_index = None @@ -192,6 +195,9 @@ def post_process_dataclasses(self, dataclasses): return dataclasses def parse_args_and_config(self, return_remaining_strings=False): + """ + Parse the command line arguments and the config file. + """ yaml_config = None if "--config" in sys.argv: config_index = sys.argv.index("--config") diff --git a/trl/utils.py b/trl/utils.py index 2c20b51668..eaea8c78aa 100644 --- a/trl/utils.py +++ b/trl/utils.py @@ -21,19 +21,20 @@ class ScriptArguments: """ Arguments common to all scripts. - dataset_name (`str`): - Dataset name. - dataset_train_split (`str`, *optional*, defaults to `"train"`): - Dataset split to use for training. - dataset_test_split (`str`, *optional*, defaults to `"test"`): - Dataset split to use for evaluation. - config (`str` or `None`, *optional*, defaults to `None`): - Path to the optional config file. - gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): - Whether to apply `use_reentrant` for gradient_checkpointing. - ignore_bias_buffers (`bool`, *optional*, defaults to `False`): - Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar type, - inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. + Args: + dataset_name (`str`): + Dataset name. + dataset_train_split (`str`, *optional*, defaults to `"train"`): + Dataset split to use for training. + dataset_test_split (`str`, *optional*, defaults to `"test"`): + Dataset split to use for evaluation. + config (`str` or `None`, *optional*, defaults to `None`): + Path to the optional config file. + gradient_checkpointing_use_reentrant (`bool`, *optional*, defaults to `False`): + Whether to apply `use_reentrant` for gradient_checkpointing. + ignore_bias_buffers (`bool`, *optional*, defaults to `False`): + Debug argument for distributed training. Fix for DDP issues with LM bias/mask buffers - invalid scalar + type, inplace operation. See https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992. """ dataset_name: str