Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory leak with TorchIO 0.20.1 #1222

Open
1 task done
FlorianScalvini opened this issue Oct 13, 2024 · 17 comments · May be fixed by #1227
Open
1 task done

Memory leak with TorchIO 0.20.1 #1222

FlorianScalvini opened this issue Oct 13, 2024 · 17 comments · May be fixed by #1227
Labels
bug Something isn't working

Comments

@FlorianScalvini
Copy link

FlorianScalvini commented Oct 13, 2024

Is there an existing issue for this?

  • I have searched the existing issues

Bug summary

A memory leak issue is observed in TorchIO version 20.0.1 during prolonged training sessions. The memory usage gradually increases over time, leading to an eventual out-of-memory error. This issue is not present in version 0.20.0, where memory usage remains stable under similar conditions.

The following graphs show the evolution of RAM utilization during a basic DL training session, using the provided example code, with TorchIO version 0.20.0:
torch20_0_0

Here with TorchIO 0.20.1:

torchio20_0_1

Do you have any tips for preventing this problem with Torch 0.20.1?

Thanks :)

Code for reproduction

import copy
from typing import Any
import torch
import torch.utils.data
import torchio as tio
import pytorch_lightning as pl
import monai
import psutil
import matplotlib.pyplot as plt


ram_usage = []
timestamps = []


def get_ram_usage():
    return psutil.virtual_memory().percent

# %% Lightning module
class BasicPLModule(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.net = monai.networks.nets.Unet(spatial_dims=3, in_channels=1, out_channels=1, channels=(8, 16, 32), strides=(2, 2)).float()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        return optimizer

    def training_step(self, batch, batch_idx):
        inputs = batch['t1'][tio.DATA].float()
        targets = batch['brain'][tio.DATA].float()
        preds = self.net(inputs)
        return torch.nn.functional.mse_loss(preds, targets)

    def on_train_epoch_end(self):
        ram_usage.append(get_ram_usage())
        timestamps.append(self.current_epoch)


# %% Main program
if __name__ == '__main__':

    subject = tio.datasets.Colin27()
    # convert to instance of new subject class
    subject = tio.Subject(
        t1=subject['t1'],
        head=subject['head'],
        brain=subject['brain'],
    )
    subjects = 10 * [subject]
    transforms = tio.Compose([
        tio.ZNormalization(masking_method = tio.ZNormalization.mean),
        tio.transforms.CropOrPad(target_shape=192)
    ])
    subject_dataset = tio.SubjectsDataset(subjects, transforms)
    training_loader = tio.SubjectsLoader(subject_dataset, batch_size=1)
    reg_net = BasicPLModule()

    trainer_args = {
        'max_epochs': 200,
    }

    trainer_reg = pl.Trainer(**trainer_args)
    trainer_reg.fit(reg_net, training_loader)

    # Plotting the data
    plt.plot(timestamps, ram_usage)
    plt.title('RAM Usage Over Time')
    plt.xlabel('Time')
    plt.ylabel('RAM Usage (%)')
    plt.grid(True)
    plt.show()

Actual outcome

After a long period :
Process finished with exit code 137 (interrupted by signal 9: SIGKILL)

Error messages

No response

Expected outcome

It's an example of code where the expected outcome is irrelevant, but the RAM utilization should remain relatively constant.

System info

Platform:   Windows-10-10.0.22631-SP0
TorchIO:    0.20.1
PyTorch:    2.4.1+cu121
SimpleITK:  2.4.0 (ITK 5.4)
NumPy:      2.0.2
Python:     3.9.13 (tags/v3.9.13:6de2ca5, May 17 2022, 16:36:42) [MSC v.1929 64 bit (AMD64)]
@FlorianScalvini FlorianScalvini added the bug Something isn't working label Oct 13, 2024
@FlorianScalvini FlorianScalvini changed the title Leak memory with TorchIO 20.0.1 Memory l’eau with TorchIO 20.0.1 Oct 14, 2024
@FlorianScalvini FlorianScalvini changed the title Memory l’eau with TorchIO 20.0.1 Memory leak with TorchIO 20.0.1 Oct 14, 2024
@FlorianScalvini FlorianScalvini changed the title Memory leak with TorchIO 20.0.1 Memory leak with TorchIO 0.20.1 Oct 14, 2024
@romainVala
Copy link
Contributor

Hi
thanks for reporting, (and for the simple example)
Unfortunately I could not reproduce. both version I tested did not show memory increase
(I first get increase memory, but I redo it and took care not to use extra memory during the test) (and run them one by one)
Can you reproduce the same result several times ? (running only the test script)

for torchio 20.1 I used

Platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.27
TorchIO: 0.20.1
PyTorch: 2.3.1
SimpleITK: 2.0.0rc2.dev912-g1eec0 (ITK 5.1)
NumPy: 1.26.4
Python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]

