Find and fix an export issue due to dynamic shapes

LLMs must be exported with dynamic shapes and it is common that a static dimension turns into a static ones. The error message from pytorch tells the user to define TORCH_LOGS="+dynamic" but it shows a very long list of messages where we need to find the string range_refined_to_singleton and that does not really indicates where it comes from. The example shows how to tweak pytorch to get that information until it gets better.

A model with an export issue

The following model implies the first dimension of x is equal to 1 or equal to the number of element in the list ys. It is not really dynamic. It looks obvious here but it is difficult to find deep inside a big model.

import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_patches


class ModelWithIssue(torch.nn.Module):
    def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
        caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
        z = x * caty
        return z


inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.0979, 0.1511, 0.0430, 0.0020],
         [0.1890, 0.4872, 0.5730, 0.6068],
         [0.7902, 0.4932, 0.1538, 0.8080]],

        [[0.0049, 0.0039, 0.0048, 0.0021],
         [0.0868, 0.0329, 0.1637, 0.1236],
         [0.8890, 0.1376, 0.6391, 0.6488]]])

Let’s export.

DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
    ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    print(ep)
except Exception as e:
    print("-- ERROR:")
    print(e)
-- ERROR:
Found the following conflicts between user-specified ranges and inferred ranges from model tracing:
- Received user-specified dim hint Dim.DYNAMIC(min=None, max=None), but tracing inferred a static shape of 2 for dimension inputs['x'].shape[0].

The error shows:

Constraints violated (L['args'][0][0].size()[0])!
    For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
    are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).

Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:

assert msg != "range_refined_to_singleton", (
    f"A dynamic dimension becomes static! "
    f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)

Stop when a dynamic dimension turns static

We use torch_export_patches to replace torch implementation by a new one raising the exception mentioned in previous section.

with torch_export_patches(stop_if_static=1, verbose=1):
    try:
        torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    except (AssertionError, ValueError, torch._dynamo.exc.TorchRuntimeError) as e:
        print("-- It failed as excepted.")
        print(f"-- final error is {e}")
        print("-- Stack Trace")
        print(traceback.format_exc())

# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
#   File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
#       z = x * caty
[torch_export_patches] patch_sympy=True
                     . patch_torch=True
                     . patch_transformers=False
                     . patch_diffusers=False
                     . catch_constraints=True
                     . stop_if_static=1
                     . patch=True
                     . custom_patches=None
[torch_export_patches] dump_rewriting=None
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[torch_export_patches] sympy.__version__='1.14.0'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.10.0.dev20251106+cu130'
[torch_export_patches] stop_if_static=1
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] assert when a dynamic dimension turns static
[torch_export_patches] replaces ShapeEnv._set_replacement
[torch_export_patches] replaces ShapeEnv._log_guard
[torch_export_patches] done patching
-- It failed as excepted.
-- final error is patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
    torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 311, in export
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 277, in export
    return _export(
           ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1229, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1195, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2334, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1229, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1195, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2142, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2073, in _non_strict_export
    aten_export_artifact = _to_aten_func(
                           ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1861, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1991, in _aot_export_non_strict
    gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1773, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2603, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2528, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2490, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1138, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1451, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2079, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 869, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1517, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1660, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 193, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/graph_capture_wrappers.py", line 1369, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 844, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2166, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 837, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 844, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2166, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 560, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 837, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
    z = x * caty
        ~~^~~~~~
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1566, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1637, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 1118, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 958, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 313, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1692, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1136, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 837, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1390, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2148, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1537, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2674, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 1337, in fast_binary_impl
    return slow("both tensors nontrivially broadcast")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 1297, in slow
    return slow_ref(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 315, in _fn
    result = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1138, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_prims_common/wrappers.py", line 152, in _fn
    result = fn(**bound.arguments)
             ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_refs/__init__.py", line 1142, in _ref
    a, b = _maybe_broadcast(a, b)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 929, in patched__maybe_broadcast
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 929, in <genexpr>
    return tuple(__maybe_broadcast(x, common_shape) for x in args)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 923, in __maybe_broadcast
    return x.expand(common_shape)
           ^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 29, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1390, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2148, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1537, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2700, in _dispatch_impl
    decomposition_table[func](*args, **kwargs)
  File "~/vv/this312/lib/python3.12/site-packages/torch/_refs/__init__.py", line 3116, in expand
    return prims.broadcast_in_dim(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 996, in patched__broadcast_in_dim_meta
    torch._check(
  File "~/vv/this312/lib/python3.12/site-packages/torch/__init__.py", line 1718, in _check
    _check_with(RuntimeError, cond, message)  # pyrefly: ignore [bad-argument-type]
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/__init__.py", line 1681, in _check_with
    if expect_true(cond):
       ^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1771, in expect_true
    return a.node.expect_true(
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 554, in expect_true
    return self.guard_bool(file, line)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 538, in guard_bool
    r = self.evaluate()
        ^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7330, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7430, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 273, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7453, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 642, in _evaluate_expr
    self._maybe_guard_rel(g)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6969, in _maybe_guard_rel
    self._refine_ranges(expr)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7924, in _refine_ranges
    self._set_replacement(
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 378, in _set_replacement
    assert msg != "range_refined_to_singleton", (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]

[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored ShapeEnv._set_replacement
[torch_export_patches] restored ShapeEnv._log_guard
[torch_export_patches] restored shape constraints
doc.plot_legend(
    "dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)
plot export locate issue

Total running time of the script: (0 minutes 0.217 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Export with dynamic dimensions in {0,1}

Export with dynamic dimensions in {0,1}

Gallery generated by Sphinx-Gallery