onnx_diagnostic.reference

ExtendedReferenceEvaluator

class onnx_diagnostic.reference.ExtendedReferenceEvaluator(proto: Any, opsets: Dict[str, int] | None = None, functions: List[ReferenceEvaluator | FunctionProto] | None = None, verbose: int = 0, new_ops: List[type[OpRun]] | None = None, **kwargs)[source]

This class replaces the python implementation by custom implementation. The evaluator allows to test scenarios outside what an onnx backend bound to the official onnx operators definition could do such as optimization patterns involving onnxruntime contrib operators.

from onnx_diagnostic.reference import ExtendedReferenceEvaluator
ref = ExtendedReferenceEvaluator(...)

The class overloads or adds the following operators by default:

<<<

import pprint
from onnx_diagnostic.reference import ExtendedReferenceEvaluator

pprint.pprint(ExtendedReferenceEvaluator.default_ops)

>>>

    [<class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.AddAdd'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.AddMul'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.AddSharedInput'>,
     <class 'onnx_diagnostic.reference.ops.op_attention.Attention'>,
     <class 'onnx_diagnostic.reference.ops.op_average_pool_grad.AveragePoolGrad'>,
     <class 'onnx_diagnostic.reference.ops.op_bias_softmax.BiasSoftmax'>,
     <class 'onnx_diagnostic.reference.ops.op_concat.Concat'>,
     <class 'onnx_diagnostic.reference.ops.op_cast_like.CastLike_15'>,
     <class 'onnx_diagnostic.reference.ops.op_cast_like.CastLike_19'>,
     <class 'onnx_diagnostic.reference.ops.op_complex.ComplexModule'>,
     <class 'onnx_diagnostic.reference.ops.op_constant_of_shape.ConstantOfShape'>,
     <class 'onnx_diagnostic.reference.ops.op_fused_matmul.FusedMatMul'>,
     <class 'onnx_diagnostic.reference.ops.op_gather.Gather'>,
     <class 'onnx_diagnostic.reference.ops.op_gather_elements.GatherElements'>,
     <class 'onnx_diagnostic.reference.ops.op_gather_grad.GatherGrad'>,
     <class 'onnx_diagnostic.reference.ops.op_scatternd_of_shape.MaskedScatterNDOfShape'>,
     <class 'onnx_diagnostic.reference.ops.op_memcpy_host.MemcpyFromHost'>,
     <class 'onnx_diagnostic.reference.ops.op_memcpy_host.MemcpyToHost'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.MulAdd'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.MulMul'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.MulSharedInput'>,
     <class 'onnx_diagnostic.reference.ops.op_mul_sigmoid.MulSigmoid'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.MulSub'>,
     <class 'onnx_diagnostic.reference.ops.op_negxplus1.NegXplus1'>,
     <class 'onnx_diagnostic.reference.ops.op_qlinear_conv.QLinearConv'>,
     <class 'onnx_diagnostic.reference.ops.op_qlinear_average_pool.QLinearAveragePool'>,
     <class 'onnx_diagnostic.reference.ops.op_quick_gelu.QuickGelu'>,
     <class 'onnx_diagnostic.reference.ops.op_replace_zero.ReplaceZero'>,
     <class 'onnx_diagnostic.reference.ops.op_rotary.Rotary'>,
     <class 'onnx_diagnostic.reference.ops.op_scatter_elements.ScatterElements'>,
     <class 'onnx_diagnostic.reference.ops.op_scatternd_of_shape.ScatterNDOfShape'>,
     <class 'onnx_diagnostic.reference.ops.op_simplified_layer_normalization.SimplifiedLayerNormalization'>,
     <class 'onnx_diagnostic.reference.ops.op_skip_layer_normalization.SkipLayerNormalization'>,
     <class 'onnx_diagnostic.reference.ops.op_slice.Slice_1'>,
     <class 'onnx_diagnostic.reference.ops.op_slice.Slice_10'>,
     <class 'onnx_diagnostic.reference.ops.op_add_add_mul_mul.SubMul'>,
     <class 'onnx_diagnostic.reference.ops.op_complex.ToComplex'>,
     <class 'onnx_diagnostic.reference.ops.op_transpose_cast.Transpose2DCastFP16'>,
     <class 'onnx_diagnostic.reference.ops.op_transpose_cast.Transpose2DCastFP32'>,
     <class 'onnx_diagnostic.reference.ops.op_tri_matrix.TriMatrix'>]
run(*args, **kwargs)[source]

See onnx.reference.ReferenceEvaluator.run().

OnnxruntimeEvaluator

class onnx_diagnostic.reference.OnnxruntimeEvaluator(proto: str | FunctionProto | ModelProto | GraphProto | NodeProto | OnnxruntimeEvaluator, session_options: SessionOptions | None = None, providers: str | List[str] | None = None, nvtx: bool = False, enable_profiling: bool = False, graph_optimization_level: GraphOptimizationLevel | bool = None, log_severity_level: int | None = None, log_verbosity_level: int | None = None, optimized_model_filepath: str | None = None, disable_aot_function_inlining: bool | None = None, use_training_api: bool = False, verbose: int = 0, local_functions: Dict[Tuple[str, str], FunctionProto | ModelProto | GraphProto | NodeProto | OnnxruntimeEvaluator] | None = None, ir_version: int = 10, opsets: int | Dict[str, int] | None = None)[source]

This class loads an onnx model and the executes one by one the nodes with onnxruntime. This class is mostly meant for debugging.

Parameters:
  • proto – proto or filename

  • session_options – options

  • providers – providers

  • nvtx – enable nvidia events

  • providersNone, “CPU”, “CUDA” or a list of providers

  • graph_optimization_level – see onnxruntime.SessionOptions

  • log_severity_level – see onnxruntime.SessionOptions

  • log_verbosity_level – see onnxruntime.SessionOptions

  • optimized_model_filepath – see onnxruntime.SessionOptions

  • disable_aot_function_inlining – see onnxruntime.SessionOptions

  • use_training_api – use onnxruntime-traning API

  • verbose – verbosity

  • local_functions – additional local function

  • ir_version – ir version to use when unknown

  • opsets – opsets to use when unknown

property input_names: List[str]

Returns input names.

property input_types: List[TypeProto]

Returns input types.

property output_names: List[str]

Returns output names.

property output_types: List[TypeProto]

Returns output types.

run(outputs: List[str] | None, feed_inputs: Dict[str, Any], intermediate: bool = False) Dict[str, Any] | List[Any][source]

Runs the model. It only works with numpy arrays.

Parameters:
  • outputs – required outputs or None for all

  • feed_inputs – inputs

  • intermediate – returns all output instead of the last ones

Returns:

outputs, as a list if return_all is False, as a dictionary if return_all is True

Other functions