Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix initial_lr when resuming training #243

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Lauler
Copy link

@Lauler Lauler commented Nov 17, 2024

The argument initial_lr in lambda_lr() is initialized incorrectly from the learning rate of an optimizer's parameter groups. This causes the learning rate to be set incorrectly when models are resumed from checkpoints trained with standard parametrization LR schedulers.

for param_group in optimizer.get_base_optimizer().param_groups:
lr_lambdas.append(get_lr_lambda_for_param_group(lr=param_group["lr"]))

It should instead be initialized from the initial learning rate of the optimizer's param groups. However, the key "initial_lr" in the optimizer does not exist when training is started, only when training is resumed from a checkpoint. I've therefore set this argument to lr_scheduler_args.learning_rate, which seems to work in standard parametrization, but almost certainly breaks something in mu-Parametrization.

See this issue comment for context: #233 (comment)

The argument initial_lr in lambda_lr() is initialized incorrectly from the learning rate of an optimizer's parameter groups. It should instead be initialized from the INITIAL learning rate of the optimizer's param groups.
@Lauler
Copy link
Author

Lauler commented Nov 18, 2024

This might break μ-Parametrization. In that case should probably unbundle the logic for standard parametrization and μP.

@Lauler
Copy link
Author

Lauler commented Nov 22, 2024

This was fixed better in other pull request here: #245 .

@Lauler
Copy link
Author

Lauler commented Nov 27, 2024

My mistake. This was not fixed in #245 .

Original training run with the latest commit of Nanotron as of writing this. Learning rate at Step 11001 is 0.000103.

[0534:0]:11/27/2024 19:06:55 [INFO|DP=0|PP=0|TP=0|lrdn0534]: iteration: 11001 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 1.28K | tokens_per_sec: 102K | tokens_per_sec_per_gpu: 6.38K | global_batch_size: 64 | lm_loss: 2.99 | lr: 0.000103 | model_tflops_per_gpu: 73.3 | hardware_tflops_per_gpu: 73.3 | grad_norm: 0.246 | cuda_memory_allocated: 8.86G | cuda_max_memory_reserved: 34.8G | hd_total_memory_tb: 9.48G | hd_used_memory_tb: 8.23G | hd_free_memory_tb: 1.24G
[0534:0]:11/27/2024 19:07:01 [INFO|DP=0|PP=0|TP=0|lrdn0534]: iteration: 11011 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 665 | tokens_per_sec: 197K | tokens_per_sec_per_gpu: 12.3K | global_batch_size: 64 | lm_loss: 3.01 | lr: 0.000103 | model_tflops_per_gpu: 141 | hardware_tflops_per_gpu: 141 | grad_norm: 0.231
[0534:0]:11/27/2024 19:07:07 [INFO|DP=0|PP=0|TP=0|lrdn0534]: iteration: 11021 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 666 | tokens_per_sec: 197K | tokens_per_sec_per_gpu: 12.3K | global_batch_size: 64 | lm_loss: 2.99 | lr: 0.000102 | model_tflops_per_gpu: 141 | hardware_tflops_per_gpu: 141 | grad_norm: 0.236

Resuming training without my patch applied resumes at the incorrect learning rate (0.000145):

[0419:0]:11/27/2024 21:32:54 [INFO|DP=0|PP=0|TP=0|lrdn0419]: iteration: 11011 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 677 | tokens_per_sec: 194K | tokens_per_sec_per_gpu: 12.1K | global_batch_size: 64 | lm_loss: 3.01 | lr: 0.000145 | model_tflops_per_gpu: 139 | hardware_tflops_per_gpu: 139 | grad_norm: 0.243
[0419:0]:11/27/2024 21:33:00 [INFO|DP=0|PP=0|TP=0|lrdn0419]: iteration: 11021 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 673 | tokens_per_sec: 195K | tokens_per_sec_per_gpu: 12.2K | global_batch_size: 64 | lm_loss: 3 | lr: 0.000144 | model_tflops_per_gpu: 140 | hardware_tflops_per_gpu: 140 | grad_norm: 0.246
[0419:0]:11/27/2024 21:33:07 [INFO|DP=0|PP=0|TP=0|lrdn0419]: iteration: 11031 / 16700 | consumed_tokens: 1.45G | elapsed_time_per_iteration_ms: 670 | tokens_per_sec: 196K | tokens_per_sec_per_gpu: 12.2K | global_batch_size: 64 | lm_loss: 3.04 | lr: 0.000144 | model_tflops_per_gpu: 141 | hardware_tflops_per_gpu: 141 | grad_norm: 0.237

