Skip to content

Fix FP8 ONNX export to use standard QuantizeLinear/DequantizeLinear ops#1037

Open
IgorBaratta wants to merge 3 commits intoNVIDIA:mainfrom
IgorBaratta:igor/export_fp8
Open

Fix FP8 ONNX export to use standard QuantizeLinear/DequantizeLinear ops#1037
IgorBaratta wants to merge 3 commits intoNVIDIA:mainfrom
IgorBaratta:igor/export_fp8

Conversation

@IgorBaratta
Copy link

@IgorBaratta IgorBaratta commented Mar 13, 2026

What does this PR do?

Type of change: Bug fix

When exporting FP8-quantized PyTorch models to ONNX, the previous implementation
emitted custom TensorRT ops (trt::TRT_FP8QuantizeLinear /trt::TRT_FP8DequantizeLinear).
These ops have no registered ONNX shape inference functions, so PyTorch's
_jit_pass_onnx pass loses shape information on those nodes, breaking any
downstream ONNX tool that relies on shape inference.

This PR replaces the TRT custom FP8 ops with standard ONNX QuantizeLinear /
DequantizeLinear ops using a FLOAT8E4M3FN zero_point (opset 19), which have
built-in ONNX shape inference. The compress_weights() function in
fp8_exporter.py and fold_fp8_qdq_to_dq() in surgeon_utils.py are updated
to detect both the legacy TRT ops (backward compatibility) and the new standard
ONNX QDQ ops.

Usage

# Add a code snippet demonstrating how to use this

Testing

tests/unit/torch/quantization/test_fp8_onnx_shape.py:
- test_trt_fp8_ops_unsupported_by_onnx_inference: documents the root cause — ONNX shape inference cannot process TRT custom domain ops
- test_fp8_onnx_export_shape_preserved: exports a real FP8-quantized SimpleConv at opset 19, runs shape inference, and asserts no TRT custom ops remain and all QDQ outputs have non-empty shape info

Additional Information

Requires opset 19 (FP8 support was added to the ONNX spec in opset 19). INT8 export is unaffected.

Summary by CodeRabbit

  • New Features

    • Extended Float8 quantization support for ONNX export, adding broader standard-path handling and updated dtype mapping
    • Added runtime detection for Float8 constants to improve export/rewiring behavior
  • Bug Fixes

    • Improved handling of Float8 weight compression and dequantization paths to preserve types and shapes
  • Tests

    • Added unit tests to verify Float8 ONNX exports preserve shape information

Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
@IgorBaratta IgorBaratta requested review from a team as code owners March 13, 2026 17:54
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

Broaden FP8 ONNX handling: detect FLOAT8 constants, fold FP8 QDQ patterns for both TRT-specific and standard QuantizeLinear/DequantizeLinear paths, export Float8 dtypes from PyTorch, and add tests ensuring FP8 ONNX exports preserve shape information.

Changes

Cohort / File(s) Summary
FP8 export & weight compression
modelopt/onnx/export/fp8_exporter.py
Extended FP8 weight compression to detect and handle both TRT_FP8QuantizeLinear and standard QuantizeLinear FP8 paths; compute FP8 weight tensors from torch weights/scales, materialize FP8 constants, and adjust downstream dequantize op and dtype propagation.
Graph surgery / QDQ folding
modelopt/onnx/llm_export_utils/surgeon_utils.py
Generalized QDQ folding to recognize FP8 QDQ via is_fp8_constant; enforce constant-weight input, accept TRT_FP8DequantizeLinear or DequantizeLinear, create FP8 weight constants, rewire DQ inputs/outputs, and preserve dtype propagation and cleanup.
ONNX utils FP8 detection
modelopt/onnx/utils.py
Added is_fp8_constant(const: Constant) -> bool to detect gs.Constant values wrapping FLOAT8E4M3FN tensors safely (guards against LazyValues/internal changes).
Torch ONNX export FP8 support
modelopt/torch/quantization/export_onnx.py
Added "Float8": torch.float8_e4m3fn to torch_dtype_map; replaced TRT-specific FP8 custom ops with standard QuantizeLinear/DequantizeLinear flows (opset 19 style), including explicit casts, zero_point handling, and output type preservation/conditional casts.
Unit tests for FP8 ONNX shapes
tests/unit/torch/quantization/test_fp8_onnx_shape.py
New tests: one showing TRT FP8 ops may be unsupported by ONNX shape inference, another verifying exported FP8 models (opset 19 / standard ONNX FP8 ops) preserve shape info for QuantizeLinear/DequantizeLinear outputs.

