Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

num_training_batches is inf in configure_optimizers #16060

Open
davidgilbertson opened this issue Dec 15, 2022 · 6 comments · May be fixed by #20148
Open

num_training_batches is inf in configure_optimizers #16060

davidgilbertson opened this issue Dec 15, 2022 · 6 comments · May be fixed by #20148
Labels
bug Something isn't working data handling Generic data-related topic loops Related to the Loop API
Milestone

Comments

@davidgilbertson
Copy link
Contributor

davidgilbertson commented Dec 15, 2022

Bug description

The value of num_training_batches is inf when referenced in configure_optimizers(). It seems that it doesn't actually get its correct value until some point later. This causes a very hard-to-find issue because the training runs without error, except the loss is nan.

Something inside optim.lr_scheduler.CyclicLR actually sets the lr of the optimizer to nan.

It would be nice if:

  • This value was available configure_optimizers() was called, or
  • There was a warning if accessing it before it's set

How to reproduce the bug

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        print(f"{optimizer.param_groups[0]['lr'] = }")  # 0.1
        lr_scheduler = torch.optim.lr_scheduler.CyclicLR(
            optimizer=optimizer,
            base_lr=0.01,
            max_lr=0.1,
            step_size_up=self.trainer.num_training_batches * 1,  # problematic!
            step_size_down=self.trainer.num_training_batches * 2,  # problematic!
            cycle_momentum=False,
        )
        print(f"{optimizer.param_groups[0]['lr'] = }")  # nan
        return [optimizer], [lr_scheduler]


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        enable_model_summary=False,
        enable_checkpointing=False,
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error messages and logs

The main hint something is wrong is actually tensorboard printing "NaN or Inf found in input tensor" - but even that doesn't come with a trace telling me who's printing this.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

cc @justusschock @awaelchli @carmocca

@davidgilbertson davidgilbertson added the needs triage Waiting to be triaged by maintainers label Dec 15, 2022
@stale
Copy link

stale bot commented Jan 21, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Jan 21, 2023
@davidgilbertson
Copy link
Contributor Author

FWIW, I question the logic of "stale bots". Do you not want GitHub to be a place that users can inform you of potential problems? What does automatic closure achieve other than sweeping issues under the carpet? I'm much rather a human being apply the "won't fix" label, I'm OK with that.

@stale stale bot removed the won't fix This will not be worked on label Jan 22, 2023
@stale
Copy link

stale bot commented Apr 14, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 14, 2023
@S-aiueo32
Copy link
Contributor

S-aiueo32 commented Apr 29, 2023

I'm facing this one on v1.9.0. It seems to be a problem around DDP. One of the workers can get the correct value but the others fail and get inf. Are there any workarounds?

@stale stale bot removed the won't fix This will not be worked on label Apr 29, 2023
@S-aiueo32
Copy link
Contributor

I found estimated_stepping_batches property, which works fine with DDP.

@awaelchli awaelchli added bug Something isn't working data handling Generic data-related topic loops Related to the Loop API and removed needs triage Waiting to be triaged by maintainers labels Sep 12, 2023
@awaelchli awaelchli added this to the 2.0.x milestone Sep 12, 2023
@Borda Borda modified the milestones: 2.0.x, 2.1.x Oct 12, 2023
@awaelchli awaelchli modified the milestones: 2.1.x, 2.2.x Feb 8, 2024
@awaelchli awaelchli modified the milestones: 2.2.x, 2.3.x Jun 13, 2024
@awaelchli awaelchli modified the milestones: 2.3.x, 2.4.x Aug 7, 2024
@nourgana
Copy link

nourgana commented Aug 20, 2024

@davidgilbertson have you found a workaround for this ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic loops Related to the Loop API
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants