Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix "same" padding torch issue #20270

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _compute_padding_length(


def _apply_same_padding(
inputs, kernel_size, strides, operation_type, dilation_rate=1
inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1
):
"""Apply same padding to the input tensor.

Expand Down Expand Up @@ -174,7 +174,10 @@ def _apply_same_padding(
spatial_shape[i], kernel_size[i], strides[i], dilation_rate[i]
)
mode = "constant"
padding = (padding_size,) + padding
if data_format == "channels_last":
padding = (padding_size,) + padding
else:
padding = padding + (padding_size,)

if all([left == right for left, right in padding]):
return inputs, [left for left, _ in padding]
Expand Down Expand Up @@ -252,7 +255,7 @@ def max_pool(
# Torch does not natively support `"same"` padding, we need to manually
# apply the right amount of padding to `inputs`.
inputs, padding = _apply_same_padding(
inputs, pool_size, strides, operation_type="pooling"
inputs, pool_size, strides, data_format, operation_type="pooling"
)
else:
padding = 0
Expand Down Expand Up @@ -312,7 +315,7 @@ def average_pool(
# Torch does not natively support `"same"` padding, we need to manually
# apply the right amount of padding to `inputs`.
inputs, padding = _apply_same_padding(
inputs, pool_size, strides, operation_type="pooling"
inputs, pool_size, strides, data_format, operation_type="pooling"
)
else:
padding = 0
Expand Down Expand Up @@ -377,6 +380,7 @@ def conv(
inputs,
kernel.shape[2:],
strides,
data_format,
operation_type="conv",
dilation_rate=dilation_rate,
)
Expand Down
1 change: 1 addition & 0 deletions keras/src/layers/pooling/average_pooling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def test_average_pooling1d(
(2, 1, "same", "channels_first", (3, 5, 5, 4), (3, 5, 5, 4)),
((2, 3), (2, 2), "valid", "channels_last", (3, 5, 5, 4), (3, 2, 2, 4)),
((2, 3), (2, 2), "same", "channels_last", (3, 5, 5, 4), (3, 3, 3, 4)),
((2, 3), (3, 3), "same", "channels_first", (3, 5, 5, 4), (3, 5, 2, 2)),
)
def test_average_pooling2d(
self,
Expand Down
12 changes: 12 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,18 @@ def test_average_pool_same_padding(self):
knn.average_pool(x, 2, (2, 1), padding="same"),
np_avgpool2d(x, 2, (2, 1), padding="same", data_format=data_format),
)
# Test 2D average pooling with different pool size.
if data_format == "channels_last":
input_shape = (2, 10, 9, 3)
else:
input_shape = (2, 3, 10, 9)
x = np.arange(540, dtype=float).reshape(input_shape)
self.assertAllClose(
knn.average_pool(x, (2, 3), (3, 3), padding="same"),
np_avgpool2d(
x, (2, 3), (3, 3), padding="same", data_format=data_format
),
)

@parameterized.product(
strides=(1, 2, 3),
Expand Down
Loading