Sequence Diagram(s)

sequenceDiagram
    participant PyTorchExporter as PyTorch Exporter
    participant ExportONNX as export_onnx.py
    participant ONNXGraph as ONNX Graph (nodes)
    participant Surgeon as surgeon_utils
    participant FP8Exporter as fp8_exporter
    participant ShapeInfer as ONNX Shape Inference / Tests

    PyTorchExporter->>ExportONNX: request export with Float8 dtype
    ExportONNX->>ONNXGraph: emit QuantizeLinear / DequantizeLinear (FLOAT8) nodes
    ONNXGraph->>Surgeon: detect QDQ patterns (is_fp8_constant)
    Surgeon->>FP8Exporter: fold QDQ -> materialize FP8 weight constants
    FP8Exporter->>ONNXGraph: replace DQ inputs/outputs, adjust dtypes
    ONNXGraph->>ShapeInfer: run shape inference / tests
    ShapeInfer-->>PyTorchExporter: report preserved shapes / issues
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~28 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title directly and concisely summarizes the main change: replacing custom TRT FP8 ops with standard ONNX QuantizeLinear/DequantizeLinear operators.
Docstring Coverage ✅ Passed Docstring coverage is 90.91% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns found. All modified files safely handle quantization exports without unsafe deserialization, code execution, or dependency risks.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/unit/torch/quantization/test_fp8_onnx_shape.py (1)

67-71: Bare except Exception is overly broad.

Catching all exceptions obscures which specific error indicates "newer ONNX rejects unknown domains." Consider catching onnx.onnx_cpp2py_export.shape_inference.InferenceError or similar specific exception.

Suggested refinement
     try:
         inferred = onnx.shape_inference.infer_shapes(model, strict_mode=False)
-    except Exception:
+    except onnx.shape_inference.InferenceError:
         # Newer ONNX rejects unknown domains outright — shape inference is impossible.
         return

If the specific exception type varies across ONNX versions, the current approach is acceptable but should be documented.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py` around lines 67 - 71,
Replace the broad try/except around onnx.shape_inference.infer_shapes(model,
strict_mode=False) with a targeted exception handler: catch the ONNX-specific
inference error (e.g., onnx.onnx_cpp2py_export.shape_inference.InferenceError or
the appropriate exception class exported by your ONNX version) and return in
that case; if multiple ONNX versions raise different exception types, catch
those specific types explicitly and only fall back to a broad except that
re-raises unexpected exceptions (or document why a broad catch is required).
This change should target the infer_shapes call and its surrounding exception
handling so only expected ONNX shape-inference failures are swallowed.
modelopt/onnx/export/fp8_exporter.py (1)

28-32: Duplicate helper function — consolidate with surgeon_utils.py.

This function is identical to _is_float8_constant in modelopt/onnx/llm_export_utils/surgeon_utils.py. Extract to a shared module to maintain a single source of truth.

Suggested location for shared utility

Create modelopt/onnx/utils.py (or similar) with:

def is_float8_constant(const: gs.Constant) -> bool:
    """Return True if the gs.Constant holds a FLOAT8E4M3FN tensor."""
    if isinstance(const.values, LazyValues):
        return const.values._tensor.data_type == onnx.TensorProto.FLOAT8E4M3FN
    return False

Then import from both fp8_exporter.py and surgeon_utils.py.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 28 - 32, Extract the
duplicate _is_float8_constant implementation into a single shared utility (e.g.,
add is_float8_constant in a new module modelopt/onnx/utils.py) and have both
fp8_exporter.py and surgeon_utils.py import and use that shared function instead
of defining their own copies; remove the local _is_float8_constant definition
from fp8_exporter.py and replace its usage with the imported is_float8_constant
(ensure LazyValues and onnx.TensorProto references remain available via existing
imports or by importing them into the new utils module).
modelopt/onnx/llm_export_utils/surgeon_utils.py (1)

27-31: Accessing private _tensor attribute may break on library updates.

The helper accesses const.values._tensor.data_type, which is an internal attribute of LazyValues. This could break if onnx_graphsurgeon changes its internal implementation.

Additionally, this function is duplicated verbatim in modelopt/onnx/export/fp8_exporter.py. Consider extracting it to a shared utility module (e.g., modelopt/onnx/utils.py) to avoid drift.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 27 - 31, The
helper _is_float8_constant currently accesses the private LazyValues attribute
_tensor which is fragile; replace that direct access with a safe lookup of
public attributes (e.g., prefer const.values.tensor.data_type or
const.values.raw.data_type or const.values.dtype using getattr/hasattr checks)
so you only use documented/public fields when determining
onnx.TensorProto.FLOAT8E4M3FN, and move the function out of both
surgeon_utils.py and fp8_exporter.py into a single shared utility (e.g., create
modelopt/onnx/utils.py with _is_float8_constant and import it from
modelopt/onnx/llm_export_utils/surgeon_utils.py and
modelopt/onnx/export/fp8_exporter.py) to avoid duplication and future drift.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py`:
- Around line 123-136: The shape lookup currently only populates shape_by_name
from inferred.graph.value_info and therefore misses shapes that are only
declared in inferred.graph.output; update the population logic for shape_by_name
(the code around shape_by_name creation and the loop over
inferred.graph.value_info) to also iterate inferred.graph.output and add each
output's name -> shape (extracting dim.dim_value or -1 like the existing logic)
so that later checks for node.op_type in ("QuantizeLinear","DequantizeLinear")
find graph outputs as well; ensure you reference the same shape extraction
behavior and keep the missing collection logic for node.output unchanged.

