Note
Go to the end to download the full example code.
Find and fix an export issue due to dynamic shapes¶
LLMs must be exported with dynamic shapes and it is common that
a static dimension turns into a static ones. The error message from
pytorch tells the user to define TORCH_LOGS="+dynamic"
but it shows a very long list of messages where we need
to find the string range_refined_to_singleton
and that
does not really indicates where it comes from. The example
shows how to tweak pytorch to get that information until
it gets better.
A model with an export issue¶
The following model implies the first dimension of x is equal to 1
or equal to the number of element in the list ys
.
It is not really dynamic. It looks obvious here but
it is difficult to find deep inside a big model.
import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_patches
class ModelWithIssue(torch.nn.Module):
def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
z = x * caty
return z
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.3113, 0.5510, 0.4920, 0.3743],
[0.2655, 0.0550, 0.0221, 0.1182],
[0.3164, 0.2754, 0.1045, 0.1932]],
[[0.2115, 0.5465, 0.0468, 0.5540],
[0.2177, 0.1544, 0.1444, 0.2867],
[0.4253, 0.2722, 0.4732, 0.5565]]])
Let’s export.
DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
print(ep)
except Exception as e:
print("-- ERROR:")
print(e)
-- ERROR:
Constraints violated (L['x'].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
- You marked L['x'].size()[0] as dynamic but your code specialized it to be a constant (2). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
The error shows:
Constraints violated (L['args'][0][0].size()[0])!
For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:
assert msg != "range_refined_to_singleton", (
f"A dynamic dimension becomes static! "
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)
Stop when a dynamic dimension turns static¶
We use torch_export_patches
to replace torch implementation by a new one raising the exception
mentioned in previous section.
with torch_export_patches(stop_if_static=1, verbose=1):
try:
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
print("-- It failed as excepted.")
print(f"-- final error is {e}")
print("-- Stack Trace")
print(traceback.format_exc())
# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
# z = x * caty
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250723+cu126'
[torch_export_patches] stop_if_static=1
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] assert when a dynamic dimension turns static
[torch_export_patches] replaces ShapeEnv._set_replacement
[torch_export_patches] replaces ShapeEnv._log_guard
[torch_export_patches] done patching
-- It failed as excepted.
-- final error is patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 304, in export
raise e
File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 271, in export
return _export(
^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2163, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1150, in wrapper
raise e
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1116, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2026, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1968, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1762, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1891, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1676, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2351, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2283, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2254, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1004, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1283, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1865, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1004, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1341, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1580, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1138, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1875, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1935, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
z = x * caty
~~^~~~~~
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1389, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1436, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 1066, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 961, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1491, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 840, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2068, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2587, in _dispatch_impl
return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 962, in fast_binary_impl
final_shape = infer_size(final_shape, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 119, in patched_infer_size
b3 = guard_size_oblivious(sizeA == sizeB)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 476, in guard_size_oblivious
return expr.node.guard_size_oblivious("", 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 596, in guard_size_oblivious
r = self.evaluate(size_oblivious=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7237, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7337, in evaluate_expr
return self._inner_evaluate_expr(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7360, in _inner_evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7629, in _evaluate_expr
self._maybe_guard_rel(g)
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6882, in _maybe_guard_rel
self._refine_ranges(expr)
File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7833, in _refine_ranges
self._set_replacement(
File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 343, in _set_replacement
assert msg != "range_refined_to_singleton", (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored ShapeEnv._set_replacement
[torch_export_patches] restored ShapeEnv._log_guard
[torch_export_patches] restored shape constraints
doc.plot_legend(
"dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 0.635 seconds)
Related examples

Intermediate results with (ONNX) ReferenceEvaluator