Export Times

fx_mode

symbolic

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

begin = time.perf_counter()
print("creating model")
model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=2,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

torch._dynamo.reset()
begin = time.perf_counter()
torch._dynamo.export(model, tracing_mode="symbolic")(*example_args_collection[0])
print(f"time to export symbolic --- {time.perf_counter() - begin}")

>>>

    creating model
    time to export symbolic --- 0.70448063799995

fake

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

begin = time.perf_counter()
print("creating model")
model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=2,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

torch._dynamo.reset()
begin = time.perf_counter()
torch._dynamo.export(model, tracing_mode="fake")(*example_args_collection[0])
print(f"time to export fake --- {time.perf_counter() - begin}")

>>>

    creating model
    time to export fake --- 0.33011539899598574

Custom Exporter

With a very simple model:

<<<

import time
from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(model, (x,))
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.7595461450036964
    time to import onnx_array_api --- 0.00019352699746377766
    time to import torch --- 2.4495818650029832
    'torch.export' already imported
    time to import torch.export --- 2.2509993868879974e-06
    time to import onnxscript --- 0.19496394599991618
    time to import onnxruntime --- 2.391540961994906
    'torch.onnx' already imported
    time to import torch.onnx --- 1.69100530911237e-06
    time to import torch._dynamo --- 1.021102607002831
    time to import experimental_experiment.torch_interpreter --- 0.024783433000266086
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.007664771001145709
    time to export 1x --- 0.19662757599871838
    time to export 2x --- 0.03059308699448593

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 1x --- {time.perf_counter() - begin}")

begin = time.perf_counter()
onx = experimental_experiment.torch_interpreter.to_onnx(
    model, example_args_collection[0]
)
print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to export 1x --- 2.1147469930001535
    time to export 2x --- 2.036533844002406

Dynamo Exporter

<<<

import time
import warnings

from experimental_experiment.checks import print_import_time

print_import_time()

import torch
import experimental_experiment.torch_interpreter


class Neuron(torch.nn.Module):
    def __init__(self, n_dims: int, n_targets: int):
        super(Neuron, self).__init__()
        self.linear = torch.nn.Linear(n_dims, n_targets)

    def forward(self, x):
        return torch.sigmoid(self.linear(x))


model = Neuron(3, 1)
x = torch.rand(5, 3)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, x)
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, x)
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    time to import onnx --- 0.5649671029968886
    time to import onnx_array_api --- 0.00019720000273082405
    time to import torch --- 1.9314632070017979
    'torch.export' already imported
    time to import torch.export --- 2.6959969545714557e-06
    time to import onnxscript --- 0.11423240799922496
    time to import onnxruntime --- 1.6460423349999473
    'torch.onnx' already imported
    time to import torch.onnx --- 1.734995748847723e-06
    time to import torch._dynamo --- 0.7172609039989766
    time to import experimental_experiment.torch_interpreter --- 0.01781430399569217
    time to import experimental_experiment.torch_interpreter.aten_functions --- 0.0037380009962362237
    time to export 1x --- 2.3869668779952917
    time to export 2x --- 0.2126849129999755

With a bigger model:

<<<

import time
import warnings
import numpy as np
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import onnx
import onnxruntime
import torch
import torch._dynamo
import torch.export
import onnxscript
import torch.onnx
import experimental_experiment
import experimental_experiment.torch_interpreter
import experimental_experiment.torch_interpreter.aten_functions
from experimental_experiment.torch_models.llama_helper import get_llama_decoder

model, example_args_collection = get_llama_decoder(
    input_dims=[(2, 1024)],
    hidden_size=4096,
    num_hidden_layers=1,
    vocab_size=32000,
    intermediate_size=11008,
    max_position_embeddings=2048,
    num_attention_heads=32,
    _attn_implementation="eager",
)

with warnings.catch_warnings():
    warnings.simplefilter("ignore")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, *example_args_collection[0])
    print(f"time to export 1x --- {time.perf_counter() - begin}")

    begin = time.perf_counter()
    onx = torch.onnx.dynamo_export(model, *example_args_collection[0])
    print(f"time to export 2x --- {time.perf_counter() - begin}")

>>>

    Applied 16 of general pattern rewrite rules.
    time to export 1x --- 4.6293873749964405
    Applied 16 of general pattern rewrite rules.
    time to export 2x --- 3.1193229490017984