Fix FP8 ONNX export to use standard QuantizeLinear/DequantizeLinear ops#1037
Fix FP8 ONNX export to use standard QuantizeLinear/DequantizeLinear ops#1037IgorBaratta wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
📝 WalkthroughWalkthroughBroaden FP8 ONNX handling: detect FLOAT8 constants, fold FP8 QDQ patterns for both TRT-specific and standard QuantizeLinear/DequantizeLinear paths, export Float8 dtypes from PyTorch, and add tests ensuring FP8 ONNX exports preserve shape information. Changes
Sequence Diagram(s)sequenceDiagram
participant PyTorchExporter as PyTorch Exporter
participant ExportONNX as export_onnx.py
participant ONNXGraph as ONNX Graph (nodes)
participant Surgeon as surgeon_utils
participant FP8Exporter as fp8_exporter
participant ShapeInfer as ONNX Shape Inference / Tests
PyTorchExporter->>ExportONNX: request export with Float8 dtype
ExportONNX->>ONNXGraph: emit QuantizeLinear / DequantizeLinear (FLOAT8) nodes
ONNXGraph->>Surgeon: detect QDQ patterns (is_fp8_constant)
Surgeon->>FP8Exporter: fold QDQ -> materialize FP8 weight constants
FP8Exporter->>ONNXGraph: replace DQ inputs/outputs, adjust dtypes
ONNXGraph->>ShapeInfer: run shape inference / tests
ShapeInfer-->>PyTorchExporter: report preserved shapes / issues
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~28 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/unit/torch/quantization/test_fp8_onnx_shape.py (1)
67-71: Bareexcept Exceptionis overly broad.Catching all exceptions obscures which specific error indicates "newer ONNX rejects unknown domains." Consider catching
onnx.onnx_cpp2py_export.shape_inference.InferenceErroror similar specific exception.Suggested refinement
try: inferred = onnx.shape_inference.infer_shapes(model, strict_mode=False) - except Exception: + except onnx.shape_inference.InferenceError: # Newer ONNX rejects unknown domains outright — shape inference is impossible. returnIf the specific exception type varies across ONNX versions, the current approach is acceptable but should be documented.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py` around lines 67 - 71, Replace the broad try/except around onnx.shape_inference.infer_shapes(model, strict_mode=False) with a targeted exception handler: catch the ONNX-specific inference error (e.g., onnx.onnx_cpp2py_export.shape_inference.InferenceError or the appropriate exception class exported by your ONNX version) and return in that case; if multiple ONNX versions raise different exception types, catch those specific types explicitly and only fall back to a broad except that re-raises unexpected exceptions (or document why a broad catch is required). This change should target the infer_shapes call and its surrounding exception handling so only expected ONNX shape-inference failures are swallowed.modelopt/onnx/export/fp8_exporter.py (1)
28-32: Duplicate helper function — consolidate withsurgeon_utils.py.This function is identical to
_is_float8_constantinmodelopt/onnx/llm_export_utils/surgeon_utils.py. Extract to a shared module to maintain a single source of truth.Suggested location for shared utility
Create
modelopt/onnx/utils.py(or similar) with:def is_float8_constant(const: gs.Constant) -> bool: """Return True if the gs.Constant holds a FLOAT8E4M3FN tensor.""" if isinstance(const.values, LazyValues): return const.values._tensor.data_type == onnx.TensorProto.FLOAT8E4M3FN return FalseThen import from both
fp8_exporter.pyandsurgeon_utils.py.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/export/fp8_exporter.py` around lines 28 - 32, Extract the duplicate _is_float8_constant implementation into a single shared utility (e.g., add is_float8_constant in a new module modelopt/onnx/utils.py) and have both fp8_exporter.py and surgeon_utils.py import and use that shared function instead of defining their own copies; remove the local _is_float8_constant definition from fp8_exporter.py and replace its usage with the imported is_float8_constant (ensure LazyValues and onnx.TensorProto references remain available via existing imports or by importing them into the new utils module).modelopt/onnx/llm_export_utils/surgeon_utils.py (1)
27-31: Accessing private_tensorattribute may break on library updates.The helper accesses
const.values._tensor.data_type, which is an internal attribute ofLazyValues. This could break ifonnx_graphsurgeonchanges its internal implementation.Additionally, this function is duplicated verbatim in
modelopt/onnx/export/fp8_exporter.py. Consider extracting it to a shared utility module (e.g.,modelopt/onnx/utils.py) to avoid drift.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 27 - 31, The helper _is_float8_constant currently accesses the private LazyValues attribute _tensor which is fragile; replace that direct access with a safe lookup of public attributes (e.g., prefer const.values.tensor.data_type or const.values.raw.data_type or const.values.dtype using getattr/hasattr checks) so you only use documented/public fields when determining onnx.TensorProto.FLOAT8E4M3FN, and move the function out of both surgeon_utils.py and fp8_exporter.py into a single shared utility (e.g., create modelopt/onnx/utils.py with _is_float8_constant and import it from modelopt/onnx/llm_export_utils/surgeon_utils.py and modelopt/onnx/export/fp8_exporter.py) to avoid duplication and future drift.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py`:
- Around line 123-136: The shape lookup currently only populates shape_by_name
from inferred.graph.value_info and therefore misses shapes that are only
declared in inferred.graph.output; update the population logic for shape_by_name
(the code around shape_by_name creation and the loop over
inferred.graph.value_info) to also iterate inferred.graph.output and add each
output's name -> shape (extracting dim.dim_value or -1 like the existing logic)
so that later checks for node.op_type in ("QuantizeLinear","DequantizeLinear")
find graph outputs as well; ensure you reference the same shape extraction
behavior and keep the missing collection logic for node.output unchanged.
---
Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 28-32: Extract the duplicate _is_float8_constant implementation
into a single shared utility (e.g., add is_float8_constant in a new module
modelopt/onnx/utils.py) and have both fp8_exporter.py and surgeon_utils.py
import and use that shared function instead of defining their own copies; remove
the local _is_float8_constant definition from fp8_exporter.py and replace its
usage with the imported is_float8_constant (ensure LazyValues and
onnx.TensorProto references remain available via existing imports or by
importing them into the new utils module).
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 27-31: The helper _is_float8_constant currently accesses the
private LazyValues attribute _tensor which is fragile; replace that direct
access with a safe lookup of public attributes (e.g., prefer
const.values.tensor.data_type or const.values.raw.data_type or
const.values.dtype using getattr/hasattr checks) so you only use
documented/public fields when determining onnx.TensorProto.FLOAT8E4M3FN, and
move the function out of both surgeon_utils.py and fp8_exporter.py into a single
shared utility (e.g., create modelopt/onnx/utils.py with _is_float8_constant and
import it from modelopt/onnx/llm_export_utils/surgeon_utils.py and
modelopt/onnx/export/fp8_exporter.py) to avoid duplication and future drift.
In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py`:
- Around line 67-71: Replace the broad try/except around
onnx.shape_inference.infer_shapes(model, strict_mode=False) with a targeted
exception handler: catch the ONNX-specific inference error (e.g.,
onnx.onnx_cpp2py_export.shape_inference.InferenceError or the appropriate
exception class exported by your ONNX version) and return in that case; if
multiple ONNX versions raise different exception types, catch those specific
types explicitly and only fall back to a broad except that re-raises unexpected
exceptions (or document why a broad catch is required). This change should
target the infer_shapes call and its surrounding exception handling so only
expected ONNX shape-inference failures are swallowed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 76bdd9c0-d449-4f22-a38d-75c65c5f073c
📒 Files selected for processing (4)
modelopt/onnx/export/fp8_exporter.pymodelopt/onnx/llm_export_utils/surgeon_utils.pymodelopt/torch/quantization/export_onnx.pytests/unit/torch/quantization/test_fp8_onnx_shape.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1037 +/- ##
==========================================
+ Coverage 70.07% 70.26% +0.19%
==========================================
Files 221 221
Lines 25499 25512 +13
==========================================
+ Hits 17869 17927 +58
+ Misses 7630 7585 -45 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
…ore, narrow except, include graph outputs in shape lookup Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 23-24: The compress_weights function uses LazyValues(tensor) but
LazyValues is not imported, causing a NameError; fix by importing LazyValues
from its defining module (where LazyValues is implemented) at the top of this
file and ensure the symbol is referenced in compress_weights (the use at
LazyValues(tensor) around lines 97-98), e.g., add an import for LazyValues
alongside is_fp8_constant so compress_weights can construct LazyValues without
runtime errors.
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 24-25: The file is missing an import for LazyValues which causes a
NameError when fold_fp8_qdq_to_dq() processes FP8 weights; add LazyValues to the
import list (alongside is_fp8_constant) so the symbol is available where used in
fold_fp8_qdq_to_dq and related FP8 handling around lines referencing LazyValues.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 59c379b3-11cc-4d4e-b563-ae338d0d3667
📒 Files selected for processing (5)
modelopt/onnx/export/fp8_exporter.pymodelopt/onnx/llm_export_utils/surgeon_utils.pymodelopt/onnx/utils.pymodelopt/torch/quantization/export_onnx.pytests/unit/torch/quantization/test_fp8_onnx_shape.py
Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 105-108: The code dereferences node.outputs[0].outputs[0] into
dq_op before checking its existence, which can raise IndexError; modify the
logic in surgeon_utils.py so you first verify node.outputs is non-empty and
node.outputs[0].outputs is non-empty (e.g., if not node.outputs or not
node.outputs[0].outputs: raise a descriptive error or return) before assigning
dq_op, then perform the existing assert on dq_op.op; reference the variables
node and dq_op to locate the fix.
- Around line 110-113: The FP8 folding currently computes numpy_weights from
(torch_weights / torch_scale) but ignores QuantizeLinear's zero_point; update
the folding in surgeon_utils.py to either validate that torch_zero_point == 0
and raise/log if not, or incorporate the zero point into the conversion by
computing (torch_weights / torch_scale + torch_zero_point) before casting to
FP8; modify the numpy_weights creation (referencing numpy_weights,
torch_weights, torch_scale, torch_zero_point) to include this change and ensure
correct rounding/typing for the subsequent .to(torch.float8_e4m3fn).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 35b56199-39c9-488f-a001-21d875d55823
📒 Files selected for processing (3)
modelopt/onnx/export/fp8_exporter.pymodelopt/onnx/llm_export_utils/surgeon_utils.pytests/unit/torch/quantization/test_fp8_onnx_shape.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unit/torch/quantization/test_fp8_onnx_shape.py
- modelopt/onnx/export/fp8_exporter.py
| dq_op = node.outputs[0].outputs[0] | ||
| assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), ( | ||
| f"QDQ does not occur in pairs. You reached {dq_op.op}" | ||
| ) |
There was a problem hiding this comment.
Guard Q output-chain indexing before dereference.
At Line 105, node.outputs[0].outputs[0] can raise IndexError on non-canonical graphs before your pair check at Line 106 runs.
🔧 Proposed fix
- dq_op = node.outputs[0].outputs[0]
- assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
- f"QDQ does not occur in pairs. You reached {dq_op.op}"
- )
+ if not node.outputs or not node.outputs[0].outputs:
+ continue
+ dq_op = node.outputs[0].outputs[0]
+ if dq_op.op not in ("TRT_FP8DequantizeLinear", "DequantizeLinear"):
+ continue🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 105 - 108, The
code dereferences node.outputs[0].outputs[0] into dq_op before checking its
existence, which can raise IndexError; modify the logic in surgeon_utils.py so
you first verify node.outputs is non-empty and node.outputs[0].outputs is
non-empty (e.g., if not node.outputs or not node.outputs[0].outputs: raise a
descriptive error or return) before assigning dq_op, then perform the existing
assert on dq_op.op; reference the variables node and dq_op to locate the fix.
| # Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8. | ||
| numpy_weights = ( | ||
| (torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy() | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# 1) Locate FP8 QuantizeLinear construction and zero_point setup in exporter paths.
rg -nP --type=py -C3 '\bQuantizeLinear\b|TRT_FP8QuantizeLinear|FLOAT8E4M3FN|zero_point|float8_e4m3fn'
# 2) Locate folding assumptions and any explicit zero_point validation.
rg -nP --type=py -C3 'fold_fp8_qdq_to_dq|is_std_fp8_q|is_fp8_constant|zero_point|axis'
# 3) Check whether tests cover non-zero zero_point / axis cases for FP8 fold/export.
rg -nP --type=py -C3 'test_.*fp8.*(zero_point|axis)|fp8_onnx_export_shape|fold_fp8_qdq_to_dq'Repository: NVIDIA/Model-Optimizer
Length of output: 50379
🏁 Script executed:
# First, let's read the specific file and lines mentioned in the review
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '100,125p'Repository: NVIDIA/Model-Optimizer
Length of output: 1413
🏁 Script executed:
# Search for the function containing this code and zero_point handling
rg -A20 -B5 "Replace it with Dequantize with FP8 weights" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1146
🏁 Script executed:
# Look for fold_fp8_qdq or similar folding functions
fd surgeon_utils.py -x cat -n {} | head -200Repository: NVIDIA/Model-Optimizer
Length of output: 5481
🏁 Script executed:
# Search for zero_point handling in the surgeon_utils file and related functions
rg -B10 -A10 "zero_point" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check ONNX standard or comments about FP8 zero_point semantics
rg -C5 "TRT_FP8QuantizeLinear|is_std_fp8_q" modelopt/onnx/ --type=py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 3014
🏁 Script executed:
# Look for FP8 quantization logic and zero_point usage
rg -C5 "float8_e4m3fn.*zero|zero.*float8_e4m3fn" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check how FP8 QuantizeLinear nodes are created - what zero_point values are used
rg -B5 -A5 "QuantizeLinear.*fp8|float8.*QuantizeLinear" modelopt/onnx/ --type=py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for zero_point creation/assignment in FP8 contexts
rg "zero_point.*=.*0|zero_point.*fp8|fp8.*zero" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 145
🏁 Script executed:
# Check what the ONNX standard and comments say about FP8 and zero_point
rg -B3 -A3 "FP8|float8_e4m3fn" modelopt/onnx/quantization/fp8.py | head -120Repository: NVIDIA/Model-Optimizer
Length of output: 2904
🏁 Script executed:
# Look for how zero_point is handled when creating FP8 QuantizeLinear nodes
rg -B10 -A10 "create.*qdq|QuantizeLinear.*zero" modelopt/onnx/quantization/autotune/export_utils.py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 2035
🏁 Script executed:
# Check how int8_to_fp8 handles zero_point conversion
rg -B5 -A15 "def int8_to_fp8" modelopt/onnx/quantization/fp8.pyRepository: NVIDIA/Model-Optimizer
Length of output: 797
🏁 Script executed:
# Search for any comments about FP8 and zero_point semantics in the codebase
rg -i "fp8.*zero_point|zero_point.*fp8" modelopt/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for test files that test FP8 folding with zero_point
fd "test.*fp8|fp8.*test" -t f modelopt/ | xargs rg -l "zero_point|fold_fp8" 2>/dev/null | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 325
🏁 Script executed:
# Get the full int8_to_fp8 function to see how it handles zero_point
cat -n modelopt/onnx/quantization/fp8.py | sed -n '40,130p'Repository: NVIDIA/Model-Optimizer
Length of output: 4531
🏁 Script executed:
# Check if there are any tests that verify folding with non-zero zero_point
rg -l "fold_fp8" tests/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the actual ONNX spec documentation in comments or code
rg -B5 -A5 "QuantizeLinear.*zero_point|zero_point.*QuantizeLinear" modelopt/onnx/quantization/fp8.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at what zero_point values are used when creating FP8 nodes
rg -B10 -A10 "get_zero_point_for_quant_type" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 1393
🏁 Script executed:
# Check if FP8 quantization ever uses non-zero zero_point
rg -B5 -A5 "quant_type.*fp8|fp8.*quant" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 619
🏁 Script executed:
# Look at the full fold_fp8_qdq_to_dq function and see how zero_point is accessed
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p'Repository: NVIDIA/Model-Optimizer
Length of output: 3096
🏁 Script executed:
# Check if DequantizeLinear with FP8 uses zero_point in its operation
rg -B3 -A3 "DequantizeLinear.*zero_point|zero_point.*DequantizeLinear" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if FP8 is ever created with non-zero zero_point
rg -B5 -A10 "default_q_zero_point.*fp8|fp8.*default.*zero" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at tests to see if they verify zero_point handling for FP8
fd "test.*fp8" tests/ --type=py -x grep -l "zero_point" {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 235
🏁 Script executed:
# Check the ONNX standard comment in create_qdq_nodes for how zero_point is used
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '149,200p'Repository: NVIDIA/Model-Optimizer
Length of output: 2278
🏁 Script executed:
# Verify: does FP8 always use zero_point=0 in the export pipeline?
rg -B10 -A10 "export_qdq_onnx.*fp8|fp8.*export_qdq|config.default_quant_type.*fp8" modelopt/onnx/quantization/autotune/ --type=py | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 2026
🏁 Script executed:
# Check what zero_point is used when creating FP8 QuantizeLinear nodes
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '215,230p'Repository: NVIDIA/Model-Optimizer
Length of output: 945
🏁 Script executed:
# Verify how int8_to_fp8 handles zero_point: does it convert INT8 zero_point to FP8?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,115p'Repository: NVIDIA/Model-Optimizer
Length of output: 1146
🏁 Script executed:
# Search for where fold_fp8_qdq_to_dq is called - is there a constraint on zero_point?
rg -B5 -A5 "fold_fp8_qdq_to_dq" modelopt/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 1099
🏁 Script executed:
# Check ONNX QuantizeLinear semantics: does DequantizeLinear with FP8 use zero_point?
rg -B10 -A10 "TRT_FP8Dequantize|FLOAT8E4M3FN.*dequant" modelopt/onnx/ --type=py | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 3763
🏁 Script executed:
# Check the ONNX DequantizeLinear spec: for FP8, is zero_point used in dequantization?
# Look at what happens with zero_point in DequantizeLinear operations
rg -B5 -A10 "def.*dequantize|DequantizeLinear.*zero" modelopt/onnx/ --type=py | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 4724
🏁 Script executed:
# Critical: check if FP8 nodes are ALWAYS created with zero_point=0
rg -B3 -A3 "default_quant_type.*fp8" modelopt/onnx/quantization/autotune/common.pyRepository: NVIDIA/Model-Optimizer
Length of output: 656
🏁 Script executed:
# Check if there's any validation that enforces zero_point=0 for FP8
rg -i "fp8.*zero_point|zero_point.*fp8|fp8.*zero" modelopt/onnx/quantization/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 186
🏁 Script executed:
# Final verification: Check ONNX DequantizeLinear spec - does it use zero_point?
# Look at how DequantizeLinear inputs are documented
rg -B5 -A10 "inputs.*scale.*zero_point" modelopt/onnx/ --type=py | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 2534
🏁 Script executed:
# Check if there's any assertion or validation that zero_point must be 0 for FP8
rg "zero_point.*==.*0|assert.*zero" modelopt/onnx/llm_export_utils/surgeon_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Verify: can FP8 nodes have non-zero zero_point? Check if fold function handles this
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '85,105p'Repository: NVIDIA/Model-Optimizer
Length of output: 972
🏁 Script executed:
# Check ONNX QuantizeLinear/DequantizeLinear formula in documentation
# Look for how zero_point affects the quantization/dequantization
rg -B3 -A3 "saturate.*round|round.*zero_point" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Key question: if zero_point != 0, what happens when folding?
# Let's verify what the actual mathematical operation should be
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p' > /tmp/surgeon.txt && cat /tmp/surgeon.txtRepository: NVIDIA/Model-Optimizer
Length of output: 3096
🏁 Script executed:
# Check if there's ever a case where FP8 nodes are created with non-zero zero_point
rg -B10 "int8_to_fp8" modelopt/onnx/quantization/fp8.py | tail -20Repository: NVIDIA/Model-Optimizer
Length of output: 984
🏁 Script executed:
# Final check: are there any tests for fold_fp8_qdq_to_dq with non-zero zero_point?
find tests/ -name "*.py" -exec grep -l "fold_fp8_qdq_to_dq\|surgeon_utils" {} \;Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Verify: What does is_fp8_constant check? Does it ensure zero_point=0?
rg -B5 -A10 "def is_fp8_constant" modelopt/onnx/utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 619
🏁 Script executed:
# Check if DequantizeLinear semantics actually use zero_point
# Search for ONNX spec or comments about the dequantization formula
rg -B5 -A5 "y = \(x - zero|saturate.*round" modelopt/onnx/ --type=pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Verify the actual behavior: if zero_point is FP8, can it be non-zero?
python3 << 'EOF'
import numpy as np
# FP8 E4M3FN format can represent values - check if zero_point=0 is always used
# Check the range and representation
print("FP8 E4M3FN can represent:")
print("- Zero values: 0.0")
print("- Positive values: up to ~480")
print("- Negative values: down to ~-480")
print("\nFor quantization semantics:")
print("- INT8: zero_point typically 0 (signed) or 128 (unsigned)")
print("- FP8: zero_point would be 0.0 (no unsigned representation)")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 320
🏁 Script executed:
# Final: Check if any code path can create FP8 with non-zero zero_point
rg -B15 "default_q_zero_point" modelopt/onnx/quantization/autotune/common.py | grep -A15 "default_quant_type"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Verify: ONNX QuantizeLinear formula - what does zero_point do?
# Also check if code ever sets non-zero zero_point for FP8
rg "default_q_zero_point.*0|zero_point.*=.*0" modelopt/onnx/quantization/autotune/export_utils.pyRepository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Critical: Check the int8_to_fp8 conversion - does it preserve zero_point values?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,112p'Repository: NVIDIA/Model-Optimizer
Length of output: 973
🏁 Script executed:
# Verify the formula: if zero_point is in QuantizeLinear, does the fold code handle it?
python3 << 'EOF'
# ONNX standard formulas:
# QuantizeLinear: y = saturate(round(x / scale) + zero_point)
# DequantizeLinear: y = (x - zero_point) * scale
# Current fold operation:
# folded_weights = (original_weights / scale)
# Later dequantize: (folded_weights - zero_point) * scale
# = ((original_weights / scale) - zero_point) * scale
# = original_weights - (zero_point * scale) ← WRONG if zero_point != 0
# Correct fold should be:
# folded_weights = (original_weights / scale) + zero_point
# Then dequantize: (folded_weights - zero_point) * scale = original_weights ✓
print("Issue confirmed: if zero_point != 0, folding is mathematically incorrect")
print("Current code ignores zero_point in the folding calculation")
EOFRepository: NVIDIA/Model-Optimizer
Length of output: 197
FP8 QuantizeLinear folding ignores zero_point in dequantization formula.
At Line 112, the code folds weights / scale without accounting for zero_point. ONNX QuantizeLinear applies the formula saturate(round(x / scale) + zero_point), and the paired DequantizeLinear applies (x - zero_point) * scale. When folding, the stored FP8 weights should encode (x / scale) + zero_point so that subsequent dequantization yields correct results. If zero_point != 0, the current implementation produces numerically incorrect folded weights.
Recommend either:
- Enforce
zero_point == 0before folding (add validation), or - Include
zero_pointin the weight conversion:(torch_weights / torch_scale + torch_zero_point).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 110 - 113, The
FP8 folding currently computes numpy_weights from (torch_weights / torch_scale)
but ignores QuantizeLinear's zero_point; update the folding in surgeon_utils.py
to either validate that torch_zero_point == 0 and raise/log if not, or
incorporate the zero point into the conversion by computing (torch_weights /
torch_scale + torch_zero_point) before casting to FP8; modify the numpy_weights
creation (referencing numpy_weights, torch_weights, torch_scale,
torch_zero_point) to include this change and ensure correct rounding/typing for
the subsequent .to(torch.float8_e4m3fn).
What does this PR do?
Type of change: Bug fix
When exporting FP8-quantized PyTorch models to ONNX, the previous implementation
emitted custom TensorRT ops (
trt::TRT_FP8QuantizeLinear/trt::TRT_FP8DequantizeLinear).These ops have no registered ONNX shape inference functions, so PyTorch's
_jit_pass_onnxpass loses shape information on those nodes, breaking anydownstream ONNX tool that relies on shape inference.
This PR replaces the TRT custom FP8 ops with standard ONNX
QuantizeLinear/DequantizeLinearops using aFLOAT8E4M3FNzero_point (opset 19), which havebuilt-in ONNX shape inference. The
compress_weights()function infp8_exporter.pyandfold_fp8_qdq_to_dq()insurgeon_utils.pyare updatedto detect both the legacy TRT ops (backward compatibility) and the new standard
ONNX QDQ ops.
Usage
# Add a code snippet demonstrating how to use thisTesting
tests/unit/torch/quantization/test_fp8_onnx_shape.py:
- test_trt_fp8_ops_unsupported_by_onnx_inference: documents the root cause — ONNX shape inference cannot process TRT custom domain ops
- test_fp8_onnx_export_shape_preserved: exports a real FP8-quantized SimpleConv at opset 19, runs shape inference, and asserts no TRT custom ops remain and all QDQ outputs have non-empty shape info
Additional Information
Requires opset 19 (FP8 support was added to the ONNX spec in opset 19). INT8 export is unaffected.
Summary by CodeRabbit
New Features
Bug Fixes
Tests