---

Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 28-32: Extract the duplicate _is_float8_constant implementation
into a single shared utility (e.g., add is_float8_constant in a new module
modelopt/onnx/utils.py) and have both fp8_exporter.py and surgeon_utils.py
import and use that shared function instead of defining their own copies; remove
the local _is_float8_constant definition from fp8_exporter.py and replace its
usage with the imported is_float8_constant (ensure LazyValues and
onnx.TensorProto references remain available via existing imports or by
importing them into the new utils module).

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 27-31: The helper _is_float8_constant currently accesses the
private LazyValues attribute _tensor which is fragile; replace that direct
access with a safe lookup of public attributes (e.g., prefer
const.values.tensor.data_type or const.values.raw.data_type or
const.values.dtype using getattr/hasattr checks) so you only use
documented/public fields when determining onnx.TensorProto.FLOAT8E4M3FN, and
move the function out of both surgeon_utils.py and fp8_exporter.py into a single
shared utility (e.g., create modelopt/onnx/utils.py with _is_float8_constant and
import it from modelopt/onnx/llm_export_utils/surgeon_utils.py and
modelopt/onnx/export/fp8_exporter.py) to avoid duplication and future drift.

In `@tests/unit/torch/quantization/test_fp8_onnx_shape.py`:
- Around line 67-71: Replace the broad try/except around
onnx.shape_inference.infer_shapes(model, strict_mode=False) with a targeted
exception handler: catch the ONNX-specific inference error (e.g.,
onnx.onnx_cpp2py_export.shape_inference.InferenceError or the appropriate
exception class exported by your ONNX version) and return in that case; if
multiple ONNX versions raise different exception types, catch those specific
types explicitly and only fall back to a broad except that re-raises unexpected
exceptions (or document why a broad catch is required). This change should
target the infer_shapes call and its surrounding exception handling so only
expected ONNX shape-inference failures are swallowed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 76bdd9c0-d449-4f22-a38d-75c65c5f073c

📥 Commits

Reviewing files that changed from the base of the PR and between 58417e5 and a7b4a7c.

📒 Files selected for processing (4)
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/llm_export_utils/surgeon_utils.py
  • modelopt/torch/quantization/export_onnx.py
  • tests/unit/torch/quantization/test_fp8_onnx_shape.py

@codecov
Copy link

codecov bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 63.15789% with 7 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.26%. Comparing base (58417e5) to head (0e8b060).

Files with missing lines Patch % Lines
modelopt/onnx/utils.py 25.00% 6 Missing ⚠️
modelopt/torch/quantization/export_onnx.py 90.90% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1037      +/-   ##
==========================================
+ Coverage   70.07%   70.26%   +0.19%     
==========================================
  Files         221      221              
  Lines       25499    25512      +13     
