onnx_diagnostic.helpers.torch_test_helper¶
- onnx_diagnostic.helpers.torch_test_helper.dummy_llm(cls_name: str | None = None, dynamic_shapes: bool = False) Tuple[Module, Tuple[Tensor, ...]] | Tuple[Module, Tuple[Tensor, ...], Any] [source]¶
Creates a dummy LLM for test purposes.
- Parameters:
cls_name – None for whole model or a piece of it
dynamic_shapes – returns dynamic shapes as well
<<<
from onnx_diagnostic.helpers.torch_test_helper import dummy_llm print(dummy_llm())
>>>
(LLM( (embedding): Embedding( (embedding): Embedding(1024, 16) (pe): Embedding(1024, 16) ) (decoder): DecoderLayer( (attention): MultiAttentionBlock( (attention): ModuleList( (0-1): 2 x AttentionBlock( (query): Linear(in_features=16, out_features=16, bias=False) (key): Linear(in_features=16, out_features=16, bias=False) (value): Linear(in_features=16, out_features=16, bias=False) ) ) (linear): Linear(in_features=32, out_features=16, bias=True) ) (feed_forward): FeedForward( (linear_1): Linear(in_features=16, out_features=128, bias=True) (relu): ReLU() (linear_2): Linear(in_features=128, out_features=16, bias=True) ) (norm_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True) (norm_2): LayerNorm((16,), eps=1e-05, elementwise_affine=True) ) ), (tensor([[ 387, 517, 880, 269, 644, 583, 600, 275, 538, 664, 308, 419, 122, 431, 545, 647, 291, 938, 142, 192, 772, 1011, 252, 27, 131, 833, 273, 285, 455, 808]]),))
- onnx_diagnostic.helpers.torch_test_helper.is_torchdynamo_exporting() bool [source]¶
Tells if torch is exporting a model.
- onnx_diagnostic.helpers.torch_test_helper.replace_string_by_dynamic(dynamic_shapes: Any) Any [source]¶
Replaces strings by
torch.export.Dim.DYNAMIC
.
- onnx_diagnostic.helpers.torch_test_helper.steel_forward(model: Module, with_shape: bool = True, with_min_max: bool = False)[source]¶
The necessary modification to steem forward method and prints out inputs and outputs. See example Steel method forward to guess the dynamic shapes (with Tiny-LLM).
- onnx_diagnostic.helpers.torch_test_helper.to_any(value: Any, to_value: dtype | device) Any [source]¶
Applies torch.to is applicables. Goes recursively.