onnx_diagnostic.export.validate

onnx_diagnostic.export.validate.compare_modules(modep: Module, mod: Module | None = None, args: Tuple[Any, ...] | None = None, kwargs: Dict[str, Any] | None = None, copy: bool = False, exc: bool = True, verbose: int = 0, atol: float = 0.01, rtol: float = 0.1) Dict[str, Any][source][source]

Compares two torch modules, usually one coming from an exported program, the other being the origin model.

Parameters:
  • model – first module

  • mod – second module (it produces the expected values)

  • args – positional arguments

  • kwargs – named arguments

  • copy – copy the inputs before executing the model (they may modify them inplace)

  • exc – raise exception if discrepancies are too high

  • verbose – verbosity level

  • atol – absolute tolerance

  • rtol – relative tolerance

Returns:

dictionary with inputs, outputs and tolerance

Example:

<<<

import torch
from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes


class Model(torch.nn.Module):
    def forward(self, x, y):
        return x + y


model = Model()
x = torch.randn((5, 6))
y = torch.randn((1, 6))
model(x, y)  # to make it is running

ds = ({0: "a", 1: "b"}, {1: "b"})
cpl = CoupleInputsDynamicShapes((x, y), {}, ds)
ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by())
validate_ep(
    ep,
    model,
    args=(x, y),
    verbose=2,
    copy=True,
    dynamic_shapes=ds,
    values_to_try={"a": [5, 10], "b": [10, 20]},
)

>>>

    [compare_modules] check ep with args=(CT1s5x6,CT1s1x6), kwargs={}...
    [compare_modules] done in 0.0010219119940302335 with output=T1s5x6
    [compare_modules] run torch module...
    [compare_modules] done in 0.00425627199729206 with output=T1s5x6
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 0/4: {'a': 5, 'b': 10}
    [compare_modules] check ep with args=(CT1s5x10,CT1s1x10), kwargs={}...
    [compare_modules] done in 0.0006602429930353537 with output=T1s5x10
    [compare_modules] run torch module...
    [compare_modules] done in 0.00016992900054901838 with output=T1s5x10
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 1/4: {'a': 5, 'b': 20}
    [compare_modules] check ep with args=(CT1s5x20,CT1s1x20), kwargs={}...
    [compare_modules] done in 0.0006451710069086403 with output=T1s5x20
    [compare_modules] run torch module...
    [compare_modules] done in 0.00015772099868627265 with output=T1s5x20
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 2/4: {'a': 10, 'b': 10}
    [compare_modules] check ep with args=(CT1s10x10,CT1s1x10), kwargs={}...
    [compare_modules] done in 0.0005811890005134046 with output=T1s10x10
    [compare_modules] run torch module...
    [compare_modules] done in 0.003368394995050039 with output=T1s10x10
    [compare_modules] discrepancies=abs=0.0, rel=0.0
    [validate_ep] try 3/4: {'a': 10, 'b': 20}
    [compare_modules] check ep with args=(CT1s10x20,CT1s1x20), kwargs={}...
    [compare_modules] done in 0.0008506960002705455 with output=T1s10x20
    [compare_modules] run torch module...
    [compare_modules] done in 0.00037577900366159156 with output=T1s10x20
    [compare_modules] discrepancies=abs=0.0, rel=0.0