and for previous torchio version
Platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.27
TorchIO: 0.19.1
PyTorch: 2.3.1
SimpleITK: 2.0.0rc2.dev912-g1eec0 (ITK 5.1)
NumPy: 1.26.4
Python: 3.10.14 (main, May 6 2024, 19:42:50) [GCC 11.2.0]

Note that for this older torchio version I added SubjectDataloader define here
#1179 (comment)
so that it works

Although I do not see memory consuption, I do see a large drop in speed !
I changed your line
timestamps.append(self.current_epoch)
by
timestamps.append(datetime.datetime.now().timestamp())

and print total time at the end
timestamps[-1]-timestamps[0]

for torchio 0.19.1 I did the 200 epoch in 320 s whereas It tooks 725 for recent torhcio version ... Do you also see speed changes ?

@FlorianScalvini
Copy link
Author

FlorianScalvini commented Oct 14, 2024

Hi !

Thank you for your feedback.
I was able to reproduce the same error on a different PC, experiencing a rapid increase in memory usage with the previous code example (Number of epochs = 50).

Torch_24_TorchIO_2001

Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.31
TorchIO: 0.20.1
PyTorch: 2.4.1+cu121
SimpleITK: 2.4.0 (ITK 5.4)
NumPy: 2.1.2
Python: 3.11.10 (main, Sep 7 2024, 18:35:42) [GCC 9.4.0]

I also tested the previous TorchIO versions (0.20.0 and 0.19.1) without any issues. I adapted the code based on your recommendations for the older TorchIO version #1179 (comment) .

However, I noticed that you were using a different version of Torch than mine. I tested with Torch 2.3.1, and during training, I observed no significant changes in RAM usage. It seems there might be an issue specifically with Torch 2.4.1 and TorchIO 0.20.1. I’ll continue to investigate to determine if that’s the case.

Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.31
TorchIO: 0.20.1
PyTorch: 2.3.1+cu121
SimpleITK: 2.4.0 (ITK 5.4)
NumPy: 2.1.2
Python: 3.11.10 (main, Sep 7 2024, 18:35:42) [GCC 9.4.0]

Torch_231_TorchIO_1901

Concerning the duration of epochs, I haven't noticed any impact of the TorchIO version in my various tests.

@nicoloesch
Copy link
Contributor

Hi @FlorianScalvini,

I am observing a similar issue which is related to the tensors created during tio.Transforms not being correctly freed upon modification. I am not entirely sure if both are related as I am mostly focussing on the preprocessing and have not yet tracked the memory usage over time during training.

My current workaround is to overwrite set_data for the Image class as follows:

  1. set the self[DATA] of the Image to None
  2. call the garbage collector with gc.collect()
  3. set the new tensor back to self[DATA] = modified_tensor

This significnalty improved my memory usage during preprocessing.
I am also calling the transforms each with copy=False - might not be required but I am still investigating the issue.

For context
I am loading rather big images for each subject (roughly 2-3GB for a single subject with 4 sequences) and the memory leak propagates rather quickly throughout the entire preprocessing pipeline as tensors are copied but never/seldom freed. With multiple workers, I easily exceed 32GB of RAM during loading of the samples, which should not happen that quickly.

It could be that the memory allocator got changed in torch (I updated to 2.4 in the same go). I also tried pin_memory=False in the constructor for the torch.data.DataLoader but that didn't seem to fix my issue.

