experimental_experiment.torch_bench._bash_bench_benchmark_runner

class experimental_experiment.torch_bench._bash_bench_benchmark_runner.BenchmarkRunner(suite_name: str, device: str, partition_id: int = 0, total_partitions: int = 1, include_model_names: Set[str] | None = None, exclude_model_names: Set[str] | None = None, training: bool = False, use_eval_mode: bool = False, enable_activation_checkpointing: bool = False, dtype: str | dtype | None = None, verbose: int = 0, warmup: int = 10, repeat: int = 30, fake_tensor: bool = False, no_grad: bool = True, target_opset: int = 18, nvtx: bool = False, dump_ort: bool = False)[source]

Class running the benchmark.

Parameters:
  • suite_name – suite name

  • device – device

  • partition_id – partition id

  • total_partition – number of total partition

  • include_model_names – models to include

  • exclude_model_names – models to exclude

  • training – training mode (CHECK)

  • use_eval_mode – use eval mode (CHECK)

  • enable_activation_checkpointing – (CHECK)

  • dtype – default dtype (None to change nothing)

  • verbose – verbosity

  • warmup – number of iteration to warmup the model

  • repeat – number of iteration to repeat the model

  • fake_tensor – use fake_tensor

  • no_grad – use no_grad

  • target_opset – target opset

  • nvtx – add events to profile

  • dump_ort – dumps onnxruntime optimized graph

enumerate_load_models() Iterator[Tuple[Any, Any]][source]

Loads the models and returns them.

enumerate_model_names(model_names: List[str], start: int = 0, end: int = -1) Iterator[str][source]

Enumerates model names.

Parameters:
  • model_names – list of names

  • start – index of the first model

  • end – index of the last model (excluded) or -1 for the end

Returns:

iterator

The method uses self.include_model_names and self.exclude_model_names to filter in or out the models to run.

enumerate_run_models() Iterator[Any][source]

Runs the models once.

enumerate_test_models(exporter: str, process: bool = False, folder: str = 'dump_test_models', dynamic: bool = False, optimization: str = '', quiet: bool = True, memory_peak: bool = False, part: int | None = None, pickled_name: str | None = None, rtopt: bool = True, shape_again: bool = False) Iterator[Dict[Any, Any]][source]

Runs the benchmarks, run, export, run in onnx, measure the speedup.

Parameters:
  • exporter – exporter to run

  • process – unused

  • folder – where to dump the models

  • dynamic – unused now

  • optimization – optimization string to run

  • quiet – True to catch exception

  • memory_peak – True to measure the memory peak in a secondary process

  • part – None to run both path, 1 to run the first part (load the model + eager mode + export), 2 to run the run the inference

  • pickled_name – name used to store everything on disk if part is True

  • rtopt – disable onnxruntime optimization

  • shape_again – run shape inference after the export, erases whatever the model already contains

get_benchmark_indices(length)[source]

Returns the model indices in the benchmark to run.

get_model_name_list() List[str][source]

Returns the model list.

classmethod max_diff(expected: Any, got: Any, verbose: int = 0, level: int = 0, flatten: bool = False, debug_info: List[str] | None = None, begin: int = 0, end: int = -1, _index: int = 0) Dict[str, float][source]

Returns the maximum discrepancy.

Parameters:
  • expected – expected values

  • got – values

  • verbose – verbosity level

  • level – for embedded outputs, used for debug purpposes

  • flatten – flatten outputs

  • debug_info – debug information

  • begin – first output to considered

  • end – last output to considered (-1 for the last one)

  • _index – used with begin and end

Returns:

dictionary with many values

  • abs: max abolute error

  • rel: max relative error

  • sum: sum of the errors

  • n: number of outputs values, if there is one output, this number will be the number of elements of this output

ort_run(sess: WrapInferenceSessionForTorch, feeds: List[Tensor]) List[Tensor][source]

Runs with onnxruntme.