onnx_diagnostic.export.validate¶
- onnx_diagnostic.export.validate.compare_modules(modep: Module, mod: Module | None = None, args: Tuple[Any, ...] | None = None, kwargs: Dict[str, Any] | None = None, copy: bool = False, exc: bool = True, verbose: int = 0, atol: float = 0.01, rtol: float = 0.1) Dict[str, Any] [source][source]¶
Compares two torch modules, usually one coming from an exported program, the other being the origin model.
- Parameters:
model – first module
mod – second module (it produces the expected values)
args – positional arguments
kwargs – named arguments
copy – copy the inputs before executing the model (they may modify them inplace)
exc – raise exception if discrepancies are too high
verbose – verbosity level
atol – absolute tolerance
rtol – relative tolerance
- Returns:
dictionary with inputs, outputs and tolerance
Example:
<<<
import torch from onnx_diagnostic.export import validate_ep, CoupleInputsDynamicShapes class Model(torch.nn.Module): def forward(self, x, y): return x + y model = Model() x = torch.randn((5, 6)) y = torch.randn((1, 6)) model(x, y) # to make it is running ds = ({0: "a", 1: "b"}, {1: "b"}) cpl = CoupleInputsDynamicShapes((x, y), {}, ds) ep = torch.export.export(model, (x, y), dynamic_shapes=cpl.replace_string_by()) validate_ep( ep, model, args=(x, y), verbose=2, copy=True, dynamic_shapes=ds, values_to_try={"a": [5, 10], "b": [10, 20]}, )
>>>
[compare_modules] check ep with args=(CT1s5x6,CT1s1x6), kwargs={}... [compare_modules] done in 0.0020388610000736662 with output=T1s5x6 [compare_modules] run torch module... [compare_modules] done in 0.00046975700024631806 with output=T1s5x6 [compare_modules] discrepancies=abs=0.0, rel=0.0,amax=0,0 [validate_ep] try 0/4: {'a': 5, 'b': 10} [compare_modules] check ep with args=(CT1s5x10,CT1s1x10), kwargs={}... [compare_modules] done in 0.0009382839998579584 with output=T1s5x10 [compare_modules] run torch module... [compare_modules] done in 0.0002730609994614497 with output=T1s5x10 [compare_modules] discrepancies=abs=0.0, rel=0.0,amax=0,0 [validate_ep] try 1/4: {'a': 5, 'b': 20} [compare_modules] check ep with args=(CT1s5x20,CT1s1x20), kwargs={}... [compare_modules] done in 0.0008782439999777125 with output=T1s5x20 [compare_modules] run torch module... [compare_modules] done in 0.00024511400079063606 with output=T1s5x20 [compare_modules] discrepancies=abs=0.0, rel=0.0,amax=0,0 [validate_ep] try 2/4: {'a': 10, 'b': 10} [compare_modules] check ep with args=(CT1s10x10,CT1s1x10), kwargs={}... [compare_modules] done in 0.000775865999457892 with output=T1s10x10 [compare_modules] run torch module... [compare_modules] done in 0.00024127100004989188 with output=T1s10x10 [compare_modules] discrepancies=abs=0.0, rel=0.0,amax=0,0 [validate_ep] try 3/4: {'a': 10, 'b': 20} [compare_modules] check ep with args=(CT1s10x20,CT1s1x20), kwargs={}... [compare_modules] done in 0.0007653570000911714 with output=T1s10x20 [compare_modules] run torch module... [compare_modules] done in 0.00031477199991059024 with output=T1s10x20 [compare_modules] discrepancies=abs=0.0, rel=0.0,amax=0,0