Maybe this is helpful, maybe this is unrelated.

@nicoloesch
Copy link
Contributor

Code to reproduce missing de-allocation of memory

import gc
import torchio as tio

class DummyTransform(tio.Transform):
    def apply_transform(self, subject: tio.Subject) -> tio.Subject:
        for img in subject.get_images_dict().values():
            # Make shallow copy - could be the memory leak
            new_img = img.data.clone().float().numpy()
            
            # NOTE: Comment these 3 lines to reproduce increasing memory
            if hasattr(img, tio.DATA):
                img[tio.DATA] = None
                gc.collect()
            new_img += 1
            img.set_data(new_img)
        return subject
        
# Main function here -> called in __main__ or separately
def main():
    # roughly 3GB in size (4 * 0.75GB)
    subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})
    
    transforms = tio.Compose([
        DummyTransform(copy=False),
        DummyTransform(copy=False),
        DummyTransform(copy=False),
    ])

    lazy_subjects_dataset = tio.SubjectsDataset(
        subjects=length * [subject],
        transform=transforms)

    lazy_subjects_loader = tio.SubjectsLoader(
        dataset=lazy_subjects_dataset,
        batch_size=1,
        num_workers=0)
        
    # Monitor the Memory usage         
    for subject in lazy_subjects_loader:
         print("Processed subject")       

By commeting the 3 lines inidicated by # NOTE, one may be able to see the increasing memory allocation compared to the version which frees the memory using the garbage collector.
I am also seeing this behaviour in torch=2.3.1.

One idea of mine would be to have all transformations and subjects using np.ndarray and then cast them to a torch.Tensor before returning them to the DataLoader (e.g. in __getitem__ of the Subject/Image). This would circumvent the PyTorch-specific memory allocator. I aware this would require a significant amount of refactoring and is also not tested by me if it solves the issue.

Maybe someone with more knowledge about memory (de-) allocation in python and PyTorch than me has a better idea!
And finally, maybe my issue and the issue of @FlorianScalvini are not at all related.

@romainVala
Copy link
Contributor

since Florian issu is happening only with a specific torch and torchio version, @nicoloesch can you repport the output of
python <(curl -s https://raw.githubusercontent.com/fepegar/torchio/main/print_system.py)

@FlorianScalvini
Copy link
Author

FlorianScalvini commented Oct 15, 2024

Hello @nicoloesch and @romainVala!

My problem may be related to @nicoloesch's concerning the missing de-allocation memory in transform functions.

Indeed, when the transformation application is disabled by setting the transform parameter to None (tio.SubjectsDataset(subjects, transform=None)) the memory problems disappear.

PS : I had to modify the initialization of the Subject object from my example to take into account the input size constraints of the Unet model (Divisible by 32), in order to test the impact of the preprocessing on memory

  subject = tio.Subject(
      t1=tio.ScalarImage(tensor=torch.rand(1, 192, 192, 192)),
      head=tio.ScalarImage(tensor=torch.rand(1, 192, 192, 192)),
      brain=tio.ScalarImage(tensor=torch.rand(1, 192, 192, 192)),
  )

@nicoloesch
Copy link
Contributor

nicoloesch commented Oct 15, 2024

Hi @romainVala,

My current environment is

Platform:   Linux-6.8.0-45-generic-x86_64-with-glibc2.35
TorchIO:    0.20.1
PyTorch:    2.3.1
SimpleITK:  2.4.0 (ITK 5.4)
NumPy:      2.0.2
Python:     3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]

However, the issue also appeared on PyTorch 2.2.2 with Torchio 0.18.90. I decided to upgrade to the newest version of Torchio (and a somewhat more recent version of PyTorch) to see if it fixed the issue. I assume the issue has been around for a while but I did not notice it as I used "regularly" sized tensors opposed to large multi-GB subjects. As a result, the potential memory leak wasn't as noticable, either because it never exceeded the RAM of my system or the garbage collector eventually caught up without explicilty calling it.

I utilised htop to monitor my RAM usage. I am aware this is not the most accurate as PyTorch apparently caches some memory and not immediately frees it to the system. However, I have noticed my program to crash (presumably due to OOM with SIGTERM) at the exact moment the RAM is fully occupied (or exceeded).