==========================================
+ Hits        17869    17927      +58     
+ Misses       7630     7585      -45     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

…ore, narrow except, include graph outputs in shape lookup

Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 23-24: The compress_weights function uses LazyValues(tensor) but
LazyValues is not imported, causing a NameError; fix by importing LazyValues
from its defining module (where LazyValues is implemented) at the top of this
file and ensure the symbol is referenced in compress_weights (the use at
LazyValues(tensor) around lines 97-98), e.g., add an import for LazyValues
alongside is_fp8_constant so compress_weights can construct LazyValues without
runtime errors.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 24-25: The file is missing an import for LazyValues which causes a
NameError when fold_fp8_qdq_to_dq() processes FP8 weights; add LazyValues to the
import list (alongside is_fp8_constant) so the symbol is available where used in
fold_fp8_qdq_to_dq and related FP8 handling around lines referencing LazyValues.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 59c379b3-11cc-4d4e-b563-ae338d0d3667

📥 Commits

Reviewing files that changed from the base of the PR and between a7b4a7c and 8338dd2.

📒 Files selected for processing (5)
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/llm_export_utils/surgeon_utils.py
  • modelopt/onnx/utils.py
  • modelopt/torch/quantization/export_onnx.py
  • tests/unit/torch/quantization/test_fp8_onnx_shape.py

