Note
Go to the end to download the full example code.
Find and fix an export issue due to dynamic shapes¶
LLMs must be exported with dynamic shapes and it is common that
a static dimension turns into a static ones. The error message from
pytorch tells the user to define TORCH_LOGS="+dynamic"
but it shows a very long list of messages where we need
to find the string range_refined_to_singleton and that
does not really indicates where it comes from. The example
shows how to tweak pytorch to get that information until
it gets better.
A model with an export issue¶
The following model implies the first dimension of x is equal to 1
or equal to the number of element in the list ys.
It is not really dynamic. It looks obvious here but
it is difficult to find deep inside a big model.
import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_patches
class ModelWithIssue(torch.nn.Module):
    def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
        caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
        z = x * caty
        return z
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.0519, 0.0844, 0.1746, 0.0404],
         [0.5118, 0.2140, 0.5120, 0.5620],
         [0.2474, 0.1650, 0.1219, 0.1728]],
        [[0.6298, 0.4665, 0.4820, 0.2788],
         [0.1256, 0.7841, 0.7342, 0.7996],
         [0.2932, 0.7785, 0.6687, 0.1029]]])
Let’s export.
DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
    ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    print(ep)
except Exception as e:
    print("-- ERROR:")
    print(e)
-- ERROR:
Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but tracing inferred a static shape of 2 for dimension inputs['x'].shape[0].
The error shows:
Constraints violated (L['args'][0][0].size()[0])!
    For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
    are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:
assert msg != "range_refined_to_singleton", (
    f"A dynamic dimension becomes static! "
    f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)
Stop when a dynamic dimension turns static¶
We use torch_export_patches
to replace torch implementation by a new one raising the exception
mentioned in previous section.
with torch_export_patches(stop_if_static=1, verbose=1):
    try:
        torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
        print("-- It failed as excepted.")
        print(f"-- final error is {e}")
        print("-- Stack Trace")
        print(traceback.format_exc())
# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
#   File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
#       z = x * caty
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250820+cu126'
[torch_export_patches] stop_if_static=1
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] assert when a dynamic dimension turns static
[torch_export_patches] replaces ShapeEnv._set_replacement
[torch_export_patches] replaces ShapeEnv._log_guard
[torch_export_patches] done patching
-- It failed as excepted.
-- final error is patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
    torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 307, in export
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 274, in export
    return _export(
           ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1158, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1124, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2192, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1158, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1124, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2055, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1997, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1788, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1917, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1701, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2354, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2285, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2256, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1005, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1283, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1867, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1341, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1588, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 187, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1354, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1937, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1937, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
    z = x * caty
        ~~^~~~~~
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1389, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1438, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 1067, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 961, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 286, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1493, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 840, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1376, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2092, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1511, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2611, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 1271, in fast_binary_impl
    final_shape = infer_size(final_shape, shape)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 119, in patched_infer_size
    b3 = guard_size_oblivious(sizeA == sizeB)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 476, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 596, in guard_size_oblivious
    r = self.evaluate(size_oblivious=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7238, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7338, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7361, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7630, in _evaluate_expr
    self._maybe_guard_rel(g)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6883, in _maybe_guard_rel
    self._refine_ranges(expr)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7834, in _refine_ranges
    self._set_replacement(
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 343, in _set_replacement
    assert msg != "range_refined_to_singleton", (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored ShapeEnv._set_replacement
[torch_export_patches] restored ShapeEnv._log_guard
[torch_export_patches] restored shape constraints
doc.plot_legend(
    "dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 0.277 seconds)
Related examples
 
Intermediate results with (ONNX) ReferenceEvaluator
 