Finally, my solution with the gc only partially resolves the issue. There is still some memory not being freed and lingering around. In addition, if the transform makes a shallow copy of the tensor at the start and requires some long processing, the RAM might fill up using enough num_workers as none of the memory is freed prior to calling set_data. I modified the transforms to set the DATA member attribute of tio.Image to None immediately after the copy is made and calling gc.collect(). This frees the memory quicker but still has the issue of some memory not beeing freed at all - it at least does not accumulate over time as severe as without but rather settles at a specific value depending on the tensor size and the num_workers.

Based on @FlorianScalvini 's comment, it appears that both of our memory issues originate in the tio.Transform but are most likely to be attributed to PyTorch and its memory management of dereferenced/overwritten tensors ...

EDIT: The residual memory of the example above originates from the call copy.deepcopy(subject) in SubjectsDataset.__getitem__. I already allocated the tensor to not load the file from disk, which is then copied (and no longer cheap as the image has been "loaded"). As a result, the memory de-allocation with the garbage collector seems to work as intended for the Image.

@nicoloesch
Copy link
Contributor

Hi @romainVala,

I dedicated some time this morning to memory profiling exact values to figure out the issue of the memory leak. For that, I utilised the memory profiler (despite not being in active development) and create the following dummy example. For that I created a new dummy Image, one for just processing everything in Pytorch called NewImagePT and one for everything in numpy called NewImageNP.
As both new image datatypes are not supported by the regular torchio.SubjectsDataset and torchio.Subject, I created a very basic version of both that contain the basic functionality required to test the memory consumption - there are no checks/parsing etc. Finally, I also removed the parse_input of the DummyTransform so the check for my new datatypes does not fail. The full code is in the following:

class NewImageNP(dict):
    def __init__(
        self,
        tensor: Union[np.ndarray, torch.Tensor],
    ):
        super().__init__()
        tensor = tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor
        assert isinstance(tensor, np.ndarray), f"Expected tensor to be of type np.ndarray but got {type(tensor)}"
        self[tio.DATA] = tensor

    @property
    def data(self) -> np.ndarray:
        return self[tio.DATA]
    
    @data.setter
    def data(self, tensor: np.ndarray) -> None:
        self[tio.DATA] = tensor

class NewImagePT(dict):
    def __init__(
        self,
        tensor: Union[np.ndarray, torch.Tensor],
    ):
        super().__init__()
        tensor = tensor if isinstance(tensor, torch.Tensor) else torch.as_tensor(tensor)
        assert isinstance(tensor, torch.Tensor), f"Expected tensor to be of type torch.Tensor but got {type(tensor)}"
        self[tio.DATA] = tensor

    @property
    def data(self) -> torch.Tensor:
        return self[tio.DATA]
    
    @data.setter
    def data(self, tensor: torch.Tensor) -> None:
        self[tio.DATA] = tensor
    
class NewSubject(dict):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__(*args, **kwargs)
        self.update_attributes()
        
    def update_attributes(self) -> None:
        # This allows to get images using attribute notation, e.g. subject.t1
        self.__dict__.update(self)

    def get_images(
        self,
        intensity_only=True,
        include: Optional[Sequence[str]] = None,
        exclude: Optional[Sequence[str]] = None,
    ) -> List[Union[NewImageNP, NewImagePT]]:
        images_dict = self.get_images_dict(
            intensity_only=intensity_only,
            include=include,
            exclude=exclude,
        )
        return list(images_dict.values())
    
    def get_images_dict(
        self,
        intensity_only=True,
        include: Optional[Sequence[str]] = None,
        exclude: Optional[Sequence[str]] = None,
    ) -> Dict[str, Union[NewImageNP, NewImagePT]]:
        images = {}
        for image_name, image in self.items():
            if not isinstance(image, (NewImageNP, NewImagePT)):
                continue
            if include is not None and image_name not in include:
                continue
            if exclude is not None and image_name in exclude:
                continue
            images[image_name] = image
        return images