Signed-off-by: Igor Baratta <ialmeidabara@nvidia.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/llm_export_utils/surgeon_utils.py`:
- Around line 105-108: The code dereferences node.outputs[0].outputs[0] into
dq_op before checking its existence, which can raise IndexError; modify the
logic in surgeon_utils.py so you first verify node.outputs is non-empty and
node.outputs[0].outputs is non-empty (e.g., if not node.outputs or not
node.outputs[0].outputs: raise a descriptive error or return) before assigning
dq_op, then perform the existing assert on dq_op.op; reference the variables
node and dq_op to locate the fix.
- Around line 110-113: The FP8 folding currently computes numpy_weights from
(torch_weights / torch_scale) but ignores QuantizeLinear's zero_point; update
the folding in surgeon_utils.py to either validate that torch_zero_point == 0
and raise/log if not, or incorporate the zero point into the conversion by
computing (torch_weights / torch_scale + torch_zero_point) before casting to
FP8; modify the numpy_weights creation (referencing numpy_weights,
torch_weights, torch_scale, torch_zero_point) to include this change and ensure
correct rounding/typing for the subsequent .to(torch.float8_e4m3fn).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 35b56199-39c9-488f-a001-21d875d55823

📥 Commits

Reviewing files that changed from the base of the PR and between 8338dd2 and 0e8b060.

📒 Files selected for processing (3)
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/llm_export_utils/surgeon_utils.py
  • tests/unit/torch/quantization/test_fp8_onnx_shape.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unit/torch/quantization/test_fp8_onnx_shape.py
  • modelopt/onnx/export/fp8_exporter.py

Comment on lines +105 to +108
dq_op = node.outputs[0].outputs[0]
assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
f"QDQ does not occur in pairs. You reached {dq_op.op}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard Q output-chain indexing before dereference.

At Line 105, node.outputs[0].outputs[0] can raise IndexError on non-canonical graphs before your pair check at Line 106 runs.

🔧 Proposed fix
-        dq_op = node.outputs[0].outputs[0]
-        assert dq_op.op in ("TRT_FP8DequantizeLinear", "DequantizeLinear"), (
-            f"QDQ does not occur in pairs. You reached {dq_op.op}"
-        )
+        if not node.outputs or not node.outputs[0].outputs:
+            continue
+        dq_op = node.outputs[0].outputs[0]
+        if dq_op.op not in ("TRT_FP8DequantizeLinear", "DequantizeLinear"):
+            continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 105 - 108, The
code dereferences node.outputs[0].outputs[0] into dq_op before checking its
existence, which can raise IndexError; modify the logic in surgeon_utils.py so
you first verify node.outputs is non-empty and node.outputs[0].outputs is
non-empty (e.g., if not node.outputs or not node.outputs[0].outputs: raise a
descriptive error or return) before assigning dq_op, then perform the existing
assert on dq_op.op; reference the variables node and dq_op to locate the fix.

Comment on lines +110 to +113
# Replace it with Dequantize with FP8 weights. This is a WAR because numpy does not support fp8.
numpy_weights = (
(torch_weights / torch_scale).to(torch.float8_e4m3fn).view(torch.uint8).numpy()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# 1) Locate FP8 QuantizeLinear construction and zero_point setup in exporter paths.
rg -nP --type=py -C3 '\bQuantizeLinear\b|TRT_FP8QuantizeLinear|FLOAT8E4M3FN|zero_point|float8_e4m3fn'

# 2) Locate folding assumptions and any explicit zero_point validation.
rg -nP --type=py -C3 'fold_fp8_qdq_to_dq|is_std_fp8_q|is_fp8_constant|zero_point|axis'

# 3) Check whether tests cover non-zero zero_point / axis cases for FP8 fold/export.
rg -nP --type=py -C3 'test_.*fp8.*(zero_point|axis)|fp8_onnx_export_shape|fold_fp8_qdq_to_dq'

Repository: NVIDIA/Model-Optimizer

Length of output: 50379


🏁 Script executed:

# First, let's read the specific file and lines mentioned in the review
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '100,125p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1413


🏁 Script executed:

# Search for the function containing this code and zero_point handling
rg -A20 -B5 "Replace it with Dequantize with FP8 weights" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1146


🏁 Script executed:

# Look for fold_fp8_qdq or similar folding functions
fd surgeon_utils.py -x cat -n {} | head -200

Repository: NVIDIA/Model-Optimizer

Length of output: 5481


🏁 Script executed:

# Search for zero_point handling in the surgeon_utils file and related functions
rg -B10 -A10 "zero_point" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check ONNX standard or comments about FP8 zero_point semantics
rg -C5 "TRT_FP8QuantizeLinear|is_std_fp8_q" modelopt/onnx/ --type=py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 3014


🏁 Script executed:

# Look for FP8 quantization logic and zero_point usage
rg -C5 "float8_e4m3fn.*zero|zero.*float8_e4m3fn" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check how FP8 QuantizeLinear nodes are created - what zero_point values are used
rg -B5 -A5 "QuantizeLinear.*fp8|float8.*QuantizeLinear" modelopt/onnx/ --type=py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for zero_point creation/assignment in FP8 contexts
rg "zero_point.*=.*0|zero_point.*fp8|fp8.*zero" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 145


🏁 Script executed:

# Check what the ONNX standard and comments say about FP8 and zero_point
rg -B3 -A3 "FP8|float8_e4m3fn" modelopt/onnx/quantization/fp8.py | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 2904


🏁 Script executed:

# Look for how zero_point is handled when creating FP8 QuantizeLinear nodes
rg -B10 -A10 "create.*qdq|QuantizeLinear.*zero" modelopt/onnx/quantization/autotune/export_utils.py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 2035


🏁 Script executed:

# Check how int8_to_fp8 handles zero_point conversion
rg -B5 -A15 "def int8_to_fp8" modelopt/onnx/quantization/fp8.py

Repository: NVIDIA/Model-Optimizer

Length of output: 797


🏁 Script executed:

# Search for any comments about FP8 and zero_point semantics in the codebase
rg -i "fp8.*zero_point|zero_point.*fp8" modelopt/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for test files that test FP8 folding with zero_point
fd "test.*fp8|fp8.*test" -t f modelopt/ | xargs rg -l "zero_point|fold_fp8" 2>/dev/null | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 325


🏁 Script executed:

# Get the full int8_to_fp8 function to see how it handles zero_point
cat -n modelopt/onnx/quantization/fp8.py | sed -n '40,130p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4531


🏁 Script executed:

# Check if there are any tests that verify folding with non-zero zero_point
rg -l "fold_fp8" tests/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the actual ONNX spec documentation in comments or code
rg -B5 -A5 "QuantizeLinear.*zero_point|zero_point.*QuantizeLinear" modelopt/onnx/quantization/fp8.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at what zero_point values are used when creating FP8 nodes
rg -B10 -A10 "get_zero_point_for_quant_type" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1393


🏁 Script executed:

# Check if FP8 quantization ever uses non-zero zero_point
rg -B5 -A5 "quant_type.*fp8|fp8.*quant" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 619


🏁 Script executed:

# Look at the full fold_fp8_qdq_to_dq function and see how zero_point is accessed
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3096


🏁 Script executed:

# Check if DequantizeLinear with FP8 uses zero_point in its operation
rg -B3 -A3 "DequantizeLinear.*zero_point|zero_point.*DequantizeLinear" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if FP8 is ever created with non-zero zero_point
rg -B5 -A10 "default_q_zero_point.*fp8|fp8.*default.*zero" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at tests to see if they verify zero_point handling for FP8
fd "test.*fp8" tests/ --type=py -x grep -l "zero_point" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 235


🏁 Script executed:

# Check the ONNX standard comment in create_qdq_nodes for how zero_point is used
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '149,200p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2278


🏁 Script executed:

# Verify: does FP8 always use zero_point=0 in the export pipeline?
rg -B10 -A10 "export_qdq_onnx.*fp8|fp8.*export_qdq|config.default_quant_type.*fp8" modelopt/onnx/quantization/autotune/ --type=py | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 2026


🏁 Script executed:

# Check what zero_point is used when creating FP8 QuantizeLinear nodes
cat -n modelopt/onnx/quantization/autotune/export_utils.py | sed -n '215,230p'

Repository: NVIDIA/Model-Optimizer

Length of output: 945


🏁 Script executed:

# Verify how int8_to_fp8 handles zero_point: does it convert INT8 zero_point to FP8?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,115p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1146


🏁 Script executed:

# Search for where fold_fp8_qdq_to_dq is called - is there a constraint on zero_point?
rg -B5 -A5 "fold_fp8_qdq_to_dq" modelopt/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 1099


🏁 Script executed:

# Check ONNX QuantizeLinear semantics: does DequantizeLinear with FP8 use zero_point?
rg -B10 -A10 "TRT_FP8Dequantize|FLOAT8E4M3FN.*dequant" modelopt/onnx/ --type=py | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 3763


🏁 Script executed:

# Check the ONNX DequantizeLinear spec: for FP8, is zero_point used in dequantization?
# Look at what happens with zero_point in DequantizeLinear operations
rg -B5 -A10 "def.*dequantize|DequantizeLinear.*zero" modelopt/onnx/ --type=py | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 4724


🏁 Script executed:

# Critical: check if FP8 nodes are ALWAYS created with zero_point=0
rg -B3 -A3 "default_quant_type.*fp8" modelopt/onnx/quantization/autotune/common.py

Repository: NVIDIA/Model-Optimizer

Length of output: 656


🏁 Script executed:

# Check if there's any validation that enforces zero_point=0 for FP8
rg -i "fp8.*zero_point|zero_point.*fp8|fp8.*zero" modelopt/onnx/quantization/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 186


🏁 Script executed:

# Final verification: Check ONNX DequantizeLinear spec - does it use zero_point?
# Look at how DequantizeLinear inputs are documented
rg -B5 -A10 "inputs.*scale.*zero_point" modelopt/onnx/ --type=py | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 2534


🏁 Script executed:

# Check if there's any assertion or validation that zero_point must be 0 for FP8
rg "zero_point.*==.*0|assert.*zero" modelopt/onnx/llm_export_utils/surgeon_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: can FP8 nodes have non-zero zero_point? Check if fold function handles this
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '85,105p'

Repository: NVIDIA/Model-Optimizer

Length of output: 972


🏁 Script executed:

# Check ONNX QuantizeLinear/DequantizeLinear formula in documentation
# Look for how zero_point affects the quantization/dequantization
rg -B3 -A3 "saturate.*round|round.*zero_point" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Key question: if zero_point != 0, what happens when folding?
# Let's verify what the actual mathematical operation should be
cat -n modelopt/onnx/llm_export_utils/surgeon_utils.py | sed -n '66,130p' > /tmp/surgeon.txt && cat /tmp/surgeon.txt

Repository: NVIDIA/Model-Optimizer

Length of output: 3096


🏁 Script executed:

# Check if there's ever a case where FP8 nodes are created with non-zero zero_point
rg -B10 "int8_to_fp8" modelopt/onnx/quantization/fp8.py | tail -20

Repository: NVIDIA/Model-Optimizer

Length of output: 984


🏁 Script executed:

# Final check: are there any tests for fold_fp8_qdq_to_dq with non-zero zero_point?
find tests/ -name "*.py" -exec grep -l "fold_fp8_qdq_to_dq\|surgeon_utils" {} \;

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: What does is_fp8_constant check? Does it ensure zero_point=0?
rg -B5 -A10 "def is_fp8_constant" modelopt/onnx/utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 619


🏁 Script executed:

# Check if DequantizeLinear semantics actually use zero_point
# Search for ONNX spec or comments about the dequantization formula
rg -B5 -A5 "y = \(x - zero|saturate.*round" modelopt/onnx/ --type=py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify the actual behavior: if zero_point is FP8, can it be non-zero?
python3 << 'EOF'
import numpy as np
# FP8 E4M3FN format can represent values - check if zero_point=0 is always used
# Check the range and representation
print("FP8 E4M3FN can represent:")
print("- Zero values: 0.0")
print("- Positive values: up to ~480")
print("- Negative values: down to ~-480")
print("\nFor quantization semantics:")
print("- INT8: zero_point typically 0 (signed) or 128 (unsigned)")
print("- FP8: zero_point would be 0.0 (no unsigned representation)")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 320


🏁 Script executed:

# Final: Check if any code path can create FP8 with non-zero zero_point
rg -B15 "default_q_zero_point" modelopt/onnx/quantization/autotune/common.py | grep -A15 "default_quant_type"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Verify: ONNX QuantizeLinear formula - what does zero_point do?
# Also check if code ever sets non-zero zero_point for FP8
rg "default_q_zero_point.*0|zero_point.*=.*0" modelopt/onnx/quantization/autotune/export_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Critical: Check the int8_to_fp8 conversion - does it preserve zero_point values?
cat -n modelopt/onnx/quantization/fp8.py | sed -n '98,112p'

Repository: NVIDIA/Model-Optimizer

Length of output: 973


🏁 Script executed:

# Verify the formula: if zero_point is in QuantizeLinear, does the fold code handle it?
python3 << 'EOF'
# ONNX standard formulas:
# QuantizeLinear: y = saturate(round(x / scale) + zero_point)
# DequantizeLinear: y = (x - zero_point) * scale

# Current fold operation:
# folded_weights = (original_weights / scale)
# Later dequantize: (folded_weights - zero_point) * scale
#                 = ((original_weights / scale) - zero_point) * scale
#                 = original_weights - (zero_point * scale)  ← WRONG if zero_point != 0

# Correct fold should be:
# folded_weights = (original_weights / scale) + zero_point
# Then dequantize: (folded_weights - zero_point) * scale = original_weights ✓

print("Issue confirmed: if zero_point != 0, folding is mathematically incorrect")
print("Current code ignores zero_point in the folding calculation")
EOF

Repository: NVIDIA/Model-Optimizer

Length of output: 197


FP8 QuantizeLinear folding ignores zero_point in dequantization formula.

At Line 112, the code folds weights / scale without accounting for zero_point. ONNX QuantizeLinear applies the formula saturate(round(x / scale) + zero_point), and the paired DequantizeLinear applies (x - zero_point) * scale. When folding, the stored FP8 weights should encode (x / scale) + zero_point so that subsequent dequantization yields correct results. If zero_point != 0, the current implementation produces numerically incorrect folded weights.

Recommend either:

  1. Enforce zero_point == 0 before folding (add validation), or
  2. Include zero_point in the weight conversion: (torch_weights / torch_scale + torch_zero_point).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/llm_export_utils/surgeon_utils.py` around lines 110 - 113, The
FP8 folding currently computes numpy_weights from (torch_weights / torch_scale)
but ignores QuantizeLinear's zero_point; update the folding in surgeon_utils.py
to either validate that torch_zero_point == 0 and raise/log if not, or
incorporate the zero point into the conversion by computing (torch_weights /
torch_scale + torch_zero_point) before casting to FP8; modify the numpy_weights
creation (referencing numpy_weights, torch_weights, torch_scale,
torch_zero_point) to include this change and ensure correct rounding/typing for
the subsequent .to(torch.float8_e4m3fn).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant