diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bd8c8af7..85d3da27f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ shared memory implementation can be used by passing `use_legacy_shared_mem_impl` - Added MMLU multiple choice (A/B/C/D) 5-shot variant downstream tasks - Tokenizer patch - Added option to specify number of model replicas when using hybrid sharding. +- Added `regularize_embeddings` option. ### Changed diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 343bdfa31..2d45f2552 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -55,7 +55,9 @@ gc_cuda, get_fs_local_rank, get_global_rank, - get_world_size, get_local_world_size, get_local_rank, + get_local_rank, + get_local_world_size, + get_world_size, ) from .util import ( _get_s3_client, diff --git a/olmo/config.py b/olmo/config.py index 0d954101b..a7efb0167 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -482,6 +482,11 @@ class OptimizerConfig(BaseConfig): If not set, defaults to the wandb `log_interval`. """ + regularize_embeddings: bool = False + """ + Applies a regularizer to the embeddings that tries to pull them towards a standard deviation of 1. + """ + def __post_init__(self): self.betas = tuple(self.betas) # type: ignore[assignment] diff --git a/olmo/optim.py b/olmo/optim.py index 2f2634238..f9e73fa3e 100644 --- a/olmo/optim.py +++ b/olmo/optim.py @@ -43,6 +43,7 @@ def clip_grads_and_collect_metrics( global_step: int, collect_param_metrics: bool = True, process_group: Optional[dist.ProcessGroup] = None, + regularize_embeddings: bool = False, ) -> Dict[str, torch.Tensor]: """ Clips gradients for every group that has the field `max_grad_norm`. @@ -83,13 +84,14 @@ def clip_grads_and_collect_metrics( # with ReLoRa, for example. assert group.get("sharded", True) is True + is_embedding_group = group["name"] == "embedding_decay_group" for name, p in zip(group["param_names"], group["params"]): name = self._clean_param_name(name) - # Always need to collect the norm of gradients for clipping, even if we're not collecting + # Always need to collect the norm of gradients and parameters for clipping, even if we're not collecting # other metrics. tensors: List[Optional[torch.Tensor]] = [p.grad] prefixes: List[str] = [f"grad/{name}"] - if collect_param_metrics: + if collect_param_metrics or (regularize_embeddings and is_embedding_group): state = self.get_state_for_param(p) sorted_state_keys = sorted([k for k in state.keys()]) tensors.extend([p] + [state[key] for key in sorted_state_keys]) @@ -232,7 +234,7 @@ def is_grad_norm_metric(metric_name: str) -> bool: all_metrics["clipping_rate"] = clipping_rate return all_metrics else: - return {} + return all_metrics @torch.no_grad() def _do_adaptive_clipping( @@ -617,6 +619,7 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] # Separate out parameters that we don't want to apply weight decay to, like norms and biases. decay = set() no_decay = set() + embeddings_decay = set() all_params = {} for mn, m in model.named_modules(): for pn, p in m.named_parameters(): @@ -644,12 +647,14 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] elif pn.endswith("weight") and isinstance(m, nn.Embedding): if cfg.optimizer.decay_embeddings: decay.add(fpn) + elif cfg.optimizer.regularize_embeddings: + embeddings_decay.add(fpn) else: no_decay.add(fpn) # Validate that we've considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay + inter_params = decay & no_decay & embeddings_decay + union_params = decay | no_decay | embeddings_decay assert len(inter_params) == 0, f"parameters {inter_params} made it into both decay/no_decay sets!" assert ( len(all_params.keys() - union_params) == 0 @@ -658,12 +663,15 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] # Create the pytorch optimizer groups. decay_sorted = sorted(list(decay)) no_decay_sorted = sorted(list(no_decay)) + embeddings_decay_sorted = sorted(list(embeddings_decay)) + param_groups = [] if len(decay_sorted) > 0: param_groups.append( { "params": [all_params[pn] for pn in decay_sorted], "param_names": decay_sorted, + "name": "decay_group", **param_group_defaults, } ) @@ -673,6 +681,17 @@ def get_param_groups(cfg: TrainConfig, model: nn.Module) -> List[Dict[str, Any]] "params": [all_params[pn] for pn in no_decay_sorted], "param_names": no_decay_sorted, "weight_decay": 0.0, + "name": "no_decay_group", + **param_group_defaults, + } + ) + if len(embeddings_decay_sorted) > 0: + # the weight_decay value will be multiplied by emb_decay_factor in olmo/train.py + param_groups.append( + { + "params": [all_params[pn] for pn in embeddings_decay_sorted], + "param_names": embeddings_decay_sorted, + "name": "embedding_decay_group", **param_group_defaults, } ) diff --git a/olmo/train.py b/olmo/train.py index 0653b6bf3..ec1570667 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -19,11 +19,10 @@ import torch import torch.distributed as dist import torch.nn.functional as F +import wandb from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader -import wandb - from .aliases import PathOrStr from .checkpoint import Checkpointer, FullCheckpointer, build_sharded_checkpointer from .config import ( @@ -720,8 +719,14 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> # passing this process group here ensures metrics are reduced correctly when we're using # HYBRID sharding. process_group=self.fsdp_model.process_group, + regularize_embeddings=self.cfg.optimizer.regularize_embeddings, ) + emb_norm = optim_metrics["param/transformer.wte.weight.norm"] + emb_size = self.cfg.model.embedding_size or self.cfg.model.vocab_size + emb_std = math.sqrt(math.pow(emb_norm, 2) / float(emb_size * self.cfg.model.vocab_size)) + emb_decay_factor = 1.0 - emb_std + # Adjust the learning rate. for group in self.optim.param_groups: # TODO (epwalsh): if we want to enable different LRs or gradient clipping settings per group @@ -737,6 +742,9 @@ def train_step(self, batch: Dict[str, Any], reduce_global_loss: bool = True) -> self.cfg.max_grad_norm_ratio, self.scheduler_current, self.scheduler_max ) + if group["name"] == "embedding_decay_group": + group["weight_decay"] *= emb_decay_factor + # Optimizer step. self.optim.step() diff --git a/test_fixtures/reverse_wd.yaml b/test_fixtures/reverse_wd.yaml new file mode 100644 index 000000000..c75da3274 --- /dev/null +++ b/test_fixtures/reverse_wd.yaml @@ -0,0 +1,43 @@ +run_name: "reverse_test" +save_folder: "/tmp/olmo-train-tiny" +wandb: + name: ${run_name} + project: reverse-test +model: + d_model: 128 + n_heads: 4 + n_layers: 4 + mlp_ratio: 4 + alibi: false + alibi_bias_max: 8.0 + attention_dropout: 0.1 + attention_layer_norm: false + residual_dropout: 0.1 + embedding_dropout: 0.1 + max_sequence_length: 512 + vocab_size: 50257 + eos_token_id: 50256 + pad_token_id: 50256 + init_device: null + init_std: 0.02 +optimizer: + learning_rate: 0.001 + regularize_embeddings: true + metrics_log_interval: 100 +scheduler: + name: "cosine_with_warmup" + t_warmup: 10 +data: + paths: + - "/net/nfs.cirrascale/allennlp/llm-data/c4/en/c4-train.00000-00099.npy" + persistent_workers: false + num_workers: 0 + prefetch_factor: null +tokenizer: + identifier: "gpt2" +save_overwrite: true +max_duration: 16 +stop_at: ${max_duration} +global_train_batch_size: 8 +device_train_microbatch_size: 8 +precision: "fp32"