class NewSubjectsDataset(Dataset):
    def __init__(
        self,
        subjects: Sequence[NewSubject],
        transform: Optional[Callable] = None,
        load_getitem: bool = True,
    ):
        self._subjects = subjects
        self._transform = transform  # Skip the check for now
        self.load_getitem = load_getitem

    def __len__(self):
        return len(self._subjects)
    
    def __getitem__(self, index: int) -> NewSubject:
        try:
            index = int(index)
        except (RuntimeError, TypeError) as err:
            message = (
                f'Index "{index}" must be int or compatible dtype,'
                f' but an object of type "{type(index)}" was passed'
            )
            raise ValueError(message) from err

        subject = self._subjects[index]
        subject = copy.deepcopy(subject)

        # Apply transform (this is usually the bottleneck)
        if self._transform is not None:
            subject = self._transform(subject)
        return subject
            

class DummyTransform(tio.Transform):

    @profile
    def apply_transform(self, subject: NewSubject) -> NewSubject:

        for img in subject.get_images_dict().values():
            # Make shallow copy - that must be the memory leak
            if isinstance(img, NewImageNP):
                new_img = img.data.copy()
                new_img += 1
                img.data = new_img
            elif isinstance(img, NewImagePT):
                new_img = img.data.clone()
                #if hasattr(img, tio.DATA):
                #    del img[tio.DATA]
                #    gc.collect()
                new_img += 1
                img.data = new_img
            elif isinstance(img, tio.Image):
                new_img = img.data.clone().float().numpy()
                new_img += 1
                new_img = torch.as_tensor(new_img)
                img.set_data(new_img)
            else:
                raise ValueError(f"Expected data to be of type np.ndarray or torch.Tensor but got {type(img.data)}")
        return subject

@profile
def run_main(
    length: int = 4,
    use_new: bool = True,
    use_np: bool = False,
):
    if use_new:
        transforms = tio.Compose([
                DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
                DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
                DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
            ])
        img_cls = NewImageNP if use_np else NewImagePT

        # roughly 3GB in size (4 * 0.75GB)
        subject = NewSubject(**{seq: img_cls(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})

        lazy_subjects_dataset = NewSubjectsDataset(
            subjects=length * [subject],
            transform=transforms
        )
    
    else:
        transforms = tio.Compose(
            [DummyTransform(copy=False),
                DummyTransform(copy=False),
                DummyTransform(copy=False),
            ]
        )
        # roughly 3GB in size (4 * 0.75GB)
        subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})

        lazy_subjects_dataset = tio.SubjectsDataset(
            subjects=length * [subject],
            transform=transforms
        )

    lazy_subjects_loader = tio.SubjectsLoader(
        dataset=lazy_subjects_dataset,
        batch_size=1,
        num_workers=0
    )

    for subject in lazy_subjects_loader:
        print("Processed subject")
    
if __name__ == "__main__":
    run_main()

As indicated by the decorator @profile, the entire function run_main and the specific subroutine apply_transform of the DummyTransform are profiled. I attached two logs I have obtained to this comment:

  1. tio.txt: the output of the profiler for running the script with use_new=False. This means that the torchio Datastructures are utilised except my DummyTransform.
  2. mine.txt: the output of the profiler for running the script with use_new=True and use_np=False. This means that my datastructures are utilised throughout the script instead of the default ones.

Doing this revealed the memory leak.

Every now and then, the method set_data, which in theory overwrites the tensor in the dict, does not free any data. Examples for this in tio.txt are

  • L 25
  • L 116
  • L 207

The memory eventually gets somewhat deallocated as the initial RAM usage prior to entering the function differs - sometimes, the call even directly frees the memory (e.g. L 55, 85, 146, 176).
This varying behaviour is not observed using my datastructures with memory always being freed directly.

You may have already noticed that my data structures utilise the attribute setter instead of set_data. To test whether this does make a difference, I utilised the now deprecated attribute setter. The output is in tio_property.txt. Spoiler: Nothing changes. It is the same as with tio.txt with sometimes memory not beeing freed correctly.

