.torch_interpreter.oxs_dispatcher

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]

Tries the fallback even if is not necessary to check it is working.

Parameters:
  • verbose – verbosity

  • raise_exc – fail or raise an exception

The class can be used the following way.

<<<

import torch
from experimental_experiment.torch_models.llama_helper import get_llama_model
from experimental_experiment.xbuilder import OptimizationOptions
from experimental_experiment.torch_interpreter import to_onnx
from experimental_experiment.torch_interpreter.oxs_dispatcher import (
    OxsDebugDispatcher,
)

with torch.no_grad():
    model, input_tensors = get_llama_model()
    input_tensors = input_tensors[0]

    to_onnx(
        model,
        input_tensors,
        input_names=[f"input{i}" for i in range(len(input_tensors))],
        options=OptimizationOptions(patterns=None),
        verbose=0,
        dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False),
    )

>>>

    
    [runpythonerror]
    Traceback (most recent call last):
      File "<stdin>", line 19, in <module>
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 957, in to_onnx
        graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter(
                                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 573, in _make_builder_interpreter
        exported_program = export_options.export(
                           ^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 471, in export
        exported_program = torch.export.export(
                           ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 319, in export
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 286, in export
        return _export(
               ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2172, in _export
        ep = _export_for_training(
             ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1159, in wrapper
        raise e
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1125, in wrapper
        ep = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 123, in wrapper
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2033, in _export_for_training
        export_artifact = export_func(
                          ^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1975, in _non_strict_export
        aten_export_artifact = _to_aten_func(  # type: ignore[operator]
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1760, in _export_to_aten_ir_make_fx
        gm, graph_signature = transform(_make_fx_helper)(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1901, in _aot_export_non_strict
        gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1679, in _make_fx_helper
        gm = make_fx(
             ^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2295, in wrapped
        return make_fx_tracer.trace(f, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2233, in trace
        return self._trace_inner(f, *args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2204, in _trace_inner
        t = dispatch_trace(
            ^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 51, in inner
        return disable_fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 893, in _fn
        return fn(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1221, in dispatch_trace
        graph = tracer.trace(root, concrete_args)  # type: ignore[arg-type]
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1792, in trace
        res = super().trace(root, concrete_args)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 850, in trace
        (self.create_arg(fn(*args)),),
                         ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1279, in wrapped
        out = f(*tensors)  # type:ignore[call-arg]
              ^^^^^^^^^^^
      File "<string>", line 1, in <lambda>
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1583, in wrapped_fn
        return tuple(flat_fn(*args))
                     ^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
        tree_out = fn(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 906, in functional_call
        out = mod(*args[params_len:], **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1885, in forward
        tree_out = mod(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/experimental-experiment/experimental_experiment/torch_models/llama_helper.py", line 75, in forward
        model_output = self.model(
                       ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 825, in module_call_wrapper
        return self.call_module(mod, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1862, in call_module
        return Tracer.call_module(self, m, forward, args, kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 542, in call_module
        ret_val = forward(*args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 818, in forward
        return _orig_module_call(mod, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1767, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1778, in _call_impl
        return forward_call(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/utils/generic.py", line 943, in wrapper
        output = func(self, *args, **kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/models/llama/modeling_llama.py", line 422, in forward
        causal_mask = create_causal_mask(
                      ^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 719, in create_causal_mask
        causal_mask = mask_interface(
                      ^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 466, in eager_mask
        mask = sdpa_mask(
               ^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 350, in sdpa_mask_recent_torch
        causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/apis.py", line 202, in wrapped
        return vmap_impl(
               ^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 334, in vmap_impl
        return _flat_vmap(
               ^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/vmap.py", line 484, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 48, in and_mask
        result = result & mask(batch_idx, head_idx, q_idx, kv_idx)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/github/transformers/src/transformers/masking_utils.py", line 118, in inner_mask
        return padding_mask[batch_idx, kv_idx]
               ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 141, in __torch_function__
        return mod_index(args[0], index_args)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/autograd/function.py", line 589, in apply
        return custom_function_call(cls, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1327, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1374, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 492, in wrapper
        return torch.overrides.handle_torch_function(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/overrides.py", line 1725, in handle_torch_function
        result = mode.__torch_function__(public_api, types, args, kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_export/non_strict_utils.py", line 976, in __torch_function__
        return func(*args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/autograd_function.py", line 49, in __call__
        return super().__call__(autograd_function, *args, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 501, in __call__
        return wrapper()
               ^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 497, in wrapper
        return self.dispatch(
               ^^^^^^^^^^^^^^
      File "~/vv/this312/lib/python3.12/site-packages/torch/_ops.py", line 458, in dispatch
        assert type(curr_mode) in self.python_key_table, (
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    AssertionError: Current active mode <torch.fx.experimental.proxy_tensor.ProxyTorchDispatchMode object at 0x766f5a336000> not registered
fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]

If DynamoInterpreter cannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.

Parameters:

verbose – verbose

fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]

The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.

Parameters:
  • name – object or str

  • fct – function found so far

  • args – known arguments coming from the graph module

  • kwargs – known named arguments coming from the graph module

  • builder – GraphBuilder

Returns:

callable

property submodules: Dict[str, Callable]

Returns the submodules implementing torch functions.