onnx_diagnostic.export.validate

onnx_diagnostic.export.validate.compare_modules(modep: Module, mod: Module | None = None, args: Tuple[Any, ...] | None = None, kwargs: Dict[str, Any] | None = None, copy: bool = False, exc: bool = True, verbose: int = 0, atol: float = 0.01, rtol: float = 0.1) Dict[str, Any][source]

Compares two torch modules, usually one coming from an exported program, the other being the origin model.

Parameters:
  • model – first module

  • mod – second module (it produces the expected values)

  • args – positional arguments

  • kwargs – named arguments

  • copy – copy the inputs before executing the model (they may modify them inplace)

  • exc – raise exception if discrepancies are too high

  • verbose – verbosity level

  • atol – absolute tolerance

  • rtol – relative tolerance

Returns:

dictionary with inputs, outputs and tolerance

Example:

<<<

import torch
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)  # to make it is running

ds = ({0: "a", 1: "b"}, {1: "b"})
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
validate_ep(
    ep,
    model,
    args=(x, y),
    verbose=2,
    copy=True,
    dynamic_shapes=ds,
    values_to_try={"a": [5, 10], "b": [10, 20]},
)

>>>

    [compare_modules] check ep with args=(CT1s5x6,CT1s1x6), kwargs={}...
    [compare_modules] done in 0.002474856999469921 with output=T1s5x6
    [compare_modules] run torch module...
    [compare_modules] done in 0.005618618000880815 with output=T1s5x6
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 0/4: {'a': 5, 'b': 10}
    [compare_modules] check ep with args=(CT1s5x10,CT1s1x10), kwargs={}...
    [compare_modules] done in 0.0008011559984879568 with output=T1s5x10
    [compare_modules] run torch module...
    [compare_modules] done in 0.00019003800116479397 with output=T1s5x10
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 1/4: {'a': 5, 'b': 20}
    [compare_modules] check ep with args=(CT1s5x20,CT1s1x20), kwargs={}...
    [compare_modules] done in 0.0004989680019207299 with output=T1s5x20
    [compare_modules] run torch module...
    [compare_modules] done in 0.00012752799739246257 with output=T1s5x20
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 2/4: {'a': 10, 'b': 10}
    [compare_modules] check ep with args=(CT1s10x10,CT1s1x10), kwargs={}...
    [compare_modules] done in 0.0004152670007897541 with output=T1s10x10
    [compare_modules] run torch module...
    [compare_modules] done in 0.00011316999734845012 with output=T1s10x10
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 3/4: {'a': 10, 'b': 20}
    [compare_modules] check ep with args=(CT1s10x20,CT1s1x20), kwargs={}...
    [compare_modules] done in 0.0004114879993721843 with output=T1s10x20
    [compare_modules] run torch module...
    [compare_modules] done in 0.000274800000624964 with output=T1s10x20
    [compare_modules] discrepancies=abs=0.0, rel=0.0