I therefore concluded that it MUST be in the setter for the attribute (set_data/ img.data=...). However, I could not figure out what the difference is. I even removed the _parse_tensor call (see tio_nocheck.txt) and also set the attribute directly using the deprecated setter to achieve parity but nothing works. I am therefore currently at a loss why the bottleneck exists.

So now what is going on? I have no clue. Maybe it is a pass-by-reference vs. pass-by-copy issues? Maybe someone with more knowledge in python can help out...

What should be changed? I also have no clue as I was unable to find the exact origin of the memory leak.

I hope this is still somewhat useful...

---- FILES ----

tio_property.txt
tio_nocheck.txt
mine.txt
tio.txt

@nicoloesch
Copy link
Contributor

Hi @romainVala, hi @FlorianScalvini,

Update: I think I found the memory leak - this time mostly definitive. I will provide a step-by-step overview of what I did.
Let's start with my environment:

Platform:   Linux-6.8.0-45-generic-x86_64-with-glibc2.35
TorchIO:    0.20.1
PyTorch:    2.3.1
SimpleITK:  2.4.0 (ITK 5.4)
NumPy:      2.0.2
Python:     3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0]

The following code is required to reproduce the memory leak:

import torchio as tio
from memory_profiler import profile
import torch

class DummyTransform(tio.Transform):

    @profile
    def apply_transform(self, subject: tio.Subject) -> tio.Subject:

        for img in subject.get_images_dict().values():
            # Make shallow copy - that must be the memory leak
            if isinstance(img, tio.Image):
                new_img = img.data.clone().float().numpy()
                new_img += 1
                new_img = torch.as_tensor(new_img)
                img.set_data(new_img)
            else:
                raise ValueError(f"Expected data to be of type np.ndarray or torch.Tensor but got {type(img.data)}")
        return subject
        
        
def run_main(
    length: int = 4,
    copy_compose: bool = True,
):
    transforms = tio.Compose([
        DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
        DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
        DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
    ], copy=copy_compose)

    # roughly 3GB in size (4 * 0.75GB)
    subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})

    subjects_dataset = tio.SubjectsDataset(
        subjects=length * [subject],
        transform=transforms
    )

    subjects_loader = tio.SubjectsLoader(
        dataset=subjects_dataset,
        batch_size=1,
        num_workers=0  # increasing it will multiply the memory leak per worker
    )

    for subject in subjects_loader:
        print("Processed subject")

if __name__ == "__main__":
    run_main()

How to reproduce the memory leak?
The memory leak appears as soon as copy_compose=True. As a result, the base class tio.Transform of tio.Compose will make a copy of the subject in __call__ (see L.162) [all the transforms would do that but I call my DummyTransform explicitly with copy=False to reduce the memory leak and just have it for Compose]:

...
if self.copy:
    subject = copy.copy(subject)
...

This in theory will create a shallow copy of the object (tio.Subject is of type dict), which calls __copy__ of the tio.Subject if implemented. The method is implemented in subject.py to apply custom logic:

def __copy__(self):
    return _subject_copy_helper(self, type(self))
    
...

def _subject_copy_helper(
    old_obj: Subject,
    new_subj_cls: Callable[[Dict[str, Any]], Subject],
):
    result_dict = {}
    for key, value in old_obj.items():
        if isinstance(value, Image):
            value = copy.copy(value)
        else:
            value = copy.deepcopy(value)
        result_dict[key] = value

    new = new_subj_cls(**result_dict)  # type: ignore[call-arg]
    new.applied_transforms = old_obj.applied_transforms[:]
    return new

And exactly here lies the issue. This method _subejct_copy_helper is responsible for the memory leak. You can see the origin of memory leak in my profiled data.

  1. subject_copycompose-False.txt: This runs the script with copy_compose=False. As a result, none of the transforms copies the tio.Subject and we do not call the __copy__ method (at least from the transforms). This results in moderate memory usage despite the rather large tensors of roughly 10GB at peak.

  2. subject_copycompose-True.txt: We run the same script but this time with copy_compose=True. This in turn invokes the __call__ of each tio.Subject, which calls the _subject_copy_helper under the hood. You see exploding memory in the logs of around 28GB at peak.

  3. subject_copycompose-True_comment-copy.txt: Run the same script as in 2. but this time we just comment the __copy__ method of tio.Subject in the source file. As a result, the copy.copy call will create a shallow copy without custom logic (apparently using __new__ and __init__ - unsure about the exact python semantics). You will see the memory usage plummets down from 28GB back to the original 10GB, indicative of a shallow copy with shared memory.

