Find and fix an export issue due to dynamic shapes

LLMs must be exported with dynamic shapes and it is common that a static dimension turns into a static ones. The error message from pytorch tells the user to define TORCH_LOGS="+dynamic" but it shows a very long list of messages where we need to find the string range_refined_to_singleton and that does not really indicates where it comes from. The example shows how to tweak pytorch to get that information until it gets better.

A model with an export issue

The following model implies the first dimension of x is equal to 1 or equal to the number of element in the list ys. It is not really dynamic. It looks obvious here but it is difficult to find deep inside a big model.

import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors


class ModelWithIssue(torch.nn.Module):
    def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
        caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
        z = x * caty
        return z


inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.8976, 0.6263, 0.6848, 0.7944],
         [0.0745, 0.1040, 0.0168, 0.0408],
         [0.0587, 0.3104, 0.1343, 0.4373]],

        [[0.0235, 0.0430, 0.1726, 0.0675],
         [0.0363, 0.0874, 0.4531, 0.4320],
         [0.0159, 0.0114, 0.0463, 0.1817]]])

Let’s export.

DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
    ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    print(ep)
except Exception as e:
    print("-- ERROR:")
    print(e)
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
-- ERROR:
Constraints violated (L['args'][0][0].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0]) are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).

The error shows:

Constraints violated (L['args'][0][0].size()[0])!
    For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
    are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).

Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:

assert msg != "range_refined_to_singleton", (
    f"A dynamic dimension becomes static! "
    f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)

Stop when a dynamic dimension turns static

We use bypass_export_some_errors to replace torch implementation by a new one raising the exception mentioned in previous section.

with bypass_export_some_errors(stop_if_static=True, verbose=1):
    try:
        torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
        print("-- It failed as excepted.")
        print(f"-- final error is {e}")
        print("-- Stack Trace")
        print(traceback.format_exc())

# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
#   File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
#       z = x * caty
[bypass_export_some_errors] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[_register_cache_serialization] register MambaCache
[_register_cache_serialization] DynamicCache is unregistered first.
[_register_cache_serialization] register DynamicCache
[bypass_export_some_errors] patch sympy
[bypass_export_some_errors] patch pytorch
[bypass_export_some_errors] modifies shape constraints
[bypass_export_some_errors] assert when a dynamic dimension turns static
[bypass_export_some_errors] done patching
/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/backends/mkldnn/__init__.py:78: UserWarning: TF32 acceleration on top of oneDNN is available for Intel GPUs. The current Torch version does not have Intel GPU Support. (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:148.)
  torch._C._set_onednn_allow_tf32(_allow_tf32)
-- It failed as excepted.
-- final error is A dynamic dimension becomes static! a=s0, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
  File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
    torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 274, in export
    return _export(
           ^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1107, in wrapper
    raise e
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1073, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2119, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1107, in wrapper
    raise e
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1073, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 122, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1981, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1924, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1711, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1852, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1631, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2253, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2191, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2162, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 849, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1187, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1751, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 837, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1242, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1535, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1821, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 530, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 805, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1836, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 812, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1821, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 530, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 805, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
    z = x * caty
        ~~^~~~~~
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1290, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1337, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 689, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 875, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1392, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 927, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1295, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1915, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1407, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2424, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 921, in fast_binary_impl
    final_shape = infer_size(final_shape, shape)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 102, in patched_infer_size
    b3 = guard_size_oblivious(sizeA == sizeB)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 410, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 588, in guard_size_oblivious
    r = self.evaluate(size_oblivious=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6700, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 263, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6716, in evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7031, in _evaluate_expr
    self._maybe_guard_rel(g)
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6357, in _maybe_guard_rel
    self._refine_ranges(expr)
  File "/home/xadupre/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7235, in _refine_ranges
    self._set_replacement(
  File "/home/xadupre/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 316, in _set_replacement
    assert msg != "range_refined_to_singleton", (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: A dynamic dimension becomes static! a=s0, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]

[bypass_export_some_errors] remove patches
[bypass_export_some_errors] restored sympy functions
[bypass_export_some_errors] restored pytorch functions
[bypass_export_some_errors] restored ShapeEnv._set_replacement
[bypass_export_some_errors] restored shape constraints
[_unregister_cache_serialization] unregistered MambaCache
[_unregister_cache_serialization] unregistered DynamicCache
doc.plot_legend(
    "dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)
plot export locate issue

Total running time of the script: (0 minutes 0.129 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export a model with a control flow (If)

Export a model with a control flow (If)

Intermediate results with (ONNX) ReferenceEvaluator

Intermediate results with (ONNX) ReferenceEvaluator

Gallery generated by Sphinx-Gallery