Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting TF fake_quant_with_min_max functions #20641

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

doncarlos999
Copy link

@doncarlos999 doncarlos999 commented Dec 13, 2024

Based on the discussion here: #20319 I started porting the fake_quant_with_min_max functions from tensorflow to keras3.
This PR contains those ported functions and the relevant tests from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/tests/fake_quant_ops_test.py.

I didn't implement tf.quantization.fake_quant_with_min_max_vars as it looks the same as tf.quantization.fake_quant_with_min_max_args. But, I can add this one too if required.

For the CLA I am waiting on our CTO to add me to the Edge Impulse <-> Google CLA. But I figured that I can work on revisions to the PR in the meantime.

CC: @matpalm, @dansitu, @james77777778

* adds fake_quant_with_min_max functions from TF to keras3
Copy link

google-cla bot commented Dec 13, 2024

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Dec 13, 2024

Codecov Report

Attention: Patch coverage is 89.80892% with 16 lines in your changes missing coverage. Please review.

Project coverage is 72.50%. Comparing base (84b531c) to head (5c48be2).

Files with missing lines Patch % Lines
keras/src/quantizers/quantizers.py 91.72% 7 Missing and 5 partials ⚠️
keras/api/_tf_keras/keras/quantizers/__init__.py 0.00% 4 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (84b531c) and HEAD (5c48be2). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (84b531c) HEAD (5c48be2)
keras 5 3
keras-numpy 1 0
keras-jax 1 0
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20641      +/-   ##
==========================================
- Coverage   81.95%   72.50%   -9.46%     
==========================================
  Files         543      543              
  Lines       50663    50820     +157     
  Branches     7828     7842      +14     
==========================================
- Hits        41523    36849    -4674     
- Misses       7246    12204    +4958     
+ Partials     1894     1767     -127     
Flag Coverage Δ
keras 72.37% <89.17%> (-9.42%) ⬇️
keras-jax ?
keras-numpy ?
keras-openvino 29.89% <12.10%> (-0.06%) ⬇️
keras-tensorflow 64.73% <89.17%> (+0.07%) ⬆️
keras-torch 63.80% <88.53%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/src/quantizers/quantizers.py Outdated Show resolved Hide resolved
keras/src/quantizers/quantizers.py Outdated Show resolved Hide resolved
keras/src/quantizers/quantizers.py Outdated Show resolved Hide resolved
Copy link
Contributor

@james77777778 james77777778 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @doncarlos999
I have left some comments.

Additionally, I think we still need fake_quant_with_min_max_vars, as it is used in TFMOT:
https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/python/core/quantization/keras/quant_ops.py#L340

@@ -12,4 +12,14 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT purposes, I don't think we need *_gradient ops.

fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel_gradient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT purposes, I don't think we need *_gradient ops.

@@ -12,4 +12,14 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel_gradient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@@ -6,6 +6,16 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_args_gradient,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here.

@keras_export(
"keras.quantizers.fake_quant_with_min_max_vars_per_channel_gradient"
)
def fake_quant_with_min_max_vars_per_channel_gradient(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT purposes, I don't think we need *_gradient ops.

@@ -100,3 +100,759 @@ def test_quantize_and_dequantize(self):
)
# A loose assertion due to an expected quantization error
self.assertAllClose(qdq_values, values, atol=5e-1)

def _TestOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use @parameterized.named_parameters and named_product to organize similar tests like this one:
https://github.com/keras-team/keras/blob/master/keras/src/ops/nn_test.py#L2355-L2365

num_bits=num_bits,
narrow_range=narrow_range,
)
self.assertAllClose(outputs, expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think verifying the output values alone is not sufficient.
We can add an assertion for the gradient, similar to this one:
https://github.com/keras-team/keras/blob/master/keras/src/layers/core/dense_test.py#L584

)
self.assertAllClose(outputs, expected)

def _TestGradOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT purposes, I don't think we need *_gradient ops.

)
self.assertAllClose(outputs, expected)

def _TestChannelsGradOp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For QAT purposes, I don't think we need *_gradient ops.

@doncarlos999
Copy link
Author

@james77777778 thank you for the review. I'm working on revisions now.

Regarding the *_gradient functions, I added those as a way to test that the gradients that come from the main functions were being calculated correctly. Should we keep them just for testing purposes but not expose them in the public facing API? If not I will remove them.

@james77777778
Copy link
Contributor

Regarding the *_gradient functions, I added those as a way to test that the gradients that come from the main functions were being calculated correctly. Should we keep them just for testing purposes but not expose them in the public facing API? If not I will remove them.

We can test the gradients of fake_* functions using:

  • tensorflow: tf.GradientTape() + tape.gradient
  • torch: loss.backward() + variable.grad
  • jax: jax.grad

You can refer to this test for an example:
https://github.com/keras-team/keras/blob/master/keras/src/layers/core/dense_test.py#L584-L649

Using a different function, separate from the user-facing function, for testing purposes seems redundant and fragile to me. However, we should wait for calls from @fchollet

@doncarlos999
Copy link
Author

I agree that having two separate functions is fragile I simply kept the functions separate as that was how they were tested in the Tensorflow repo.
I will start adding tests based on your example in the meantime. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants