Find and fix an export issue due to dynamic shapes

LLMs must be exported with dynamic shapes and it is common that a static dimension turns into a static ones. The error message from pytorch tells the user to define TORCH_LOGS="+dynamic" but it shows a very long list of messages where we need to find the string range_refined_to_singleton and that does not really indicates where it comes from. The example shows how to tweak pytorch to get that information until it gets better.

A model with an export issue

The following model implies the first dimension of x is equal to 1 or equal to the number of element in the list ys. It is not really dynamic. It looks obvious here but it is difficult to find deep inside a big model.

import traceback
import torch
from onnx_diagnostic import doc
from onnx_diagnostic.torch_export_patches import torch_export_patches


class ModelWithIssue(torch.nn.Module):
    def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
        caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
        z = x * caty
        return z


inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
model = ModelWithIssue()
model(*inputs)
tensor([[[0.0026, 0.1810, 0.0759, 0.2596],
         [0.0715, 0.0749, 0.1267, 0.0853],
         [0.4246, 0.0620, 0.3842, 0.0689]],

        [[0.0403, 0.0453, 0.2531, 0.1696],
         [0.6776, 0.6923, 0.3622, 0.0931],
         [0.4267, 0.0761, 0.5997, 0.3721]]])

Let’s export.

DYN = torch.export.Dim.DYNAMIC
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
try:
    ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    print(ep)
except Exception as e:
    print("-- ERROR:")
    print(e)
-- ERROR:
Constraints violated (L['x'].size()[0])! For more information, run with TORCH_LOGS="+dynamic".
  - You marked L['x'].size()[0] as dynamic but your code specialized it to be a constant (2). If you're using mark_dynamic, either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, replace it with either Dim.STATIC or Dim.AUTO.

