Skip to content

Commit

Permalink
[Torch FX] Post Quantize Weights Compression (#2984)
Browse files Browse the repository at this point in the history
### Changes

Transformation for removing fake quantize nodes and saving all weights
to disk in int8 format after quantization. It works as follows:
1. Reshape the scale if qdq operation is per-channel.
2. Pattern match the quantize-dequantize nodes.
3. Filter the matches to only include quantize-dequantize ops with
constant input.
4. Replace with the multiplication of the scale and input.

### Reason for changes

To compress the model after quantization

### Tests

Add `test_post_quantization_compression()` in
`tests/torch/fx/test_model_transformer.py` which checks the data type of
all weights in the model after applying quantization and also checks the
value after the decompression step (element-wise multiplication
operation).

### Tickets
#2766

---------

Co-authored-by: Daniil Lyakhov <[email protected]>
  • Loading branch information
anzr299 and daniil-lyakhov authored Oct 21, 2024
1 parent 7385c41 commit 7c94b23
Show file tree
Hide file tree
Showing 19 changed files with 15,068 additions and 3,490 deletions.
31 changes: 31 additions & 0 deletions nncf/experimental/torch/fx/quantization/backend_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters


class FXBackendParameters:
COMPRESS_WEIGHTS = "compress_weights"


def is_weight_compression_needed(advanced_parameters: Optional[AdvancedQuantizationParameters]) -> bool:
"""
Determines whether weight compression is needed based on the provided
advanced quantization parameters.
:param advanced_parameters: Advanced quantization parameters.
:return: True if weight compression is needed, False otherwise.
"""
if advanced_parameters is not None and advanced_parameters.backend_params is not None:
return advanced_parameters.backend_params.get(FXBackendParameters.COMPRESS_WEIGHTS, True)
return True
11 changes: 11 additions & 0 deletions nncf/experimental/torch/fx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
from nncf.common.logging import nncf_logger
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
from nncf.experimental.torch.fx.quantization.backend_parameters import is_weight_compression_needed
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.experimental.torch.fx.transformations import compress_post_quantize_transformation
from nncf.experimental.torch.fx.transformations import fq_weights_transformation
from nncf.experimental.torch.fx.transformations import revert_quantization_transformations
from nncf.experimental.torch.fx.transformations import shared_constants_unification_transformation
from nncf.parameters import BackupMode
Expand Down Expand Up @@ -94,6 +97,11 @@ def quantize_impl(
# bias configuration.
revert_quantization_transformations(quantized_model)

if is_weight_compression_needed(advanced_parameters):
compress_post_quantize_transformation(quantized_model)
else:
fq_weights_transformation(quantized_model)

# Magic. Without this call compiled model
# is not preformant
quantized_model = GraphModule(quantized_model, quantized_model.graph)
Expand All @@ -107,6 +115,9 @@ def quantize_impl(

quantized_model.meta.update(original_graph_meta)
quantized_model = _disallow_eval_train(quantized_model)
# Each transformation adds a duplicate tensor value to the model buffer.
# This step removes the duplicates tensor values from the buffer.
quantized_model = GraphModule(quantized_model, quantized_model.graph)

return quantized_model

Expand Down
Loading

0 comments on commit 7c94b23

Please sign in to comment.