onnx_diagnostic.export.validate¶
- onnx_diagnostic.export.validate.compare_modules(modep: Module, mod: Module | None = None, args: Tuple[Any, ...] | None = None, kwargs: Dict[str, Any] | None = None, copy: bool = False, exc: bool = True, verbose: int = 0, atol: float = 0.01, rtol: float = 0.1) Dict[str, Any] [source][source]¶
Compares two torch modules, usually one coming from an exported program, the other being the origin model.
- Parameters:
model – first module
mod – second module (it produces the expected values)
args – positional arguments
kwargs – named arguments
copy – copy the inputs before executing the model (they may modify them inplace)
exc – raise exception if discrepancies are too high
verbose – verbosity level
atol – absolute tolerance
rtol – relative tolerance
- Returns:
dictionary with inputs, outputs and tolerance
Example:
<<<
import torch from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y) # to make it is running ds = ({0: "a", 1: "b"}, {1: "b"}) cpl = CoupleInputsDynamicShapes((x, y), {}, ds) ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by()) validate_ep( ep, model, args=(x, y), verbose=2, copy=True, dynamic_shapes=ds, values_to_try={"a": [5, 10], "b": [10, 20]}, )
>>>
[compare_modules] check ep with args=(CT1s5x6,CT1s1x6), kwargs={}... [compare_modules] done in 0.0010219119940302335 with output=T1s5x6 [compare_modules] run torch module... [compare_modules] done in 0.00425627199729206 with output=T1s5x6 [compare_modules] discrepancies=abs=0.0, rel=0.0 [validate_ep] try 0/4: {'a': 5, 'b': 10} [compare_modules] check ep with args=(CT1s5x10,CT1s1x10), kwargs={}... [compare_modules] done in 0.0006602429930353537 with output=T1s5x10 [compare_modules] run torch module... [compare_modules] done in 0.00016992900054901838 with output=T1s5x10 [compare_modules] discrepancies=abs=0.0, rel=0.0 [validate_ep] try 1/4: {'a': 5, 'b': 20} [compare_modules] check ep with args=(CT1s5x20,CT1s1x20), kwargs={}... [compare_modules] done in 0.0006451710069086403 with output=T1s5x20 [compare_modules] run torch module... [compare_modules] done in 0.00015772099868627265 with output=T1s5x20 [compare_modules] discrepancies=abs=0.0, rel=0.0 [validate_ep] try 2/4: {'a': 10, 'b': 10} [compare_modules] check ep with args=(CT1s10x10,CT1s1x10), kwargs={}... [compare_modules] done in 0.0005811890005134046 with output=T1s10x10 [compare_modules] run torch module... [compare_modules] done in 0.003368394995050039 with output=T1s10x10 [compare_modules] discrepancies=abs=0.0, rel=0.0 [validate_ep] try 3/4: {'a': 10, 'b': 20} [compare_modules] check ep with args=(CT1s10x20,CT1s1x20), kwargs={}... [compare_modules] done in 0.0008506960002705455 with output=T1s10x20 [compare_modules] run torch module... [compare_modules] done in 0.00037577900366159156 with output=T1s10x20 [compare_modules] discrepancies=abs=0.0, rel=0.0