Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adam update occurs during accelerator.backward rather than optimizer.step #3291

Open
jsmith-1989 opened this issue Dec 12, 2024 · 0 comments

Comments

@jsmith-1989
Copy link

jsmith-1989 commented Dec 12, 2024

Accelerate seems to apply the Adam update when calling loss.backward rather than when calling optimizer.step. Moreover, accelerator.no_sync doesn't prevent the adam update.

This caused a hard-to-debug issue for me, because I was trying to set the learning rate between calling backward and calling step (as in the script below), but my learning rate wasn't applied until the next update.

import torch
from accelerate import Accelerator
from torch import nn
from torch.optim import Adam


# Simple model with just one parameter
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.variable = nn.Parameter(torch.tensor([1.0]))

    def forward(self):
        return self.variable


# Monkey patch Adam to print something before step
original_step = torch.optim.Adam.step


# Create new step method that prints lr
def patched_step(self, *args, **kwargs):
    for i, group in enumerate(self.param_groups):
        print(f"Group {i} learning rate before step: {group['lr']}", flush=True)
    return original_step(self, *args, **kwargs)


torch.optim.Adam.step = patched_step

# Setup
accelerator = Accelerator()
model = SimpleModel()
optimizer = Adam(model.parameters(), lr=1000)  # Start with a large LR

# Prepare with accelerator
model, optimizer = accelerator.prepare(model, optimizer)

print(f"Initial lr: {optimizer.param_groups[0]['lr']}")  # Should show 0.1

loss = model()
print("before calling backward")
accelerator.backward(loss)
print("after calling backward")

# now let's change the LR
new_lr = 0.01
for param_group in optimizer.param_groups:
    param_group["lr"] = new_lr
print(f"Set new lr to: {new_lr}")

# Call step
optimizer.step()

print("After step")

By looking at the printouts, you can see that the Adam step happens on accelerator.backward, and it's using the large LR.

I also tried replacing accelerator.backward(loss) with

with accelerator.no_sync(model):
    accelerator.backward(loss)

but this didn't seem to have any effect.

Here are my config files:

deepspeed.yaml:

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: /PATH/TO/deepspeed_config.json
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config: {}
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: 2
use_cpu: false

deepspeed_config.json

{
    "communication_data_type": "fp32",
    "bf16": {
        "enabled": false
    },
    "fp16": {
        "enabled": false
    },
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "reduce_bucket_size": 16777216,
        "contiguous_gradients": true,
        "stage3_gather_16bit_weights_on_model_save": true,
        "stage3_prefetch_bucket_size": 15e6,
        "stage3_param_persistence_threshold": 40960,
        "stage3_max_live_parameters": 2e9,
        "stage3_max_reuse_distance": 2e9
    },
    "gradient_clipping": false,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": 1,
    "steps_per_print": 2000000
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant