-
Notifications
You must be signed in to change notification settings - Fork 240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Standardize Axes in Random Transforms. Add Random Axis to RandomMotion #1185
base: main
Are you sure you want to change the base?
Conversation
…e _flip_image and _parse_restore into class
All checks pass now. |
@@ -392,6 +402,15 @@ def axis_name_to_index(self, axis: str) -> int: | |||
if not isinstance(axis, str): | |||
raise ValueError('Axis must be a string') | |||
axis = axis[0].upper() | |||
if axis not in 'LRPAISTB': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if axis not in 'LRPAISTB': | |
if axis not in FLIP_AXIS: |
for name, image in self.get_images_dict(subject).items(): | ||
is_2d = image.is_2d() | ||
axes = [a for a in self.axes if a != 2] if is_2d else self.axes | ||
for name, _ in self.get_images_dict(subject).items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for name, _ in self.get_images_dict(subject).items(): | |
for name in self.get_images_dict(subject): |
@@ -134,6 +142,7 @@ class Motion(IntensityTransform, FourierTransform): | |||
simulate motion artifacts for data augmentation. | |||
|
|||
Args: | |||
axis: Integer representing the axis along which the simulated movements |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this sentence is incomplete.
reader.ReadImageInformation() | ||
num_channels = reader.GetNumberOfComponents() | ||
num_dimensions = reader.GetDimension() | ||
try: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the changes in this method related to the goal of this PR?
axes: Tuple[int, ...], | ||
num_ghosts_range: Tuple[int, int], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the point of this change?
@@ -211,35 +228,18 @@ def get_rigid_transforms( | |||
) -> List[sitk.Euler3DTransform]: | |||
center_ijk = np.array(image.GetSize()) / 2 | |||
center_lps = image.TransformContinuousIndexToPhysicalPoint(center_ijk) | |||
identity = np.eye(4) | |||
matrices = [identity] | |||
ident_transform = sitk.Euler3DTransform() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename "ident" -> "identity"?
is_2d = subject.get_first_image().is_2d() | ||
if is_2d and 2 in self.axes: | ||
warnings.warn( | ||
f'Input image is 2D, but "2" is in axes: {self.axes}', | ||
RuntimeWarning, | ||
stacklevel=2, | ||
) | ||
self.axes = list(self.axes) | ||
self.axes.remove(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this not needed anymore?
super().__init__(**kwargs) | ||
self.axes = _parse_axes(axes) | ||
self.args_names = ('axes',) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect the type checker is checking that this is a tuple somewhere, but I might be wrong.
@staticmethod | ||
def flip_image(image, axes): | ||
spatial_axes = np.array(axes, int) + 1 | ||
data = image.numpy() | ||
data = np.flip(data, axis=spatial_axes) | ||
data = data.copy() # remove negative strides | ||
data = torch.as_tensor(data) | ||
image.set_data(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did this need to be moved here?
@@ -386,7 +385,7 @@ def guess_external_viewer() -> Optional[Path]: | |||
def parse_spatial_shape(shape): | |||
result = to_tuple(shape, length=3) | |||
for n in result: | |||
if n < 1 or n % 1: | |||
if isinstance(n, (str, bytes)) or n < 1 or n % 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need to check for bytes?
Fixes #81
Description
This pull request standardizes the use of axes in random transforms. It also adds random axes options in the
RandomMotion
transform. Includes the following changes:to_tuple
to see strings as singular values (this also required adding number of changes to checks for strings in areas where numbers are necessary). Tests all pass.parse_axes
andensure_axes_indices
to baseTransform
class. Standardizes the use of axes in all transforms. Capable of accepting integer and specific string axis values. This is now used inRandomFlip
,RandomMotion
,RandomGhosting
,RandomAnisotropy
. Tests were added to check these additional axis values.RandomMotion
. This now works properly on 2D images and can apply motion to any axis of 3D images. Also removed redundant change between matrix and transform.Checklist
CONTRIBUTING
docs and have a developer setup (especially important arepre-commit
andpytest
)pytest
make html
inside thedocs/
folder