Tries with Undocumented

Example about torch._dynamo.export

<<<

import numpy as np
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
import torch
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
from experimental_experiment.torch_interpreter import to_onnx


def ids_tensor(shape, vocab_size):
    total_dims = 1
    for dim in shape:
        total_dims *= dim

    values = []
    for _ in range(total_dims):
        values.append(np.random.randint(0, vocab_size - 1))

    return torch.tensor(data=values, dtype=torch.long).view(shape).contiguous()


config = LlamaConfig(
    hidden_size=16,
    num_hidden_layers=1,
    vocab_size=1024,
    intermediate_size=16,
    max_position_embeddings=1024,
    num_attention_heads=2,
)
config._attn_implementation = "eager"

model = LlamaModel(config)

batch, seq, vocab_size = 2, 1024, 1024

input_ids = ids_tensor([batch, seq], vocab_size)
input_mask = torch.tril(torch.ones(batch, seq, dtype=torch.float32))

model(input_ids, input_mask)

from torch.export.dynamic_shapes import (
    _process_constraints,
    _process_dynamic_shapes,
    Constraint,
    dims,
    dynamic_dim,
)

args = input_ids, input_mask

constraints = _process_dynamic_shapes(model, args, {}, None)
print(constraints)

gm, _ = torch._dynamo.export(
    model,
    aten_graph=True,
    tracing_mode="symbolic",
    decomposition_table={},
    constraints=constraints,
)(*args)

print(gm.graph)

>>>

    [2024-05-08 14:07:22,771] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
    None
    [runpythonerror]
    /home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:1376: UserWarning: export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead.  If you don't migrate, we may break your export call in the future if your user defined kwargs conflict with future kwargs added to export(f).
      warnings.warn(
    Traceback (most recent call last):
      File "<stdin>", line 60, in <module>
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1381, in export
        return inner(*extra_args, **extra_kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1251, in inner
        result_traced = opt_f(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 403, in _fn
        return fn(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/xadupre/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
        return forward_call(*args, **kwargs)
    TypeError: LlamaModel.forward() got an unexpected keyword argument 'constraints'

Example about custom ops in onnxrt

Look into unit test file test_custom_ops.py.