Framework stack:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/__main__.py", line 7, in <module>
    raise SystemExit(main(sys.argv[1:]))
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/cmd/build.py", line 563, in main
    return build_main(argv)
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/cmd/build.py", line 496, in build_main
    app = Sphinx(
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/application.py", line 295, in __init__
    self._init_builder()
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/application.py", line 369, in _init_builder
    self.events.emit('builder-inited')
  File "??", line 0, in _start
  File "??", line 0, in __libc_start_main
  File "??", line 0, in __libc_init_first
  File "??", line 0, in Py_BytesMain
  File "??", line 0, in Py_RunMain
  File "??", line 0, in _PyInterpreterState_SetRunningMain
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx/events.py", line 404, in emit
    results.append(listener.handler(self.app, *args))
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyDict_DelItemIf
  File "??", line 0, in PyEval_EvalCode
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_gallery.py", line 757, in generate_gallery_rst
    ) = generate_dir_rst(
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _Py_NewReferenceNoTotal
  File "??", line 0, in PyObject_IsTrue
  File "??", line 0, in _PyObject_FastCallDictTstate
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 606, in generate_dir_rst
    results = parallel(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _Py_convert_optional_to_ssize_t
  File "??", line 0, in _PyList_Extend
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 607, in <genexpr>
    p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 1374, in generate_file_rst
    output_blocks, time_elapsed = execute_script(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 1192, in execute_script
    execute_code_block(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 1048, in execute_code_block
    is_last_expr, mem_max = _exec_and_get_memory(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 876, in _exec_and_get_memory
    mem_max, _ = call_memory(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 1725, in _sg_call_memory_noop
    return 0.0, func()
  File "??", line 0, in _PyObject_MakeTpCall
  File "??", line 0, in _PyObject_RealIsInstance
  File "??", line 0, in _PyObject_Call_Prepend
  File "??", line 0, in PyConfig_SetBytesString
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/sphinx_gallery/gen_rst.py", line 794, in __call__
    exec(self.code, self.fake_main.__dict__)
  File "??", line 0, in _PyDict_DelItemIf
  File "??", line 0, in PyEval_EvalCode
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 50, in <module>
    ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
    return _export(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1119, in wrapper
    ep = fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2175, in _export
    ep = _export_for_training(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1119, in wrapper
    ep = fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2036, in _export_for_training
    export_artifact = export_func(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1978, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1769, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1899, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1684, in _make_fx_helper
    gm = make_fx(
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2348, in wrapped
    return make_fx_tracer.trace(f, *args)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2280, in trace
    return self._trace_inner(f, *args)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2251, in _trace_inner
    t = dispatch_trace(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 950, in _fn
    return fn(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1280, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in trace
    res = super().trace(root, concrete_args)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 950, in _fn
    return fn(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1338, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "<string>", line 1, in <lambda>
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1588, in wrapped_fn
    return tuple(flat_fn(*args))
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 1138, in functional_call
    out = mod(*args[params_len:], **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyObject_RealIsInstance
  File "??", line 0, in _PyObject_Call_Prepend
  File "??", line 0, in PyConfig_SetBytesString
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1932, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1883, in forward
    tree_out = mod(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyObject_RealIsInstance
  File "??", line 0, in _PyObject_Call_Prepend
  File "??", line 0, in PyConfig_SetBytesString
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1932, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
    z = x * caty
  File "??", line 0, in PyNumber_Multiply
  File "??", line 0, in PyNumber_Float
  File "??", line 0, in PyBytes_Repr
  File "??", line 0, in PyMember_SetOne
  File "python_variable_methods.cpp", line 0, in _object* torch::autograd::TypeError_to_NotImplemented_<&torch::autograd::THPVariable_mul>(_object*, _object*, _object*)
  File "python_variable_methods.cpp", line 0, in torch::autograd::THPVariable_mul(_object*, _object*, _object*)
  File "", line 0, in torch::handle_torch_function(torch::PythonArgs&, _object*, _object*, _object*, _object*, char const*, char const*)
  File "??", line 0, in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName)
  File "??", line 0, in PyObject_CallMethod
  File "??", line 0, in PyList_Sort
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1386, in __torch_function__
    return func(*args, **kwargs)
  File "??", line 0, in PyMember_SetOne
  File "python_variable_methods.cpp", line 0, in torch::autograd::THPVariable_mul(_object*, _object*, _object*)
  File "", line 0, in torch::handle_torch_function(torch::PythonArgs&, _object*, _object*, _object*, _object*, char const*, char const*)
  File "??", line 0, in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName)
  File "??", line 0, in PyObject_CallMethod
  File "??", line 0, in PyList_Sort
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1433, in __torch_function__
    return func(*args, **kwargs)
  File "??", line 0, in PyMember_SetOne
  File "python_variable_methods.cpp", line 0, in torch::autograd::THPVariable_mul(_object*, _object*, _object*)
  File "", line 0, in torch::handle_torch_function(torch::PythonArgs&, _object*, _object*, _object*, _object*, char const*, char const*)
  File "??", line 0, in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName)
  File "??", line 0, in PyObject_CallMethod
  File "??", line 0, in PyList_Sort
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 1054, in __torch_function__
    return func(*args, **kwargs)
  File "??", line 0, in PyMember_SetOne
  File "python_variable_methods.cpp", line 0, in torch::autograd::THPVariable_mul(_object*, _object*, _object*)
  File "??", line 0, in at::_ops::mul_Tensor::call(at::Tensor const&, at::Tensor const&)
  File "", line 0, in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
  File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 950, in handler
    return torch._library.utils.handle_dispatch_mode(
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1488, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
    out = func(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyObject_RealIsInstance
  File "??", line 0, in _PyObject_Call_Prepend
  File "??", line 0, in PyConfig_SetBytesString
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in PyObject_GenericGetAttr
  File "", line 0, in pybind11::cpp_function::dispatcher(_object*, _object*, _object*)
  File "init.cpp", line 0, in pybind11::cpp_function::initialize<torch::jit::initJITBindings(_object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#2}::operator()(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) const::{lambda(pybind11::args const&, pybind11::kwargs const&)#1}, pybind11::object, pybind11::args const&, pybind11::kwargs const&>(torch::jit::initJITBindings(_object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#2}::operator()(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) const::{lambda(pybind11::args const&, pybind11::kwargs const&)#1}&&, pybind11::object (*)(pybind11::args const&, pybind11::kwargs const&))::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&)
  File "init.cpp", line 0, in torch::jit::initJITBindings(_object*)::{lambda(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)#2}::operator()(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) const::{lambda(pybind11::args const&, pybind11::kwargs const&)#1}::operator()(pybind11::args const&, pybind11::kwargs const&) const
  File "??", line 0, in torch::jit::_get_operation_for_overload_or_packet(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, c10::Symbol, pybind11::args const&, pybind11::kwargs const&, bool, std::optional<c10::DispatchKey>)
  File "??", line 0, in torch::jit::invokeOperatorFromPython(std::vector<std::shared_ptr<torch::jit::Operator>, std::allocator<std::shared_ptr<torch::jit::Operator> > > const&, pybind11::args const&, pybind11::kwargs const&, std::optional<c10::DispatchKey>)
  File "register_c10_ops.cpp", line 0, in c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const [clone .isra.0]
  File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
  File "", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const
  File "PythonFallbackKernel.cpp", line 0, in void c10::BoxedKernel::make_boxed_function<&(anonymous namespace)::pythonTLSSnapshotFallback>(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
  File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
  File "", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const
  File "VariableType_0.cpp", line 0, in c10::impl::make_boxed_from_unboxed_functor<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (c10::DispatchKeySet, at::Tensor const&, at::Tensor const&), &torch::autograd::VariableType::(anonymous namespace)::mul_Tensor>, at::Tensor, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&, at::Tensor const&> >, false>::call(c10::OperatorKernel*, c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
  File "VariableType_0.cpp", line 0, in torch::autograd::VariableType::(anonymous namespace)::mul_Tensor(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
  File "??", line 0, in at::_ops::mul_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
  File "", line 0, in c10::impl::BoxedKernelWrapper<at::Tensor (at::Tensor const&, at::Tensor const&), void>::call(c10::BoxedKernel const&, c10::OperatorHandle const&, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&)
  File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::python_dispatcher(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
  File "", line 0, in c10::OperatorHandle::callBoxedForDispatchKey(c10::DispatchKey, std::vector<c10::IValue, std::allocator<c10::IValue> >&) const
  File "PythonFallbackKernel.cpp", line 0, in (anonymous namespace)::pythonFallback(c10::OperatorHandle const&, c10::DispatchKeySet, std::vector<c10::IValue, std::allocator<c10::IValue> >*)
  File "PyInterpreter.cpp", line 0, in torch::detail::(anonymous namespace)::ConcretePyInterpreterVTable::dispatch(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const
  File "??", line 0, in torch::handle_torch_function_no_python_arg_parser(c10::ArrayRef<_object*>, _object*, _object*, char const*, _object*, char const*, torch::TorchFunctionName)
  File "??", line 0, in PyObject_CallMethod
  File "??", line 0, in PyList_Sort
  File "??", line 0, in Py_CompileStringExFlags
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2068, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2587, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 962, in fast_binary_impl
    final_shape = infer_size(final_shape, shape)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 922, in infer_size
    torch._check(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/__init__.py", line 1702, in _check
    _check_with(RuntimeError, cond, message)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/__init__.py", line 1665, in _check_with
    if expect_true(cond):
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 1712, in expect_true
    return a.node.expect_true(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 559, in expect_true
    return self.shape_env.guard_or_defer_runtime_assert(
  File "??", line 0, in _PyObject_MakeTpCall
  File "??", line 0, in _Py_GetLocaleEncodingObject
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
  File "??", line 0, in _PyObject_Call
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7740, in guard_or_defer_runtime_assert
    self._maybe_guard_rel(expr)
  File "??", line 0, in _PyObject_MakeTpCall
  File "??", line 0, in _Py_GetLocaleEncodingObject
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6878, in _maybe_guard_rel
    self._refine_ranges(expr)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7829, in _refine_ranges
    self._set_replacement(
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6778, in _set_replacement
    CapturedTraceback.extract(cpp=True)
  File "??", line 0, in PyObject_Vectorcall
  File "??", line 0, in _PyEval_EvalFrameDefault
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_traceback.py", line 212, in extract
    torch._C._profiler.gather_traceback(python=True, script=script, cpp=cpp),
  File "??", line 0, in _PyObject_MakeTpCall
  File "??", line 0, in PyObject_GenericGetAttr
  File "", line 0, in pybind11::cpp_function::dispatcher(_object*, _object*, _object*)
  File "", line 0, in pybind11::cpp_function::initialize<std::shared_ptr<torch::CapturedTraceback> (*&)(bool, bool, bool), std::shared_ptr<torch::CapturedTraceback>, bool, bool, bool, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg_v, pybind11::arg_v, pybind11::arg_v>(std::shared_ptr<torch::CapturedTraceback> (*&)(bool, bool, bool), std::shared_ptr<torch::CapturedTraceback> (*)(bool, bool, bool), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::arg_v const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::operator()(pybind11::detail::function_call&) const
  File "??", line 0, in torch::CapturedTraceback::gather(bool, bool, bool)
  File "??", line 0, in torch::unwind::unwind()


The error above occurred when calling torch.export.export. If you would like to view some more information about this error, and get a list of all other errors that may occur in your export call, you can replace your `export()` call with `draft_export()`.

The error shows:

Constraints violated (L['args'][0][0].size()[0])!
    For more information, run with TORCH_LOGS="+dynamic".
- Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
    are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).

Where does it happens? That’s a tricky question we need to answer. The message is raised from torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement. One way to find the exact location is to retrieve a stack trace by inserting an assert such as the following:

assert msg != "range_refined_to_singleton", (
    f"A dynamic dimension becomes static! "
    f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
)

Stop when a dynamic dimension turns static

We use torch_export_patches to replace torch implementation by a new one raising the exception mentioned in previous section.

with torch_export_patches(stop_if_static=1, verbose=1):
    try:
        torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
    except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e:
        print("-- It failed as excepted.")
        print(f"-- final error is {e}")
        print("-- Stack Trace")
        print(traceback.format_exc())

# The stack trace is quite long but the first line referring to this example
# is the following one. It points out the line turing a dynamic dimension into
# static.
#
# .. code-block::
#
#   File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
#       z = x * caty
[torch_export_patches] replace torch.jit.isinstance, torch._dynamo.mark_static_address
[torch_export_patches] sympy.__version__='1.13.3'
[torch_export_patches] patch sympy
[torch_export_patches] torch.__version__='2.9.0.dev20250628+cu126'
[torch_export_patches] stop_if_static=1
[torch_export_patches] patch pytorch
[torch_export_patches] modifies shape constraints
[torch_export_patches] assert when a dynamic dimension turns static
[torch_export_patches] replaces ShapeEnv._set_replacement
[torch_export_patches] replaces ShapeEnv._log_guard
[torch_export_patches] done patching
-- It failed as excepted.
-- final error is patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]
-- Stack Trace
Traceback (most recent call last):
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 90, in <module>
    torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
    return _export(
           ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1153, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1119, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2175, in _export
    ep = _export_for_training(
         ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1153, in wrapper
    raise e
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1119, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2036, in _export_for_training
    export_artifact = export_func(
                      ^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1978, in _non_strict_export
    aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1769, in _export_to_aten_ir_make_fx
    gm, graph_signature = transform(_make_fx_helper)(
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1899, in _aot_export_non_strict
    gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1684, in _make_fx_helper
    gm = make_fx(
         ^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2348, in wrapped
    return make_fx_tracer.trace(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2280, in trace
    return self._trace_inner(f, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2251, in _trace_inner
    t = dispatch_trace(
        ^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 53, in inner
    return disable_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 950, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1280, in dispatch_trace
    graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in trace
    res = super().trace(root, concrete_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 950, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
    (self.create_arg(fn(*args)),),
                     ^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1338, in wrapped
    out = f(*tensors)  # type:ignore[call-arg]
          ^^^^^^^^^^^
  File "<string>", line 1, in <lambda>
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1588, in wrapped_fn
    return tuple(flat_fn(*args))
                 ^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 1138, in functional_call
    out = mod(*args[params_len:], **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1932, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1883, in forward
    tree_out = mod(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
    return self.call_module(mod, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1932, in call_module
    return Tracer.call_module(self, m, forward, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
    ret_val = forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
    return _orig_module_call(mod, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 35, in forward
    z = x * caty
        ~~^~~~~~
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1386, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1433, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 1054, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 950, in handler
    return torch._library.utils.handle_dispatch_mode(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode
    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1488, in __torch_dispatch__
    return proxy_call(self, func, self.pre_dispatch, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 974, in proxy_call
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 829, in __call__
    return self._op(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/utils/_stats.py", line 28, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1352, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2068, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 1487, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2587, in _dispatch_impl
    return maybe_propagate_real_tensors(fast_impl(self, *args, **kwargs))
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/_subclasses/fake_impls.py", line 962, in fast_binary_impl
    final_shape = infer_size(final_shape, shape)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 119, in patched_infer_size
    b3 = guard_size_oblivious(sizeA == sizeB)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 476, in guard_size_oblivious
    return expr.node.guard_size_oblivious("", 0)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 596, in guard_size_oblivious
    r = self.evaluate(size_oblivious=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/sym_node.py", line 512, in evaluate
    return self.shape_env.evaluate_sym_node(self, size_oblivious)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7233, in evaluate_sym_node
    return self.evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7333, in evaluate_expr
    return self._inner_evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/recording.py", line 272, in wrapper
    return retlog(fn(*args, **kwargs))
                  ^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7356, in _inner_evaluate_expr
    return self._evaluate_expr(
           ^^^^^^^^^^^^^^^^^^^^
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7625, in _evaluate_expr
    self._maybe_guard_rel(g)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6878, in _maybe_guard_rel
    self._refine_ranges(expr)
  File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 7829, in _refine_ranges
    self._set_replacement(
  File "~/github/onnx-diagnostic/onnx_diagnostic/torch_export_patches/patches/patch_torch.py", line 343, in _set_replacement
    assert msg != "range_refined_to_singleton", (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: patched_ShapeEnv: A dynamic dimension becomes static! a=s77, tgt=2, msg='range_refined_to_singleton', tgt_bound=VR[2, 2]

[torch_export_patches] remove patches
[torch_export_patches] restored sympy functions
[torch_export_patches] restored pytorch functions
[torch_export_patches] restored ShapeEnv._set_replacement
[torch_export_patches] restored ShapeEnv._log_guard
[torch_export_patches] restored shape constraints
doc.plot_legend(
    "dynamic dimension\nwas inferred\nto be a constant", "torch.export.export", "tomato"
)
plot export locate issue

Total running time of the script: (0 minutes 3.791 seconds)

Related examples

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Export microsoft/phi-2

Export microsoft/phi-2

Intermediate results with (ONNX) ReferenceEvaluator

Intermediate results with (ONNX) ReferenceEvaluator

Gallery generated by Sphinx-Gallery