I was wondering why there is a differentiation betweeen the copying behaviour of a tio.Image and any other attribute of the dict in _subjets_copy_helper? Why are other attributes called with deepcopy despite we are calling the method in __copy__ and not in __deepcopy__?

Note: The memory behaviour changes between runs occasionally where sometimes memory is freed, sometimes memory is not freed. It still can be backtraced to the implementation of __copy__ and ultimately the question of why this is required as calling copy.copy and copy.deepcopy on the tio.Subject without having any implementation of __call__ works as intended. Maybe it is useful to figure out when it is absolutely necessary to call copy.deepcopy (expensive with large tensors) and when to call copy.copy opposed to a mixed version of both.

One way could be to define memory-safe __copy__ and __deepcopy__ methods as described here but I am so far unsure if that is required, as the default without any specification of __copy__ copies the subject as intended.

I hope this helps and would love to hear some thoughts :)
Nico

@romainVala
Copy link
Contributor

thanks for the detail report
it gets above my poor python skill, so not sure I can help here ...

@fepegar
Copy link
Owner

fepegar commented Oct 18, 2024

Pinging @justusschock as he worked on _subject_copy_helper IIRC.

@justusschock
Copy link
Contributor

Hi folks,

sorry I did see the ping just now (pretty busy these days ^^).

@fepegar I only outsourced the existing implementation to a function to correctly handle subclasses of Subject :)

That being said though, I agree with @nicoloesch on the issue and that it should be safe to use copy for most things.
Probably the reason for going with deepcopy here is because python sometimes has weird rules to call by reference vs call by value (and therefore also when a copy vs. deepcopy is required). And with objects like tensors, where the object is merely a pointer to the actual consumed memory these things get even less clear.

My recommendation would be:

  • Check if just switching to a copy works as expected in all cases since it seems to reduce the memory footprint (we might need a separate deepcopy then though as we might still want to exclude images here)
  • if it doesn't act as expected make a list of things that we want to be shallow copied and handle them manually in this function.

Sadly, I won't have the time to work on it myself, but I'd be happy to review a PR if someone else takes this on :)

@nicoloesch
Copy link
Contributor

Hi everyone,

