onnx_diagnostic.ci_models.ci_helpers¶
- onnx_diagnostic.ci_models.ci_helpers.check_for_discrepancies_and_log_everything_into_a_json_file(agg_stat_file: str, stat_file: str, export_duration: float, device: str, model_file: str, cached_inputs: str, cached_expected_outputs: str, main_info: Dict[str, Any], atol: float, mismatch01: float)[source][source]¶
Checks discrepancies for a specific model.
Imports are delayed to be faster when running the help of the command line.
- Parameters:
agg_stat_file – a file when the discrepancies are collected, this is used to produce a table to make it easier to compare across types, devices, …
stat_file – discrepancies results dumps into that file
export_duration – export duration
device – targeted device (to select onnxruntime provider)
model_file – onnx model file
cache_inputs – inputs saved with
torch.save()and restored withtorch.load(), needs to contains export_inputs (to check the model is valid), and other_inputs, other sets of inputs to measure the discrepancies, and speed up (rough estimation)cached_expected_outputs – expected outputs saved with
torch.save()and restored withtorch.load(), needs to contains export_expected (to check the model is valid), and other_expected, other sets of outputs to measure the discrepancies, and speed up (rough estimation)main_info – a dictionary with values used to tell which version, device, …
atol – assert if tolerance is above this
mismatch01 – assert if the ratio of mismatches is above that threshold
- onnx_diagnostic.ci_models.ci_helpers.compute_expected_outputs(output_filename: str, model_to_export: torch.nn.Module, input_filename: str) Tuple[Any, List[Any], List[float]][source][source]¶
Computes the expected outputs for a model. The function uses delayed import to make to fail fast at startup.
It caches the expected outputs in a file. They are restored if the file exists or computed and saved if not.
Imports are delayed to be faster when running the help of the command line.
- onnx_diagnostic.ci_models.ci_helpers.get_parser(name: str) ArgumentParser[source][source]¶
Creates a default parser for many models.
- onnx_diagnostic.ci_models.ci_helpers.get_torch_dtype_from_command_line_args(dtype: str) torch.dtype[source][source]¶
Returns the torch dtype base on the argument provided on the command line.
Imports are delayed to be faster when running the help of the command line.
- onnx_diagnostic.ci_models.ci_helpers.get_versions()[source][source]¶
Returns the version of the package currently used. The output is a dictionary. The function uses delayed import to make to fail fast at startup.
- onnx_diagnostic.ci_models.ci_helpers.remove_inplace_body_last_input_output_type_for_loop_because_they_might_be_sequences(filename: str)[source][source]¶
Modified inplace an onnx file. It wipes out shapes provided in
model.graph.value_infobecause they are wrong when a Loop outputs a sequence. It alose removes the types in attribute ‘Body’ of an operator Loop because it may be a tensor when a sequence is expected. This should not be needed in the future.