Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronize validation and test logging #20492

Open
jialiangZ opened this issue Dec 12, 2024 · 0 comments
Open

Synchronize validation and test logging #20492

jialiangZ opened this issue Dec 12, 2024 · 0 comments
Labels
needs triage Waiting to be triaged by maintainers refactor

Comments

@jialiangZ
Copy link

jialiangZ commented Dec 12, 2024

Outline & Motivation

In the VALIDATION STEP after performing the distributed training process, two ways are given here, one is to continue with the distributed validation process and the other is to do it separately on rank0. But in the previous tutorial, it was hinted that the validation process should be carried out using rank0 in ddp, so I am wondering which one should be used? thanks!!!

Synchronize validation and test logging

When running in distributed mode, we have to ensure that the validation and test step logging calls are synchronized across processes. This is done by adding sync_dist=True to all self.log calls in the validation and test step. This ensures that each GPU worker has the same behaviour when tracking model checkpoints, which is important for later downstream tasks such as testing the best checkpoint across all workers. The sync_dist option can also be used in logging calls during the step methods, but be aware that this can lead to significant communication overhead and slow down your training.

Note if you use any built in metrics or custom metrics that use TorchMetrics, these do not need to be updated and are automatically handled for you.

def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)


def test_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    # Add sync_dist=True to sync logging across all GPU workers (may have performance impact)
    self.log("test_loss", loss, on_step=True, on_epoch=True, sync_dist=True)

It is possible to perform some computation manually and log the reduced result on rank 0 as follows:

def __init__(self):
    super().__init__()
    self.outputs = []


def test_step(self, batch, batch_idx):
    x, y = batch
    tensors = self(x)
    self.outputs.append(tensors)
    return tensors


def on_test_epoch_end(self):
    mean = torch.mean(self.all_gather(self.outputs))
    self.outputs.clear()  # free memory

    # When you call `self.log` only on rank 0, don't forget to add
    # `rank_zero_only=True` to avoid deadlocks on synchronization.
    # Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/lightning/issues/15852
    if self.trainer.is_global_zero:
        self.log("my_reduced_metric", mean, rank_zero_only=True)

Note

It is recommended to validate on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampler is used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.

Pitch

告知在进行普通任务时的默认方式

Additional context

No response

cc @justusschock @awaelchli

@jialiangZ jialiangZ added needs triage Waiting to be triaged by maintainers refactor labels Dec 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs triage Waiting to be triaged by maintainers refactor
Projects
None yet
Development

No branches or pull requests

1 participant