Skip to content

v1.1.0: ORTTrainer, Seq2SeqORTTrainer, ONNX Runtime optimization and quantization API improvements

Compare
Choose a tag to compare
@echarlaix echarlaix released this 01 Apr 17:01
· 1031 commits to main since this release

ORTTrainer and Seq2SeqORTTrainer

The ORTTrainer and Seq2SeqORTTrainer are two newly experimental classes.

  • Both ORTTrainer and Seq2SeqORTTrainer were created to have a similar user-facing API as the Trainer and Seq2SeqTrainer of the Transformers library.
  • ORTTrainer allows the usage of the ONNX Runtime backend to train a given PyTorch model in order to accelerate training. ONNX Runtime will run the forward and backward passes using an optimized automatically-exported ONNX computation graph, while the rest of the training loop is executed by native PyTorch.
  • ORTTrainer allows the usage of ONNX Runtime inferencing during both the evaluation and the prediction step.
  • For Seq2SeqORTTrainer, ONNX Runtime inferencing is incompatible with --predict_with_generate, as the generate method is not supported yet.

ONNX Runtime optimization and quantization APIs improvements

The ORTQuantizer and ORTOptimizer classes underwent a massive refactoring that should allow a simpler and more flexible user-facing API.

  • Addition of the possibility to iteratively compute the quantization activation ranges when applying static quantization by using the ORTQuantizer method partial_fit. This is especially useful when using memory-hungry calibration methods such as Entropy and Percentile methods.
  • When using the MinMax calibration method, it is now possible to compute the moving average of the minimum and maximum values representing the activations quantization ranges instead of the global minimum and maximum (feature available with onnxruntime v1.11.0 or higher).
  • The classes OptimizationConfig, QuantizationConfig and CalibrationConfig were added in order to better segment the different ONNX Runtime related parameters instead of having one unique configuration ORTConfig.
  • The QuantizationPreprocessor class was added in order to find the nodes to include and / or exclude from quantization, by finding the nodes following a given pattern (such as the nodes forming LayerNorm for example). This is particularly useful in the context of static quantization, where the quantization of modules such as LayerNorm or GELU are responsible of important drop in accuracy.