I can take on the task for now! It actively restricts my workflow currently so I will need to find a solution regardless. I will check the current logic in terms of copying and prepare a PR (all Transforms and everywhere a Subject is utilised, which may include additional tests regarding soft and deep/hard copy and the respective checks of attributes.

@justusschock Thank you for offering to review the PR - I will keep you posted on any updates.

Cheers,
Nico

@nicoloesch
Copy link
Contributor

Hi everyone,

I changed the code and verified the inexistence of the memory leak on my machine. I included two additional tests that verify the behaviour of a shallow copy vs. a deep copy. I will create a PR but it could also be beneficial if someone independent of me verifies the absence of the memory leak using the code from above with the DummyTransform.

One IMPORTANT side note: We are creating one copy of the Subject in each iteration over the SubjectsLoader. This originates from the __getitem__ call in SubjectsDataset (see code below), which creates a deep copy of the subject. The comment of the line states that this is cheap since we have not loaded images yet,. This is true for the regular usage of providing paths to images opposed to my example, where I allocate the tensor straight-away using torchio.ScalarImage(tensor=...)

def __getitem__(self, index: int) -> Subject:
        try:
            index = int(index)
        except (RuntimeError, TypeError) as err:
            message = (
                f'Index "{index}" must be int or compatible dtype,'
                f' but an object of type "{type(index)}" was passed'
            )
            raise ValueError(message) from err

        subject = self._subjects[index]
        subject = copy.deepcopy(subject)  # cheap since images not loaded yet
        if self.load_getitem:
            subject.load()

        # Apply transform (this is usually the bottleneck)
        if self._transform is not None:
            subject = self._transform(subject)
        return subject

I followed the contributing guidelines with the exception of installing the newest 3.12 version of python. Utilising the command

conda create --name torchioenv python --yes

will pull Python 3.13.0, which is incompatible with torchio (Requirement: < 3.13). I therefore utilised

conda create --name torchioenv python=3.12

which pulled Python 3.12.7 instead.

The only question I would have is that while searching for calls of copy.copy() and copy.deepcopy() I have noticed that a lot of np.ndarray are copied using the np.ndarray.copy() method. This creates, according to the numpy documentation (requires to also navigate to the documentation of numpy.copy), a shallow copy instead of a deep copy. I am not entirely sure if that is what we want as some of the methods are modifying the array in place and are then re-setting the original attribute with the modified one. Using a shallow copy, the original array should already reflect the changes so creating the copy that way might not be necessary.

One notable exception would be cases where we want to remove negative strides, which is documented at some occurences. Each of them reference the original GitHub issue within a Facebook project but can also be seen in the PyTorch forums.

Examples for these np.ndarray.copy() calls with subsequent in-place array modifications are:

  • torchio.transforms.augmentation.spatial.ElasticDeformation.get_bspline_transform [L. 240]
  • torchio.transforms.preprocessing.spatial.Crop.apply_transform [L. 47] (we are making a copy of affine, alter new_affine in place and then overwrite the original affine again)
  • torchio.transforms.preprocessing.spatial.Pad.apply_transform [L. 92] (we are making a copy of affine, alter new_affine in place and then overwrite the original affine again)

Based on the code preceeding the operation, it appears that it might alleviate the negative stride issue, but I wanted to double-check with you!

Cheers,
Nico

@nicoloesch nicoloesch linked a pull request Oct 24, 2024 that will close this issue
8 tasks
@romainVala
Copy link
Contributor

Thanks @nicoloesch for taking care of this.
this is important.

I could reproduce the memory grows when copy_compose=True and not in case it is False

Note that I did not reproduce it when I used older version of python and torch
PyTorch: 1.6.0
Python: 3.8.3

but ok it is a quite old version and on should not care too much. (but just to point out that it is not an obvious issue)
Anyway I do reproduce for version
PyTorch: 2.3.1
Python: 3.10.14

About torchio logic, I do not fully understand in Transform Class the copy argument say
copy: Make a shallow copy of the input before applying the transform.

Why a shallow copy and not a full copy ?

If one want to keep the original subject (before the transform), then I expect to have the all images content untouch, so I do want a deep_copy no?

Can you see cases where one which to have just a shallow copy ?
If yes what is exactly shallow copy of an image ?

In the numpy documentation in this example
a = np.array([1, 'm', [2, 3, 4]], dtype=object)
b = a.copy()
I finally understood what a shallow copy is by noticing that
a[0] and a[1] are copied but not a[2] .... (b[0]=3 won't change a but b[2][0])=3 change a)

If I make the analogy with torchio image then it would mean that we do copy of image attribute, but not of image data ... that does not make sense to me ...

@fepegar
Copy link
Owner

fepegar commented Oct 24, 2024

Many thanks for working on this, @nicoloesch. I'm not sure I understand your concern about NumPy copying.

>>> import numpy as np
>>> x = np.arange(3)
>>> y = x.copy()
>>> x *= 2
>>> y
array([0, 1, 2])

As explained in the NumPy docs, the copy is shallow only for elements of dtype object:

>>> import numpy as np
>>> x = np.array([{"parrot": "dead"}])
>>> y = x.copy()
>>> x[0]["parrot"] = "blue"
>>> y
array([{'parrot': 'blue'}], dtype=object)

@Bigsealion
Copy link

I encountered the same memory leak issue after modifying my code to use tio.SubjectsLoader instead of torch.utils.data.DataLoader. This causes my Python script to run out of memory and stop after a certain period.

Here are the relevant package versions:

In [5]: tio.__version__
Out[5]: '0.20.1'

In [6]: torch.__version__
Out[6]: '2.5.1+cu121'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants