Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Standardize Axes in Random Transforms. Add Random Axis to RandomMotion #1185

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

blakedewey
Copy link
Contributor

Fixes #81

Description

This pull request standardizes the use of axes in random transforms. It also adds random axes options in the RandomMotion transform. Includes the following changes:

  • Update to_tuple to see strings as singular values (this also required adding number of changes to checks for strings in areas where numbers are necessary). Tests all pass.
  • Add parse_axes and ensure_axes_indices to base Transform class. Standardizes the use of axes in all transforms. Capable of accepting integer and specific string axis values. This is now used in RandomFlip, RandomMotion, RandomGhosting, RandomAnisotropy. Tests were added to check these additional axis values.
  • Added a random axis parameter to RandomMotion. This now works properly on 2D images and can apply motion to any axis of 3D images. Also removed redundant change between matrix and transform.

Checklist

  • I have read the CONTRIBUTING docs and have a developer setup (especially important are pre-commitand pytest)
  • Non-breaking change (would not break existing functionality)
  • Breaking change (would cause existing functionality to change)
  • Tests added or modified to cover the changes
  • Integration tests passed locally by running pytest
  • In-line docstrings updated
  • Documentation updated, tested running make html inside the docs/ folder
  • This pull request is ready to be reviewed

@blakedewey
Copy link
Contributor Author

All checks pass now.

@@ -392,6 +402,15 @@ def axis_name_to_index(self, axis: str) -> int:
if not isinstance(axis, str):
raise ValueError('Axis must be a string')
axis = axis[0].upper()
if axis not in 'LRPAISTB':
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if axis not in 'LRPAISTB':
if axis not in FLIP_AXIS:

for name, image in self.get_images_dict(subject).items():
is_2d = image.is_2d()
axes = [a for a in self.axes if a != 2] if is_2d else self.axes
for name, _ in self.get_images_dict(subject).items():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for name, _ in self.get_images_dict(subject).items():
for name in self.get_images_dict(subject):

@@ -134,6 +142,7 @@ class Motion(IntensityTransform, FourierTransform):
simulate motion artifacts for data augmentation.

Args:
axis: Integer representing the axis along which the simulated movements
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sentence is incomplete.

reader.ReadImageInformation()
num_channels = reader.GetNumberOfComponents()
num_dimensions = reader.GetDimension()
try:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes in this method related to the goal of this PR?

axes: Tuple[int, ...],
num_ghosts_range: Tuple[int, int],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of this change?

@@ -211,35 +228,18 @@ def get_rigid_transforms(
) -> List[sitk.Euler3DTransform]:
center_ijk = np.array(image.GetSize()) / 2
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk)
identity = np.eye(4)
matrices = [identity]
ident_transform = sitk.Euler3DTransform()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename "ident" -> "identity"?

Comment on lines -87 to -95
is_2d = subject.get_first_image().is_2d()
if is_2d and 2 in self.axes:
warnings.warn(
f'Input image is 2D, but "2" is in axes: {self.axes}',
RuntimeWarning,
stacklevel=2,
)
self.axes = list(self.axes)
self.axes.remove(2)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this not needed anymore?

super().__init__(**kwargs)
self.axes = _parse_axes(axes)
self.args_names = ('axes',)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the type checker is checking that this is a tuple somewhere, but I might be wrong.

Comment on lines +100 to +107
@staticmethod
def flip_image(image, axes):
spatial_axes = np.array(axes, int) + 1
data = image.numpy()
data = np.flip(data, axis=spatial_axes)
data = data.copy() # remove negative strides
data = torch.as_tensor(data)
image.set_data(data)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this need to be moved here?

@@ -386,7 +385,7 @@ def guess_external_viewer() -> Optional[Path]:
def parse_spatial_shape(shape):
result = to_tuple(shape, length=3)
for n in result:
if n < 1 or n % 1:
if isinstance(n, (str, bytes)) or n < 1 or n % 1:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need to check for bytes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Randomize k-space filling axis in RandomMotion
2 participants