From 6240dc994c00fb48846c0d348763291f479ed1d7 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 3 May 2024 14:50:35 +0530 Subject: [PATCH 01/13] reverse weight decay --- olmo/checkpoint.py | 4 +++- olmo/config.py | 6 ++++++ olmo/optim.py | 17 +++++++++++++++++ olmo/train.py | 6 ++++++ 4 files changed, 32 insertions(+), 1 deletion(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index f369888da..b9a846d85 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -55,7 +55,9 @@ gc_cuda, get_fs_local_rank, get_global_rank, - get_world_size, get_local_world_size, get_local_rank, + get_local_rank, + get_local_world_size, + get_world_size, ) from .util import ( _get_s3_client, diff --git a/olmo/config.py b/olmo/config.py index a354ab51b..08baa2485 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -481,6 +481,12 @@ class OptimizerConfig(BaseConfig): If not set, defaults to the wandb `log_interval`. """ + reverse_embedding_decay: bool = False + """ + Applying weight decay to embeddings may make them too small, potentially causing spikes. + Setting this parameter to true is a way of applying "reverse weight decay" to embeddings. + """ + def __post_init__(self): self.betas = tuple(self.betas) # type: ignore[assignment] diff --git a/olmo/optim.py b/olmo/optim.py index 2f2634238..b2355132c 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -617,6 +617,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] # Separate out parameters that we don't want to apply weight decay to, like norms and biases. decay = set() no_decay = set() + embeddings_decay = set() all_params = {} for mn, m in model.named_modules(): for pn, p in m.named_parameters(): @@ -644,6 +645,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] elif pn.endswith("weight") and isinstance(m, nn.Embedding): if cfg.optimizer.decay_embeddings: decay.add(fpn) + elif cfg.optimizer.reverse_embedding_decay: + embeddings_decay.add(fpn) else: no_decay.add(fpn) @@ -658,12 +661,15 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] # Create the pytorch optimizer groups. decay_sorted = sorted(list(decay)) no_decay_sorted = sorted(list(no_decay)) + embeddings_decay_sorted = sorted(list(embeddings_decay)) + param_groups = [] if len(decay_sorted) > 0: param_groups.append( { "params": [all_params[pn] for pn in decay_sorted], "param_names": decay_sorted, + "name": "decay_group", **param_group_defaults, } ) @@ -673,6 +679,17 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] "params": [all_params[pn] for pn in no_decay_sorted], "param_names": no_decay_sorted, "weight_decay": 0.0, + "name": "no_decay_group", + **param_group_defaults, + } + ) + if len(embeddings_decay_sorted) > 0: + # the weight_decay value will be multiplied by emb_decay_factor in olmo/train.py + param_groups.append( + { + "params": [all_params[pn] for pn in embeddings_decay_sorted], + "param_names": embeddings_decay_sorted, + "name": "embedding_decay_group", **param_group_defaults, } ) diff --git a/olmo/train.py b/olmo/train.py index e9ddf95f9..b55efcd41 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -711,6 +711,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> process_group=self.fsdp_model.process_group, ) + # TODO: confirm + emb_decay_factor = 1.0 - optim_metrics["optim/param/transformer.wte.weight.norm"] + # Adjust the learning rate. for group in self.optim.param_groups: # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group @@ -726,6 +729,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max ) + if group["name"] == "embedding_decay_group": + group["weight_decay"] *= emb_decay_factor + # Optimizer step. self.optim.step() From 0f5e28f9ff716ffee9c116b9a0eb913b9f41fdee Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 3 May 2024 04:58:41 -0700 Subject: [PATCH 02/13] bug fix --- olmo/optim.py | 4 ++-- olmo/train.py | 7 ++++-- test_fixtures/reverse_wd.yaml | 42 +++++++++++++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 test_fixtures/reverse_wd.yaml diff --git a/olmo/optim.py b/olmo/optim.py index b2355132c..d667783eb 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -651,8 +651,8 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] no_decay.add(fpn) # Validate that we've considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay + inter_params = decay & no_decay & embeddings_decay + union_params = decay | no_decay | embeddings_decay assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" assert ( len(all_params.keys() - union_params) == 0 diff --git a/olmo/train.py b/olmo/train.py index b55efcd41..b0784bd9e 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -711,8 +711,11 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> process_group=self.fsdp_model.process_group, ) - # TODO: confirm - emb_decay_factor = 1.0 - optim_metrics["optim/param/transformer.wte.weight.norm"] + # TODO: what to do otherwise? + if should_log_optim_metrics_this_step: + emb_decay_factor = 1.0 - optim_metrics["param/transformer.wte.weight.norm"] + else: + emb_decay_factor = 1.0 # Adjust the learning rate. for group in self.optim.param_groups: diff --git a/test_fixtures/reverse_wd.yaml b/test_fixtures/reverse_wd.yaml new file mode 100644 index 000000000..1bf426613 --- /dev/null +++ b/test_fixtures/reverse_wd.yaml @@ -0,0 +1,42 @@ +run_name: "reverse_test" +save_folder: "/tmp/olmo-train-tiny" +wandb: + name: ${run_name} + project: reverse-test +model: + d_model: 128 + n_heads: 4 + n_layers: 4 + mlp_ratio: 4 + alibi: false + alibi_bias_max: 8.0 + attention_dropout: 0.1 + attention_layer_norm: false + residual_dropout: 0.1 + embedding_dropout: 0.1 + max_sequence_length: 512 + vocab_size: 50257 + eos_token_id: 50256 + pad_token_id: 50256 + init_device: null + init_std: 0.02 +optimizer: + learning_rate: 0.001 + reverse_embedding_decay: true + metrics_log_interval: 1 +scheduler: + name: "cosine_with_warmup" + t_warmup: 10 +data: + paths: + - "test_fixtures/mup-sample-data/part-010-00002.npy" + persistent_workers: false + num_workers: 0 + prefetch_factor: null +tokenizer: + identifier: "gpt2" +save_overwrite: true +max_duration: 4 +global_train_batch_size: 8 +device_train_microbatch_size: 8 +precision: "fp32" From b7dc57e234349def2380c9b41d4f6df17d847812 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 3 May 2024 21:01:47 +0530 Subject: [PATCH 03/13] std, not norm --- olmo/optim.py | 2 +- olmo/train.py | 8 +++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index d667783eb..09123a437 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -232,7 +232,7 @@ def is_grad_norm_metric(metric_name: str) -> bool: all_metrics["clipping_rate"] = clipping_rate return all_metrics else: - return {} + return all_metrics @torch.no_grad() def _do_adaptive_clipping( diff --git a/olmo/train.py b/olmo/train.py index b0784bd9e..0c00f513b 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -711,11 +711,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> process_group=self.fsdp_model.process_group, ) - # TODO: what to do otherwise? - if should_log_optim_metrics_this_step: - emb_decay_factor = 1.0 - optim_metrics["param/transformer.wte.weight.norm"] - else: - emb_decay_factor = 1.0 + emb_norm = optim_metrics["param/transformer.wte.weight.norm"] + emb_std = math.sqrt(emb_norm^2 / (self.cfg.embedding_size * self.cfg.vocab_size)) + emb_decay_factor = 1.0 - emb_std # Adjust the learning rate. for group in self.optim.param_groups: From 1fc07cd990397cf5d85a26277bcd2d87ab9c6ed4 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Fri, 3 May 2024 21:04:29 +0530 Subject: [PATCH 04/13] right config key --- olmo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 0c00f513b..1631cd859 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -712,7 +712,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] - emb_std = math.sqrt(emb_norm^2 / (self.cfg.embedding_size * self.cfg.vocab_size)) + emb_std = math.sqrt(emb_norm^2 / (self.cfg.model.embedding_size * self.cfg.model.vocab_size)) emb_decay_factor = 1.0 - emb_std # Adjust the learning rate. From 4c5c4b1793167e19c63fc927199185a958a0f3ea Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 09:10:14 -0700 Subject: [PATCH 05/13] pow, not bitwise --- olmo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 1631cd859..c5f47d4ad 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -712,7 +712,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] - emb_std = math.sqrt(emb_norm^2 / (self.cfg.model.embedding_size * self.cfg.model.vocab_size)) + emb_std = math.sqrt(math.pow(emb_norm, 2) / (self.cfg.model.embedding_size * self.cfg.model.vocab_size)) emb_decay_factor = 1.0 - emb_std # Adjust the learning rate. From 49a6f83e09435977302580f0cfff1e537b7638cd Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 09:50:23 -0700 Subject: [PATCH 06/13] always compute param norm if reverse decay --- olmo/optim.py | 6 ++++-- olmo/train.py | 1 + test_fixtures/reverse_wd.yaml | 7 ++++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index 09123a437..b11ea2972 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics( global_step: int, collect_param_metrics: bool = True, process_group: Optional[dist.ProcessGroup] = None, + reverse_embedding_decay: bool = False, ) -> Dict[str, torch.Tensor]: """ Clips gradients for every group that has the field `max_grad_norm`. @@ -85,11 +86,12 @@ def clip_grads_and_collect_metrics( for name, p in zip(group["param_names"], group["params"]): name = self._clean_param_name(name) - # Always need to collect the norm of gradients for clipping, even if we're not collecting + # Always need to collect the norm of gradients and parameters for clipping, even if we're not collecting # other metrics. tensors: List[Optional[torch.Tensor]] = [p.grad] prefixes: List[str] = [f"grad/{name}"] - if collect_param_metrics: + # TODO: only do this for the embedding group + if collect_param_metrics or reverse_embedding_decay: state = self.get_state_for_param(p) sorted_state_keys = sorted([k for k in state.keys()]) tensors.extend([p] + [state[key] for key in sorted_state_keys]) diff --git a/olmo/train.py b/olmo/train.py index c5f47d4ad..d5eb93c48 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -709,6 +709,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # passing this process group here ensures metrics are reduced correctly when we're using # HYBRID sharding. process_group=self.fsdp_model.process_group, + reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay, ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] diff --git a/test_fixtures/reverse_wd.yaml b/test_fixtures/reverse_wd.yaml index 1bf426613..5d208542c 100644 --- a/test_fixtures/reverse_wd.yaml +++ b/test_fixtures/reverse_wd.yaml @@ -23,20 +23,21 @@ model: optimizer: learning_rate: 0.001 reverse_embedding_decay: true - metrics_log_interval: 1 + metrics_log_interval: 100 scheduler: name: "cosine_with_warmup" t_warmup: 10 data: paths: - - "test_fixtures/mup-sample-data/part-010-00002.npy" + - "/net/nfs.cirrascale/allennlp/llm-data/c4/en/c4-train.00000-00099.npy" persistent_workers: false num_workers: 0 prefetch_factor: null tokenizer: identifier: "gpt2" save_overwrite: true -max_duration: 4 +max_duration: 16 +stop_at: ${max_duration} global_train_batch_size: 8 device_train_microbatch_size: 8 precision: "fp32" From 9fae31a540b234dd6d97d6983c207403f4c5e7b4 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 09:57:43 -0700 Subject: [PATCH 07/13] only for embedding group --- olmo/optim.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/optim.py b/olmo/optim.py index b11ea2972..9a25b5cfb 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -84,14 +84,14 @@ def clip_grads_and_collect_metrics( # with ReLoRa, for example. assert group.get("sharded", True) is True + is_embedding_group = group["name"] == "embedding_decay_group" for name, p in zip(group["param_names"], group["params"]): name = self._clean_param_name(name) # Always need to collect the norm of gradients and parameters for clipping, even if we're not collecting # other metrics. tensors: List[Optional[torch.Tensor]] = [p.grad] prefixes: List[str] = [f"grad/{name}"] - # TODO: only do this for the embedding group - if collect_param_metrics or reverse_embedding_decay: + if collect_param_metrics or (reverse_embedding_decay and is_embedding_group): state = self.get_state_for_param(p) sorted_state_keys = sorted([k for k in state.keys()]) tensors.extend([p] + [state[key] for key in sorted_state_keys]) From d6d5345e36ceb0aa8df11c795f9d0fdae7492720 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 10:10:10 -0700 Subject: [PATCH 08/13] make mypy happy --- olmo/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index d5eb93c48..298c0422b 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -713,7 +713,8 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] - emb_std = math.sqrt(math.pow(emb_norm, 2) / (self.cfg.model.embedding_size * self.cfg.model.vocab_size)) + emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size + emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size)) emb_decay_factor = 1.0 - emb_std # Adjust the learning rate. From 70d12b898d1ef43ec53167bb4b41c38347a13a81 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 10:10:42 -0700 Subject: [PATCH 09/13] changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f4bf369a..cdd05e76c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks - Tokenizer patch - Added option to specify number of model replicas when using hybrid sharding. +- Added reverse_embedding_decay option. ### Changed From d2f6ea2a0e9cdc42aa2ae4d785b899f7e387bdfe Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 12:04:22 -0700 Subject: [PATCH 10/13] isort --- olmo/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index c0f9701c4..cd4e625fb 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -19,11 +19,10 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import wandb from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader -import wandb - from .aliases import PathOrStr from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( From 717925eb31cbc674b16ea60b890108ca73b333fd Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 12:32:16 -0700 Subject: [PATCH 11/13] rename to regularize_embeddings --- olmo/config.py | 2 +- olmo/optim.py | 6 +++--- olmo/train.py | 2 +- test_fixtures/reverse_wd.yaml | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index bc1819d50..6ddf55142 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -482,7 +482,7 @@ class OptimizerConfig(BaseConfig): If not set, defaults to the wandb `log_interval`. """ - reverse_embedding_decay: bool = False + regularize_embeddings: bool = False """ Applying weight decay to embeddings may make them too small, potentially causing spikes. Setting this parameter to true is a way of applying "reverse weight decay" to embeddings. diff --git a/olmo/optim.py b/olmo/optim.py index 9a25b5cfb..f9e73fa3e 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -43,7 +43,7 @@ def clip_grads_and_collect_metrics( global_step: int, collect_param_metrics: bool = True, process_group: Optional[dist.ProcessGroup] = None, - reverse_embedding_decay: bool = False, + regularize_embeddings: bool = False, ) -> Dict[str, torch.Tensor]: """ Clips gradients for every group that has the field `max_grad_norm`. @@ -91,7 +91,7 @@ def clip_grads_and_collect_metrics( # other metrics. tensors: List[Optional[torch.Tensor]] = [p.grad] prefixes: List[str] = [f"grad/{name}"] - if collect_param_metrics or (reverse_embedding_decay and is_embedding_group): + if collect_param_metrics or (regularize_embeddings and is_embedding_group): state = self.get_state_for_param(p) sorted_state_keys = sorted([k for k in state.keys()]) tensors.extend([p] + [state[key] for key in sorted_state_keys]) @@ -647,7 +647,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] elif pn.endswith("weight") and isinstance(m, nn.Embedding): if cfg.optimizer.decay_embeddings: decay.add(fpn) - elif cfg.optimizer.reverse_embedding_decay: + elif cfg.optimizer.regularize_embeddings: embeddings_decay.add(fpn) else: no_decay.add(fpn) diff --git a/olmo/train.py b/olmo/train.py index cd4e625fb..ec1570667 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -719,7 +719,7 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # passing this process group here ensures metrics are reduced correctly when we're using # HYBRID sharding. process_group=self.fsdp_model.process_group, - reverse_embedding_decay=self.cfg.optimizer.reverse_embedding_decay, + regularize_embeddings=self.cfg.optimizer.regularize_embeddings, ) emb_norm = optim_metrics["param/transformer.wte.weight.norm"] diff --git a/test_fixtures/reverse_wd.yaml b/test_fixtures/reverse_wd.yaml index 5d208542c..c75da3274 100644 --- a/test_fixtures/reverse_wd.yaml +++ b/test_fixtures/reverse_wd.yaml @@ -22,7 +22,7 @@ model: init_std: 0.02 optimizer: learning_rate: 0.001 - reverse_embedding_decay: true + regularize_embeddings: true metrics_log_interval: 100 scheduler: name: "cosine_with_warmup" From 962b983978f262e468c1db202a0311d40a3ff19b Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 13:17:05 -0700 Subject: [PATCH 12/13] change docstring --- olmo/config.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index 6ddf55142..a7efb0167 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -484,8 +484,7 @@ class OptimizerConfig(BaseConfig): regularize_embeddings: bool = False """ - Applying weight decay to embeddings may make them too small, potentially causing spikes. - Setting this parameter to true is a way of applying "reverse weight decay" to embeddings. + Applies a regularizer to the embeddings that tries to pull them towards a standard deviation of 1. """ def __post_init__(self): From 465d14316826567acf083d810a9f7eebab1a9988 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 6 May 2024 13:44:35 -0700 Subject: [PATCH 13/13] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e671bc838..85d3da27f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,7 @@ shared memory implementation can be used by passing `use_legacy_shared_mem_impl` - Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks - Tokenizer patch - Added option to specify number of model replicas when using hybrid sharding. -- Added reverse_embedding_decay option. +- Added `regularize_embeddings` option. ### Changed