Resuming training with my hacky patch applied (lr resumes at 0.000103):

[0435:0]:11/27/2024 21:52:00 [INFO|DP=0|PP=0|TP=0|lrdn0435]: iteration: 11011 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 673 | tokens_per_sec: 195K | tokens_per_sec_per_gpu: 12.2K | global_batch_size: 64 | lm_loss: 3.01 | lr: 0.000103 | model_tflops_per_gpu: 140 | hardware_tflops_per_gpu: 140 | grad_norm: 0.231
[0435:0]:11/27/2024 21:52:07 [INFO|DP=0|PP=0|TP=0|lrdn0435]: iteration: 11021 / 16700 | consumed_tokens: 1.44G | elapsed_time_per_iteration_ms: 669 | tokens_per_sec: 196K | tokens_per_sec_per_gpu: 12.2K | global_batch_size: 64 | lm_loss: 2.99 | lr: 0.000102 | model_tflops_per_gpu: 141 | hardware_tflops_per_gpu: 141 | grad_norm: 0.236
[0435:0]:11/27/2024 21:52:13 [INFO|DP=0|PP=0|TP=0|lrdn0435]: iteration: 11031 / 16700 | consumed_tokens: 1.45G | elapsed_time_per_iteration_ms: 688 | tokens_per_sec: 191K | tokens_per_sec_per_gpu: 11.9K | global_batch_size: 64 | lm_loss: 3.03 | lr: 0.000102 | model_tflops_per_gpu: 137 | hardware_tflops_per_gpu: 137 | grad_norm: 0.229

Initial LR and optimizer LR still aren't in sync in src/nanotron/helper.py @NouamaneTazi . When resuming training in the lr_decay_steps stages of training, it will resume at the incorrect learning rate. When testing this bug, make sure you are resuming from a checkpoint that was saved in the LR decay stage of training, as opposed to at the stage where LR is at the learning_rate specified in yaml config. The bigger the difference between last LR in checkpoint and initial LR, the bigger the discrepancy when resuming.

I'm not knowledgeable enough about μ-Parametrization to suggest a general fix for this. Perhaps a True/False flag somewhere to signal whether the current training run is resuming from a checkpoint or not. And if it is, then initialize the LR scheduler with param_group["initial_lr"] instead.

Since most people including me are not using μP, this works as a temporary fix for us.

@Lauler Lauler reopened this Nov 27, 2024
gritukan added a commit to thenno/nanotron-tractoai that referenced this pull request Nov 28, 2024
@NouamaneTazi
Copy link
Member

Thanks a lot for opening the PR and the details explanation! I think we merged a fix recently here https://github.com/huggingface/nanotron/pull/245/files
Can you confirm that it works for you?

@Lauler
Copy link
Author

Lauler commented Dec 4, 2024

Unfortunately #245 did not fix this issue. Learning rates still don't match up when resuming in the latest commit of Nanotron with a clean environment.

If you want to reproduce my test below, here is the yaml config file.

Original training run at step 1401 out of 1500 shows learning rate at lr: 4.37e-05:

12/04/2024 13:14:03 [WARNING|DP=0|PP=0|TP=0]: Saving checkpoint at checkpoints/test_lr/1400
Saving weights: 100%|████████████████████████████████████████████████████████████| 195/195 [00:01<00:00, 126.66it/s]
12/04/2024 13:14:12 [INFO|DP=0|PP=0|TP=0]: iteration: 1401 / 1500 | consumed_tokens: 2.87M | elapsed_time_per_iteration_ms: 255 | tokens_per_sec: 8.04K | tokens_per_sec_per_gpu: 8.04K | global_batch_size: 4 | lm_loss: 6.67 | lr: 4.37e-05 | model_tflops_per_gpu: 22.6 | hardware_tflops_per_gpu: 22.6 | grad_norm: 1.58 | cuda_memory_allocated: 7.92G | cuda_max_memory_reserved: 13G | hd_total_memory_tb: 78.2G | hd_used_memory_tb: 65.4G | hd_free_memory_tb: 8.69G
12/04/2024 13:14:14 [INFO|DP=0|PP=0|TP=0]: iteration: 1411 / 1500 | consumed_tokens: 2.89M | elapsed_time_per_iteration_ms: 337 | tokens_per_sec: 6.07K | tokens_per_sec_per_gpu: 6.07K | global_batch_size: 4 | lm_loss: 6.16 | lr: 3.55e-05 | model_tflops_per_gpu: 17.1 | hardware_tflops_per_gpu: 17.1 | grad_norm: 1.09

