.torch_interpreter.oxs_dispatcher¶
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDebugDispatcher(verbose: int = 0, raise_exc: bool = True)[source]¶
Tries the fallback even if is not necessary to check it is working.
- Parameters:
verbose – verbosity
raise_exc – fail or raise an exception
The class can be used the following way.
<<<
import torch from experimental_experiment.ext_test_case import get_llama_model from experimental_experiment.xbuilder import OptimizationOptions from experimental_experiment.torch_interpreter import to_onnx from experimental_experiment.torch_interpreter.oxs_dispatcher import ( OxsDebugDispatcher, ) with torch.no_grad(): model, input_tensors = get_llama_model() input_tensors = input_tensors[0] to_onnx( model, input_tensors, input_names=[f"input{i}" for i in range(len(input_tensors))], options=OptimizationOptions(patterns=None), verbose=0, dispatcher=OxsDebugDispatcher(verbose=2, raise_exc=False), )
>>>
[runpythonerror] use_kernel_func_from_hub is not available in the installed kernels version. Please upgrade kernels to use this feature. Traceback (most recent call last): File "<stdin>", line 19, in <module> File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 1025, in to_onnx graph_module, builder, interpreter, mask_outputs = _make_builder_interpreter( ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/onnx_export.py", line 583, in _make_builder_interpreter exported_program = export_options.export( ^^^^^^^^^^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 692, in export exported_program = self._export( ^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/torch_interpreter/export_options.py", line 335, in _export return torch_export( ^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/export_helpers.py", line 73, in torch_export return torch_export( ^^^^^^^^^^^^^ File "~/github/experimental-experiment/experimental_experiment/export_helpers.py", line 164, in torch_export return torch.export.export( ^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 205, in export raise e File "~/vv/this312/lib/python3.12/site-packages/torch/export/__init__.py", line 171, in export return _export( ^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1343, in wrapper raise e File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1309, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2504, in _export ep = _export_for_training( ^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1343, in wrapper raise e File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1309, in wrapper ep = fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/exported_program.py", line 124, in wrapper return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2293, in _export_for_training export_artifact = export_func( ^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2222, in _non_strict_export aten_export_artifact = _to_aten_func( ^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1999, in _export_to_aten_ir_make_fx gm, graph_signature = transform(_make_fx_helper)( ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 2129, in _aot_export_non_strict gm, sig = aot_export(stack, wrapped_mod, args, kwargs=kwargs, **flags) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1910, in _make_fx_helper gm = make_fx( ^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2778, in wrapped return make_fx_tracer.trace(f, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2683, in trace return self._trace_inner(f, *args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2645, in _trace_inner t = dispatch_trace( ^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner return disable_fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1226, in _fn return fn(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1508, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 2230, in trace res = super().trace(root, concrete_args) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/_symbolic_trace.py", line 875, in trace (self.create_arg(fn(*args)),), ^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/fx/experimental/proxy_tensor.py", line 1574, in wrapped out = f(*tensors) # type:ignore[call-arg] ^^^^^^^^^^^ File "<string>", line 1, in <lambda> File "~/vv/this312/lib/python3.12/site-packages/torch/export/_trace.py", line 1794, in wrapped_fn return tuple(flat_fn(*args)) ^^^^^^^^^^^^^^ File "~/vv/this312/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/utils.py", line 199, in flat_fn raise RuntimeError( RuntimeError: Found <class 'transformers.cache_utils.DynamicCache'> in output, which is not a known type. If this type holds tensors, you need to register a pytree for it. See https://github.com/pytorch/functorch/issues/475 for a brief explanation why. If you don't need to register a pytree, please leave a comment explaining your use case and we'll make this more ergonomic to deal with- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable
- class experimental_experiment.torch_interpreter.oxs_dispatcher.OxsDispatcher(verbose: int = 0)[source]¶
If
DynamoInterpretercannot find any converting function for a specific function, it tries to find an existing one in onnxscript. The converting function from onnxscript is run in trace only mode. The variable and functions op, Rank, IsScalar are replaced by op = OwsOpset(), op.Rank, op.Scalar. onnxscript may have multiple overloaded functions. Right now, it takes the first one.- Parameters:
verbose – verbose
- fallback(name: Any, fct: Callable | None, args: List[Any], kwargs: Dict[str, Any], builder: GraphBuilder) Callable | None[source]¶
The function is called after the function converting an aten function into ONNX. fct is this function. It can be changed and just set when mapping was found.
- Parameters:
name – object or str
fct – function found so far
args – known arguments coming from the graph module
kwargs – known named arguments coming from the graph module
builder – GraphBuilder
- Returns:
callable