Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scale parameter for torch.nn.functional.scaled_dot_product_attention #2294

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6892,9 +6892,17 @@ def _broadcast_tensor_to_same_batch_dims(x: Var, batch_dims: List[int]) -> Var:

# When len(inputs) == 7, the inputs are (q, k, v, attn_mask, dropout, is_causal, scale)
if len(inputs) == 7 and inputs[6] is not None:
raise NotImplementedError(
"scaled_dot_product_attention op: scale parameter is not handled."
)
default_scale = q.shape[-1] ** -0.5
scale = inputs[6]

if scale.val == default_scale:
# No need to apply scale if it is the default value
pass
elif scale.val != 1.0:
# Apply correction to desired scale since default_scale
# will be applied downstream regardless
corrected_scale = scale.val / default_scale
q = mb.mul(x=corrected_scale, y=q, name=q.name + "_scaled")

if attn_mask is not None and is_causal:
raise ValueError(
Expand Down
71 changes: 71 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11324,6 +11324,77 @@ def test_attn_mask(
input_as_shape=False,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, minimum_deployment_target, seq_lengths, bool_mask, dynamic, scale",
itertools.product(
compute_units,
backends,
frontends,
[None, ct.target.iOS18],
[(5, 5), (7, 5)],
[False, True],
[False, True],
[None, 0.5, 1.],
),
)
def test_scale_argument(
self,
compute_unit,
backend,
frontend,
minimum_deployment_target,
seq_lengths,
bool_mask,
dynamic,
scale,
):
if frontend == TorchFrontend.TORCHSCRIPT and bool_mask:
pytest.xfail(
"rdar://110499660 ([CI][Bug] test_attn_mask is occasionally failing when bool_mask = True)"
)

source_seq_len, target_seq_len = seq_lengths
query_shape = (2, 3, target_seq_len, 7)
key_shape = (2, 3, source_seq_len, 7)
value_shape = key_shape
mask_shape = (target_seq_len, source_seq_len)

query = generate_input_data(query_shape)
key = generate_input_data(key_shape)
value = generate_input_data(value_shape)
if bool_mask:
mask = torch.rand(mask_shape) > 0.5
mask = mask.bool()
else:
mask = generate_input_data(mask_shape)

model = ModuleWrapper(
function=nn.functional.scaled_dot_product_attention,
kwargs={"scale": scale},
)

if dynamic:
converter_input_type = [
ct.TensorType(
shape=(ct.RangeDim(upper_bound=10, default=input_data.shape[0]),)
+ input_data.shape[1:]
)
for input_data in [query, key, value, mask]
]
else:
converter_input_type = None

self.run_compare_torch(
(query, key, value, mask),
model,
frontend=frontend,
backend=backend,
converter_input_type=converter_input_type,
compute_unit=compute_unit,
minimum_deployment_target=minimum_deployment_target,
input_as_shape=False,
)

@pytest.mark.parametrize(
"compute_unit, backend, frontend, minimum_deployment_target, mask_as_input, dynamic",
itertools.product(
Expand Down