Current Nanotron commit after resuming at 1401 steps is at lr: 0.000108:

12/04/2024 14:06:15 [INFO|DP=0|PP=0|TP=0]: iteration: 1401 / 1500 | consumed_tokens: 2.87M | elapsed_time_per_iteration_ms: 3.19K | tokens_per_sec: 642 | tokens_per_sec_per_gpu: 642 | global_batch_size: 4 | lm_loss: 6.67 | lr: 0.000108 | model_tflops_per_gpu: 1.81 | hardware_tflops_per_gpu: 1.81 | grad_norm: 1.58 | cuda_memory_allocated: 7.9G | cuda_max_memory_reserved: 9.41G | hd_total_memory_tb: 78.2G | hd_used_memory_tb: 65.4G | hd_free_memory_tb: 8.69G
12/04/2024 14:06:17 [INFO|DP=0|PP=0|TP=0]: iteration: 1411 / 1500 | consumed_tokens: 2.89M | elapsed_time_per_iteration_ms: 283 | tokens_per_sec: 7.25K | tokens_per_sec_per_gpu: 7.25K | global_batch_size: 4 | lm_loss: 6.16 | lr: 0.0001 | model_tflops_per_gpu: 20.4 | hardware_tflops_per_gpu: 20.4 | grad_norm: 1.5

Applying the patch in #256 where LR scheduler builder is initialized before the optimizer is loaded results in resuming at lr: 4.29e-05:

12/04/2024 14:34:48 [INFO|DP=0|PP=0|TP=0]: iteration: 1401 / 1500 | consumed_tokens: 2.87M | elapsed_time_per_iteration_ms: 3.21K | tokens_per_sec: 637 | tokens_per_sec_per_gpu: 637 | global_batch_size: 4 | lm_loss: 6.67 | lr: 4.29e-05 | model_tflops_per_gpu: 1.79 | hardware_tflops_per_gpu: 1.79 | grad_norm: 1.58 | cuda_memory_allocated: 7.9G | cuda_max_memory_reserved: 9.41G | hd_total_memory_tb: 78.2G | hd_used_memory_tb: 65.4G | hd_free_memory_tb: 8.69G
12/04/2024 14:34:50 [INFO|DP=0|PP=0|TP=0]: iteration: 1411 / 1500 | consumed_tokens: 2.89M | elapsed_time_per_iteration_ms: 313 | tokens_per_sec: 6.53K | tokens_per_sec_per_gpu: 6.53K | global_batch_size: 4 | lm_loss: 6.16 | lr: 3.48e-05 | model_tflops_per_gpu: 18.4 | hardware_tflops_per_gpu: 18.4 | grad_norm: 1.13

My patch in this PR also results in resuming at the same LR as above lr: 4.29e-05:

12/04/2024 14:40:18 [INFO|DP=0|PP=0|TP=0]: iteration: 1401 / 1500 | consumed_tokens: 2.87M | elapsed_time_per_iteration_ms: 3.23K | tokens_per_sec: 635 | tokens_per_sec_per_gpu: 635 | global_batch_size: 4 | lm_loss: 6.67 | lr: 4.29e-05 | model_tflops_per_gpu: 1.79 | hardware_tflops_per_gpu: 1.79 | grad_norm: 1.58 | cuda_memory_allocated: 7.9G | cuda_max_memory_reserved: 9.41G | hd_total_memory_tb: 78.2G | hd_used_memory_tb: 65.4G | hd_free_memory_tb: 8.69G

The reason we are slightly off the value of the original training run is because you've added an ._initial_step() of the LR scheduler in #245 .

lr_scheduler._initial_step() # NOTE: this is required to set the initial learning rate

If we comment this line out, the training resumes at the correct value lr: 4.37e-05:

12/04/2024 14:51:48 [INFO|DP=0|PP=0|TP=0]: iteration: 1401 / 1500 | consumed_tokens: 2.87M | elapsed_time_per_iteration_ms: 3.24K | tokens_per_sec: 631 | tokens_per_sec_per_gpu: 631 | global_batch_size: 4 | lm_loss: 6.67 | lr: 4.37e-05 | model_tflops_per_gpu: 1.78 | hardware_tflops_per_gpu: 1.78 | grad_norm: 1.58 | cuda_memory_allocated: 7.9G | cuda_max_memory_reserved: 9.41G | hd_total_memory_tb: 78.2G | hd_used_memory_tb: 65.4G | hd_free_memory_tb: 8.69G

I would recommend

You can close this pull request. The solution in #256 looks much better.

@NouamaneTazi
Copy link
Member

Perfect ty @Lauler

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants