Note
Go to the end to download the full example code.
Find and fix an export issue due to dynamic shapes¶
LLMs must be exported with dynamic shapes and it is common that
a static dimension turns into a static ones. The error message from
pytorch tells the user to define TORCH_LOGS="+dynamic"
but it shows a very long list of messages where we need
to find the string range_refined_to_singleton
and that
does not really indicates where it comes from. The example
shows how to tweak pytorch to get that information until
it gets better.
A model with an export issue¶
The following model implies the first dimension of x is equal to 1
or equal to the number of element in the list ys
.
It is not really dynamic. It looks obvious here but
it is difficult to find deep inside a big model.
import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_patches
class ModelWithIssue(torch.nn.Module):
def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
z = x * caty
return z
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.3481, 0.0791, 0.6101, 0.0440],
[0.0128, 0.1095, 0.0991, 0.0882],
[0.0015, 0.0774, 0.0349, 0.0051]],
[[0.4564, 0.8413, 0.4050, 0.1063],
[0.4719, 0.6740, 0.1819, 0.6136],
[0.7280, 0.0364, 0.0384, 0.5866]]])
Let’s export.
DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
print(ep)
except Exception as e:
print("-- ERROR:")
print(e)
-- ERROR:
Constraints violated (L['args'][0][0].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0]) are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.
The error shows:
Constraints violated (L['args'][0][0].size()[0])!
For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:
assert msg != "range_refined_to_singleton", (
f"A dynamic dimension becomes static! "
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)
Stop when a dynamic dimension turns static¶
We use torch_export_patches
to replace torch implementation by a new one raising the exception
mentioned in previous section.
with torch_export_patches(stop_if_static=1, verbose=1):
try:
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
print("-- It failed as excepted.")
print(f"-- final error is {e}")
print("-- Stack Trace")
print(traceback.format_exc())
# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
# z = x * caty
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_register_cache_serialization] register <class 'transformers.cache_utils.MambaCache'>
[_register_cache_serialization] register <class 'transformers.cache_utils.EncoderDecoderCache'>
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.8.0.dev20250423+cu126'
[torch_export_patches] stop_if_static=1
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] assert when a dynamic dimension turns static
[torch_export_patches] replaces ShapeEnv._set_replacement
[torch_export_patches] replaces ShapeEnv._log_guard
[torch_export_patches] done patching
-- It failed as excepted.
-- final error is patched_ShapeEnv: A dynamic dimension becomes static! a=s35, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 318, in export
raise e
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 285, in export
return _export(
^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1110, in wrapper
raise e
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1076, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2122, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1110, in wrapper
raise e
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1076, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1983, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1925, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1710, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1851, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1630, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2288, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2226, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2197, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 856, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1785, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 856, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 837, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1276, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1534, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1855, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 530, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 805, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1835, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1855, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 530, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 805, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1755, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1766, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
z = x * caty
~~^~~~~~
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1324, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1371, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 942, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 925, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1426, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 926, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 806, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1312, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1950, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1424, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2473, in _dispatch_impl
return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 957, in fast_binary_impl
final_shape = infer_size(final_shape, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 119, in patched_infer_size
b3 = guard_size_oblivious(sizeA == sizeB)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 417, in guard_size_oblivious
return expr.node.guard_size_oblivious("", 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 594, in guard_size_oblivious
r = self.evaluate(size_oblivious=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6779, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6795, in evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7110, in _evaluate_expr
self._maybe_guard_rel(g)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6429, in _maybe_guard_rel
self._refine_ranges(expr)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7314, in _refine_ranges
self._set_replacement(
File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 343, in _set_replacement
assert msg != "range_refined_to_singleton", (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: patched_ShapeEnv: A dynamic dimension becomes static! a=s35, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored ShapeEnv._set_replacement
[torch_export_patches] restored ShapeEnv._log_guard
[torch_export_patches] restored shape constraints
[_unregister_cache_serialization] unregistered MambaCache
[_unregister_cache_serialization] unregistered EncoderDecoderCache
doc.plot_legend(
"dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 0.188 seconds)
Related examples