Note
Go to the end to download the full example code.
Find and fix an export issue due to dynamic shapes¶
LLMs must be exported with dynamic shapes and it is common that
a static dimension turns into a static ones. The error message from
pytorch tells the user to define TORCH_LOGS="+dynamic"
but it shows a very long list of messages where we need
to find the string range_refined_to_singleton
and that
does not really indicates where it comes from. The example
shows how to tweak pytorch to get that information until
it gets better.
A model with an export issue¶
The following model implies the first dimension of x is equal to 1
or equal to the number of element in the list ys
.
It is not really dynamic. It looks obvious here but
it is difficult to find deep inside a big model.
import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
class ModelWithIssue(torch.nn.Module):
def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
z = x * caty
return z
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.1896, 0.0919, 0.1302, 0.2041],
[0.6827, 0.3633, 0.4235, 0.7472],
[0.0214, 0.0209, 0.0168, 0.0108]],
[[0.4849, 0.3434, 0.4207, 0.4044],
[0.1987, 0.2693, 0.3528, 0.3519],
[0.5722, 0.1479, 0.4143, 0.1571]]])
Let’s export.
DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
print(ep)
except Exception as e:
print("-- ERROR:")
print(e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
-- ERROR:
Constraints violated (L['args'][0][0].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0]) are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
The error shows:
Constraints violated (L['args'][0][0].size()[0])!
For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:
assert msg != "range_refined_to_singleton", (
f"A dynamic dimension becomes static! "
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)
Stop when a dynamic dimension turns static¶
We use bypass_export_some_errors
to replace torch implementation by a new one raising the exception
mentioned in previous section.
with bypass_export_some_errors(stop_if_static=True, verbose=1):
try:
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
print("-- It failed as excepted.")
print(f"-- final error is {e}")
print("-- Stack Trace")
print(traceback.format_exc())
# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
# z = x * caty
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_register_cache_serialization] register MambaCache
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] assert when a dynamic dimension turns static
[bypass_export_some_errors] done patching
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
torch._C._set_onednn_allow_tf32(_allow_tf32)
-- It failed as excepted.
-- final error is A dynamic dimension becomes static! a=s0, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 274, in export
return _export(
^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1099, in wrapper
raise e
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1072, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2111, in _export
ep = _export_for_training(
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1099, in wrapper
raise e
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1072, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1973, in _export_for_training
export_artifact = export_func(
^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1916, in _non_strict_export
aten_export_artifact = _to_aten_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1703, in _export_to_aten_ir_make_fx
gm, graph_signature = transform(_make_fx_helper)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1844, in _aot_export_non_strict
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1623, in _make_fx_helper
gm = make_fx(
^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2253, in wrapped
return make_fx_tracer.trace(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2191, in trace
return self._trace_inner(f, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2162, in _trace_inner
t = dispatch_trace(
^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
return disable_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 841, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1187, in dispatch_trace
graph = tracer.trace(root, concrete_args) # type: ignore[arg-type]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1751, in trace
res = super().trace(root, concrete_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 838, in trace
(self.create_arg(fn(*args)),),
^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1242, in wrapped
out = f(*tensors) # type:ignore[call-arg]
^^^^^^^^^^^
File "<string>", line 1, in <lambda>
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1527, in wrapped_fn
return tuple(flat_fn(*args))
^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
out = mod(*args[params_len:], **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1821, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1828, in forward
tree_out = mod(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper
return self.call_module(mod, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1821, in call_module
return Tracer.call_module(self, m, forward, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 531, in call_module
ret_val = forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 806, in forward
return _orig_module_call(mod, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
z = x * caty
~~^~~~~~
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1290, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1337, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 689, in __torch_function__
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 875, in handler
return torch._library.utils.handle_dispatch_mode(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1392, in __torch_dispatch__
return proxy_call(self, func, self.pre_dispatch, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 927, in proxy_call
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1393, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2314, in _dispatch_impl
return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 921, in fast_binary_impl
final_shape = infer_size(final_shape, shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 102, in patched_infer_size
b3 = guard_size_oblivious(sizeA == sizeB)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 408, in guard_size_oblivious
return expr.node.guard_size_oblivious("", 0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 588, in guard_size_oblivious
r = self.evaluate(size_oblivious=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
return self.shape_env.evaluate_sym_node(self, size_oblivious)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6667, in evaluate_sym_node
return self.evaluate_expr(
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
return retlog(fn(*args, **kwargs))
^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6683, in evaluate_expr
return self._evaluate_expr(
^^^^^^^^^^^^^^^^^^^^
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6998, in _evaluate_expr
self._maybe_guard_rel(g)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6324, in _maybe_guard_rel
self._refine_ranges(expr)
File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7202, in _refine_ranges
self._set_replacement(
File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 316, in _set_replacement
assert msg != "range_refined_to_singleton", (
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: A dynamic dimension becomes static! a=s0, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
[bypass_export_some_errors] remove patches
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored ShapeEnv._set_replacement
[bypass_export_some_errors] restored shape constraints
[_unregister_cache_serialization] unregistered MambaCache
doc.plot_legend(
"dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)

Total running time of the script: (0 minutes 9.888 seconds)
Related examples

Intermediate results with (ONNX) ReferenceEvaluator