onnx_diagnostic.torch_test_helper

More complex helpers used in unit tests.

onnx_diagnostic.torch_test_helper.check_model_ort(onx: ModelProto, providers: str | List[Any] | None = None, dump_file: str | None = None) InferenceSession[source]

Loads a model with onnxruntime.

Parameters:
  • onx – ModelProto

  • providers – list of providers, None fur CPU, cpu for CPU, cuda for CUDA

  • dump_file – if not empty, dumps the model into this file if an error happened

Returns:

InferenceSession

onnx_diagnostic.torch_test_helper.dummy_llm(cls_name: str | None = None, dynamic_shapes: bool = False) Tuple[Module, Tuple[Tensor, ...]] | Tuple[Module, Tuple[Tensor, ...], Any][source]

Creates a dummy LLM for test purposes.

Parameters:
  • cls_name – None for whole model or a piece of it

  • dynamic_shapes – returns dynamic shapes as well

<<<

from onnx_diagnostic.torch_test_helper import dummy_llm

print(dummy_llm())

>>>

    (LLM(
      (embedding): Embedding(
        (embedding): Embedding(1024, 16)
        (pe): Embedding(1024, 16)
      )
      (decoder): DecoderLayer(
        (attention): MultiAttentionBlock(
          (attention): ModuleList(
            (0-1): 2 x AttentionBlock(
              (query): Linear(in_features=16, out_features=16, bias=False)
              (key): Linear(in_features=16, out_features=16, bias=False)
              (value): Linear(in_features=16, out_features=16, bias=False)
            )
          )
          (linear): Linear(in_features=32, out_features=16, bias=True)
        )
        (feed_forward): FeedForward(
          (linear_1): Linear(in_features=16, out_features=128, bias=True)
          (relu): ReLU()
          (linear_2): Linear(in_features=128, out_features=16, bias=True)
        )
        (norm_1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
        (norm_2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
      )
    ), (tensor([[ 638,  345,  939,  436,  769,  411,  873,  433,  574,  642,  145,  533,
              739,  944,  750,  891,  303,  431,  989,  686, 1009,  675,  924,  567,
               16,   32,  677,  270,  219,  398]]),))
onnx_diagnostic.torch_test_helper.to_numpy(tensor: Tensor)[source]

Converts a torch tensor to numy.