-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Memory leak with TorchIO 0.20.1 #1222
Comments
Hi for torchio 20.1 I used Platform: Linux-4.15.0-213-generic-x86_64-with-glibc2.27 and for previous torchio version Note that for this older torchio version I added SubjectDataloader define here Although I do not see memory consuption, I do see a large drop in speed ! and print total time at the end for torchio 0.19.1 I did the 200 epoch in 320 s whereas It tooks 725 for recent torhcio version ... Do you also see speed changes ? |
Hi ! Thank you for your feedback. Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.31 I also tested the previous TorchIO versions (0.20.0 and 0.19.1) without any issues. I adapted the code based on your recommendations for the older TorchIO version #1179 (comment) . However, I noticed that you were using a different version of Torch than mine. I tested with Torch 2.3.1, and during training, I observed no significant changes in RAM usage. It seems there might be an issue specifically with Torch 2.4.1 and TorchIO 0.20.1. I’ll continue to investigate to determine if that’s the case. Platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.31 Concerning the duration of epochs, I haven't noticed any impact of the TorchIO version in my various tests. |
Hi @FlorianScalvini, I am observing a similar issue which is related to the tensors created during My current workaround is to overwrite
This significnalty improved my memory usage during preprocessing. For context It could be that the memory allocator got changed in Maybe this is helpful, maybe this is unrelated. |
Code to reproduce missing de-allocation of memory import gc
import torchio as tio
class DummyTransform(tio.Transform):
def apply_transform(self, subject: tio.Subject) -> tio.Subject:
for img in subject.get_images_dict().values():
# Make shallow copy - could be the memory leak
new_img = img.data.clone().float().numpy()
# NOTE: Comment these 3 lines to reproduce increasing memory
if hasattr(img, tio.DATA):
img[tio.DATA] = None
gc.collect()
new_img += 1
img.set_data(new_img)
return subject
# Main function here -> called in __main__ or separately
def main():
# roughly 3GB in size (4 * 0.75GB)
subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})
transforms = tio.Compose([
DummyTransform(copy=False),
DummyTransform(copy=False),
DummyTransform(copy=False),
])
lazy_subjects_dataset = tio.SubjectsDataset(
subjects=length * [subject],
transform=transforms)
lazy_subjects_loader = tio.SubjectsLoader(
dataset=lazy_subjects_dataset,
batch_size=1,
num_workers=0)
# Monitor the Memory usage
for subject in lazy_subjects_loader:
print("Processed subject") By commeting the 3 lines inidicated by One idea of mine would be to have all transformations and subjects using Maybe someone with more knowledge about memory (de-) allocation in python and PyTorch than me has a better idea! |
since Florian issu is happening only with a specific torch and torchio version, @nicoloesch can you repport the output of |
Hello @nicoloesch and @romainVala! My problem may be related to @nicoloesch's concerning the missing de-allocation memory in transform functions. Indeed, when the transformation application is disabled by setting the transform parameter to None ( PS : I had to modify the initialization of the Subject object from my example to take into account the input size constraints of the Unet model (Divisible by 32), in order to test the impact of the preprocessing on memory
|
Hi @romainVala, My current environment is
However, the issue also appeared on I utilised Finally, my solution with the Based on @FlorianScalvini 's comment, it appears that both of our memory issues originate in the EDIT: The residual memory of the example above originates from the call |
Hi @romainVala, I dedicated some time this morning to memory profiling exact values to figure out the issue of the memory leak. For that, I utilised the memory profiler (despite not being in active development) and create the following dummy example. For that I created a new dummy Image, one for just processing everything in Pytorch called class NewImageNP(dict):
def __init__(
self,
tensor: Union[np.ndarray, torch.Tensor],
):
super().__init__()
tensor = tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor
assert isinstance(tensor, np.ndarray), f"Expected tensor to be of type np.ndarray but got {type(tensor)}"
self[tio.DATA] = tensor
@property
def data(self) -> np.ndarray:
return self[tio.DATA]
@data.setter
def data(self, tensor: np.ndarray) -> None:
self[tio.DATA] = tensor
class NewImagePT(dict):
def __init__(
self,
tensor: Union[np.ndarray, torch.Tensor],
):
super().__init__()
tensor = tensor if isinstance(tensor, torch.Tensor) else torch.as_tensor(tensor)
assert isinstance(tensor, torch.Tensor), f"Expected tensor to be of type torch.Tensor but got {type(tensor)}"
self[tio.DATA] = tensor
@property
def data(self) -> torch.Tensor:
return self[tio.DATA]
@data.setter
def data(self, tensor: torch.Tensor) -> None:
self[tio.DATA] = tensor
class NewSubject(dict):
def __init__(
self,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.update_attributes()
def update_attributes(self) -> None:
# This allows to get images using attribute notation, e.g. subject.t1
self.__dict__.update(self)
def get_images(
self,
intensity_only=True,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
) -> List[Union[NewImageNP, NewImagePT]]:
images_dict = self.get_images_dict(
intensity_only=intensity_only,
include=include,
exclude=exclude,
)
return list(images_dict.values())
def get_images_dict(
self,
intensity_only=True,
include: Optional[Sequence[str]] = None,
exclude: Optional[Sequence[str]] = None,
) -> Dict[str, Union[NewImageNP, NewImagePT]]:
images = {}
for image_name, image in self.items():
if not isinstance(image, (NewImageNP, NewImagePT)):
continue
if include is not None and image_name not in include:
continue
if exclude is not None and image_name in exclude:
continue
images[image_name] = image
return images
class NewSubjectsDataset(Dataset):
def __init__(
self,
subjects: Sequence[NewSubject],
transform: Optional[Callable] = None,
load_getitem: bool = True,
):
self._subjects = subjects
self._transform = transform # Skip the check for now
self.load_getitem = load_getitem
def __len__(self):
return len(self._subjects)
def __getitem__(self, index: int) -> NewSubject:
try:
index = int(index)
except (RuntimeError, TypeError) as err:
message = (
f'Index "{index}" must be int or compatible dtype,'
f' but an object of type "{type(index)}" was passed'
)
raise ValueError(message) from err
subject = self._subjects[index]
subject = copy.deepcopy(subject)
# Apply transform (this is usually the bottleneck)
if self._transform is not None:
subject = self._transform(subject)
return subject
class DummyTransform(tio.Transform):
@profile
def apply_transform(self, subject: NewSubject) -> NewSubject:
for img in subject.get_images_dict().values():
# Make shallow copy - that must be the memory leak
if isinstance(img, NewImageNP):
new_img = img.data.copy()
new_img += 1
img.data = new_img
elif isinstance(img, NewImagePT):
new_img = img.data.clone()
#if hasattr(img, tio.DATA):
# del img[tio.DATA]
# gc.collect()
new_img += 1
img.data = new_img
elif isinstance(img, tio.Image):
new_img = img.data.clone().float().numpy()
new_img += 1
new_img = torch.as_tensor(new_img)
img.set_data(new_img)
else:
raise ValueError(f"Expected data to be of type np.ndarray or torch.Tensor but got {type(img.data)}")
return subject
@profile
def run_main(
length: int = 4,
use_new: bool = True,
use_np: bool = False,
):
if use_new:
transforms = tio.Compose([
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"], parse_input=False),
])
img_cls = NewImageNP if use_np else NewImagePT
# roughly 3GB in size (4 * 0.75GB)
subject = NewSubject(**{seq: img_cls(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})
lazy_subjects_dataset = NewSubjectsDataset(
subjects=length * [subject],
transform=transforms
)
else:
transforms = tio.Compose(
[DummyTransform(copy=False),
DummyTransform(copy=False),
DummyTransform(copy=False),
]
)
# roughly 3GB in size (4 * 0.75GB)
subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})
lazy_subjects_dataset = tio.SubjectsDataset(
subjects=length * [subject],
transform=transforms
)
lazy_subjects_loader = tio.SubjectsLoader(
dataset=lazy_subjects_dataset,
batch_size=1,
num_workers=0
)
for subject in lazy_subjects_loader:
print("Processed subject")
if __name__ == "__main__":
run_main() As indicated by the decorator
Doing this revealed the memory leak. Every now and then, the method
The memory eventually gets somewhat deallocated as the initial RAM usage prior to entering the function differs - sometimes, the call even directly frees the memory (e.g. L 55, 85, 146, 176). You may have already noticed that my data structures utilise the attribute setter instead of I therefore concluded that it MUST be in the setter for the attribute ( So now what is going on? I have no clue. Maybe it is a pass-by-reference vs. pass-by-copy issues? Maybe someone with more knowledge in python can help out... What should be changed? I also have no clue as I was unable to find the exact origin of the memory leak. I hope this is still somewhat useful... ---- FILES ---- |
Hi @romainVala, hi @FlorianScalvini, Update: I think I found the memory leak - this time mostly definitive. I will provide a step-by-step overview of what I did.
The following code is required to reproduce the memory leak: import torchio as tio
from memory_profiler import profile
import torch
class DummyTransform(tio.Transform):
@profile
def apply_transform(self, subject: tio.Subject) -> tio.Subject:
for img in subject.get_images_dict().values():
# Make shallow copy - that must be the memory leak
if isinstance(img, tio.Image):
new_img = img.data.clone().float().numpy()
new_img += 1
new_img = torch.as_tensor(new_img)
img.set_data(new_img)
else:
raise ValueError(f"Expected data to be of type np.ndarray or torch.Tensor but got {type(img.data)}")
return subject
def run_main(
length: int = 4,
copy_compose: bool = True,
):
transforms = tio.Compose([
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
DummyTransform(copy=False, include=["t1c", "t1n", "t2w", "t2f"]),
], copy=copy_compose)
# roughly 3GB in size (4 * 0.75GB)
subject = tio.Subject(**{seq: tio.ScalarImage(tensor=torch.randn(1, 1024, 1024, 192)) for seq in ["t1c", "t1n", "t2w", "t2f"]})
subjects_dataset = tio.SubjectsDataset(
subjects=length * [subject],
transform=transforms
)
subjects_loader = tio.SubjectsLoader(
dataset=subjects_dataset,
batch_size=1,
num_workers=0 # increasing it will multiply the memory leak per worker
)
for subject in subjects_loader:
print("Processed subject")
if __name__ == "__main__":
run_main() How to reproduce the memory leak? ...
if self.copy:
subject = copy.copy(subject)
... This in theory will create a shallow copy of the object ( def __copy__(self):
return _subject_copy_helper(self, type(self))
...
def _subject_copy_helper(
old_obj: Subject,
new_subj_cls: Callable[[Dict[str, Any]], Subject],
):
result_dict = {}
for key, value in old_obj.items():
if isinstance(value, Image):
value = copy.copy(value)
else:
value = copy.deepcopy(value)
result_dict[key] = value
new = new_subj_cls(**result_dict) # type: ignore[call-arg]
new.applied_transforms = old_obj.applied_transforms[:]
return new And exactly here lies the issue. This method
I was wondering why there is a differentiation betweeen the copying behaviour of a Note: The memory behaviour changes between runs occasionally where sometimes memory is freed, sometimes memory is not freed. It still can be backtraced to the implementation of One way could be to define memory-safe I hope this helps and would love to hear some thoughts :) |
thanks for the detail report |
Pinging @justusschock as he worked on |
Hi folks, sorry I did see the ping just now (pretty busy these days ^^). @fepegar I only outsourced the existing implementation to a function to correctly handle subclasses of That being said though, I agree with @nicoloesch on the issue and that it should be safe to use My recommendation would be:
Sadly, I won't have the time to work on it myself, but I'd be happy to review a PR if someone else takes this on :) |
Hi everyone, I can take on the task for now! It actively restricts my workflow currently so I will need to find a solution regardless. I will check the current logic in terms of copying and prepare a PR (all @justusschock Thank you for offering to review the PR - I will keep you posted on any updates. Cheers, |
Hi everyone, I changed the code and verified the inexistence of the memory leak on my machine. I included two additional tests that verify the behaviour of a shallow copy vs. a deep copy. I will create a PR but it could also be beneficial if someone independent of me verifies the absence of the memory leak using the code from above with the One IMPORTANT side note: We are creating one copy of the def __getitem__(self, index: int) -> Subject:
try:
index = int(index)
except (RuntimeError, TypeError) as err:
message = (
f'Index "{index}" must be int or compatible dtype,'
f' but an object of type "{type(index)}" was passed'
)
raise ValueError(message) from err
subject = self._subjects[index]
subject = copy.deepcopy(subject) # cheap since images not loaded yet
if self.load_getitem:
subject.load()
# Apply transform (this is usually the bottleneck)
if self._transform is not None:
subject = self._transform(subject)
return subject I followed the contributing guidelines with the exception of installing the newest 3.12 version of python. Utilising the command
will pull Python 3.13.0, which is incompatible with
which pulled Python 3.12.7 instead. The only question I would have is that while searching for calls of One notable exception would be cases where we want to remove negative strides, which is documented at some occurences. Each of them reference the original GitHub issue within a Facebook project but can also be seen in the PyTorch forums. Examples for these
Based on the code preceeding the operation, it appears that it might alleviate the negative stride issue, but I wanted to double-check with you! Cheers, |
Thanks @nicoloesch for taking care of this. I could reproduce the memory grows when Note that I did not reproduce it when I used older version of python and torch but ok it is a quite old version and on should not care too much. (but just to point out that it is not an obvious issue) About torchio logic, I do not fully understand in Transform Class the copy argument say Why a shallow copy and not a full copy ? If one want to keep the original subject (before the transform), then I expect to have the all images content untouch, so I do want a deep_copy no? Can you see cases where one which to have just a shallow copy ? In the numpy documentation in this example If I make the analogy with torchio image then it would mean that we do copy of image attribute, but not of image data ... that does not make sense to me ... |
Many thanks for working on this, @nicoloesch. I'm not sure I understand your concern about NumPy copying. >>> import numpy as np
>>> x = np.arange(3)
>>> y = x.copy()
>>> x *= 2
>>> y
array([0, 1, 2]) As explained in the NumPy docs, the copy is shallow only for elements of dtype >>> import numpy as np
>>> x = np.array([{"parrot": "dead"}])
>>> y = x.copy()
>>> x[0]["parrot"] = "blue"
>>> y
array([{'parrot': 'blue'}], dtype=object) |
I encountered the same memory leak issue after modifying my code to use Here are the relevant package versions:
|
Is there an existing issue for this?
Bug summary
A memory leak issue is observed in TorchIO version 20.0.1 during prolonged training sessions. The memory usage gradually increases over time, leading to an eventual out-of-memory error. This issue is not present in version 0.20.0, where memory usage remains stable under similar conditions.
The following graphs show the evolution of RAM utilization during a basic DL training session, using the provided example code, with TorchIO version 0.20.0:
Here with TorchIO 0.20.1:
Do you have any tips for preventing this problem with Torch 0.20.1?
Thanks :)
Code for reproduction
Actual outcome
After a long period :
Process finished with exit code 137 (interrupted by signal 9: SIGKILL)
Error messages
No response
Expected outcome
It's an example of code where the expected outcome is irrelevant, but the RAM utilization should remain relatively constant.
System info
The text was updated successfully, but these errors were encountered: