From a LLM to processing a prompt

Method generate generates the model answer for a given prompt. Let’s implement our own to understand better how it works and then apply it to an ONNX model.

Example with Phi 1.5

epkg:microsoft/Phi-1.5 is a small LLM. The example given

import os
import time
import sys
import pandas
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import to_any, get_weight_type
from onnx_diagnostic.helpers.rt_helper import onnx_generate
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config, task_from_id
from onnx_diagnostic.tasks import random_input_kwargs
from onnx_diagnostic.export.api import to_onnx


device = "cuda" if torch.cuda.is_available() else "cpu"
data = []

print("-- load the model...")
if unit_test_going():
    # unit_test_going() returns True if UNITTEST_GOING is 1
    # The example switches to a faster scenario.
    model_id = "arnir0/Tiny-LLM"
    data_export = get_untrained_model_with_inputs(model_id)
    model = data_export["model"]
    export_inputs = data_export["inputs"]
    export_shapes = data_export["dynamic_shapes"]
    tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
    model_id = "microsoft/phi-1_5"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    config = get_pretrained_config(model_id)
    task = task = task_from_id(model_id)
    kwargs, fct = random_input_kwargs(config, task)
    res = fct(model, config, add_second_input=False, **kwargs)
    export_inputs = res["inputs"]
    export_shapes = res["dynamic_shapes"]
model = model.to(device)
print("-- done.")

print("-- tokenize the prompt...")
inputs = tokenizer(
    '''def print_prime(n):
   """
   Print all primes between 1 and n
   """''',
    return_tensors="pt",
    return_attention_mask=False,
).to(device)
print("-- done.")

print("-- compute the answer...")
begin = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=100)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="generate", duration=duration))
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- load the model...
-- done.
-- tokenize the prompt...
-- done.
-- compute the answer...
-- done in 3.6174923379985557
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

eos_token_id?

This token means the end of the answer.

print("eos_token_id=", tokenizer.eos_token_id)
eos_token_id= 50256

Custom method generate

Let’s implement a simple function replicating when method generate does.

def simple_generate_with_cache(
    model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
):
    # First call: prefill
    outputs = model(input_ids, use_cache=True)

    # Next calls: decode
    for _ in tqdm(list(range(max_new_tokens))):
        next_token_logits = outputs.logits[:, -1, :]
        past_key_values = outputs.past_key_values

        # The most probable next token is chosen.
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        # But we could select it using a multinomial law
        # <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
        # <<< top_probs, top_indices = torch.topk(probs, top_k)
        # <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]

        if next_token_id.item() == eos_token_id:
            break
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

        # Feed only the new token, but with the cache
        outputs = model(next_token_id, use_cache=True, past_key_values=past_key_values)

    return input_ids


print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = simple_generate_with_cache(
    model, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="custom", duration=duration))

print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...

  0%|          | 0/100 [00:00<?, ?it/s]
  3%|▎         | 3/100 [00:00<00:03, 26.15it/s]
  6%|▌         | 6/100 [00:00<00:03, 25.17it/s]
  9%|▉         | 9/100 [00:00<00:03, 26.32it/s]
 12%|█▏        | 12/100 [00:00<00:03, 27.39it/s]
 15%|█▌        | 15/100 [00:00<00:03, 25.97it/s]
 18%|█▊        | 18/100 [00:00<00:03, 26.44it/s]
 21%|██        | 21/100 [00:00<00:03, 25.19it/s]
 24%|██▍       | 24/100 [00:00<00:03, 24.73it/s]
 27%|██▋       | 27/100 [00:01<00:03, 23.05it/s]
 30%|███       | 30/100 [00:01<00:03, 21.12it/s]
 33%|███▎      | 33/100 [00:01<00:03, 19.19it/s]
 35%|███▌      | 35/100 [00:01<00:03, 18.20it/s]
 37%|███▋      | 37/100 [00:01<00:03, 16.31it/s]
 39%|███▉      | 39/100 [00:01<00:03, 15.32it/s]
 41%|████      | 41/100 [00:02<00:03, 15.19it/s]
 43%|████▎     | 43/100 [00:02<00:03, 14.69it/s]
 45%|████▌     | 45/100 [00:02<00:03, 14.93it/s]
 47%|████▋     | 47/100 [00:02<00:03, 14.14it/s]
 49%|████▉     | 49/100 [00:02<00:03, 14.14it/s]
 51%|█████     | 51/100 [00:02<00:03, 14.15it/s]
 53%|█████▎    | 53/100 [00:02<00:03, 13.51it/s]
 55%|█████▌    | 55/100 [00:03<00:03, 13.14it/s]
 57%|█████▋    | 57/100 [00:03<00:03, 13.27it/s]
 59%|█████▉    | 59/100 [00:03<00:03, 13.48it/s]
 61%|██████    | 61/100 [00:03<00:02, 13.81it/s]
 63%|██████▎   | 63/100 [00:03<00:02, 14.26it/s]
 65%|██████▌   | 65/100 [00:03<00:02, 14.13it/s]
 67%|██████▋   | 67/100 [00:03<00:02, 13.90it/s]
 69%|██████▉   | 69/100 [00:04<00:02, 12.72it/s]
 71%|███████   | 71/100 [00:04<00:02, 12.35it/s]
 73%|███████▎  | 73/100 [00:04<00:02, 13.32it/s]
 75%|███████▌  | 75/100 [00:04<00:01, 14.54it/s]
 77%|███████▋  | 77/100 [00:04<00:01, 14.39it/s]
 79%|███████▉  | 79/100 [00:04<00:01, 15.28it/s]
 81%|████████  | 81/100 [00:04<00:01, 15.31it/s]
 83%|████████▎ | 83/100 [00:05<00:01, 15.78it/s]
 85%|████████▌ | 85/100 [00:05<00:00, 15.95it/s]
 87%|████████▋ | 87/100 [00:05<00:00, 14.89it/s]
 89%|████████▉ | 89/100 [00:05<00:00, 14.50it/s]
 91%|█████████ | 91/100 [00:05<00:00, 14.45it/s]
 93%|█████████▎| 93/100 [00:05<00:00, 14.28it/s]
 95%|█████████▌| 95/100 [00:05<00:00, 15.24it/s]
 97%|█████████▋| 97/100 [00:05<00:00, 14.65it/s]
 99%|█████████▉| 99/100 [00:06<00:00, 14.80it/s]
100%|██████████| 100/100 [00:06<00:00, 16.19it/s]
-- done in 6.383380365001358
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

Method generate for onnx models

We first need to export the model into ONNX.

ONNX Conversion

if "position_ids" in export_inputs:
    del export_inputs["position_ids"]
    del export_shapes["position_ids"]
dtype = get_weight_type(model)
print("-- model dtype:", dtype)
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom"
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
if not os.path.exists(model_name):
    # This step is slow so let's skip it if it was already done.
    print("-- conversion to ONNX.")
    begin = time.perf_counter()
    with torch_export_patches(patch_transformers=True):
        to_onnx(
            model,
            (),
            kwargs=to_any(export_inputs, device),
            dynamic_shapes=export_shapes,
            filename=model_name,
            verbose=1,
            exporter=exporter,
        )
    duration = time.perf_counter() - begin
    print(f"-- done in {duration}")
-- model dtype: torch.float16
-- conversion to ONNX.
[to_onnx] build the graph module from <class 'transformers.models.phi.modeling_phi.PhiForCausalLM'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seq_length'}, 'attention_mask': {0: 'batch', 1: 'cache+seq'}, 'past_key_values': [{0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}]}
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=()
[_make_builder_interpreter] input kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#24[T10r4,...], value_cache=#24[T10r4,...]))
[_make_builder_interpreter] dynamic_shapes={'input_ids': {0: 'batch', 1: 'seq_length'}, 'attention_mask': {0: 'batch', 1: 'cache+seq'}, 'past_key_values': [{0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}, {0: 'batch', 2: 'cache_length'}]}
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'PhiForCausalLM'
[ExportOptions.export] aten_as_function=('aten.index_copy.default', 'aten.index_put.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes={'input_ids': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'attention_mask': {0: DimHint(DYNAMIC), 1: DimHint(DYNAMIC)}, 'past_key_values': [{0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC), 2: DimHint(DYNAMIC)}]}
[ExportOptions.export] args=()
[ExportOptions.export] kwargs=dict(input_ids:T7r2,attention_mask:T7r2,past_key_values:DynamicCache(key_cache=#24[T10r4,...], value_cache=#24[T10r4,...]))
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious=None
[torch_export] export starts with backed_size_oblivious=None
[ExportOptions.export] export done in 15.05451744399761
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[ExportOptions.export] slices: 3 slices nodes were removed
[CustomTracer.remove_inplace] starts with 1708 nodes
[CustomTracer.remove_inplace] S1: 1 inplace nodes
[CustomTracer.remove_inplace] S2: 1 inplace nodes and 10 iterations
[CustomTracer.remove_inplace] end with 8 iterations and 1705 nodes
[ExportOptions.export] inplaces: 1 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.05087083700345829, modified=1
[ExportOptions.export] done with no decomposition in 0.051558062001276994
[to_onnx] graph module done in 15.123618322999391 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)

  0%|          | 0/1705 [00:00<?, ?it/s]
 25%|██▍       | 426/1705 [00:00<00:00, 4072.18it/s]
 49%|████▉     | 834/1705 [00:00<00:00, 1928.31it/s]
 64%|██████▎   | 1084/1705 [00:00<00:00, 1436.18it/s]
 74%|███████▍  | 1264/1705 [00:00<00:00, 1356.30it/s]
 83%|████████▎ | 1419/1705 [00:00<00:00, 1304.26it/s]
 91%|█████████▏| 1560/1705 [00:01<00:00, 1212.33it/s]
 99%|█████████▉| 1687/1705 [00:01<00:00, 1140.25it/s]
100%|██████████| 1705/1705 [00:01<00:00, 1342.43it/s]
[to_onnx] 2308 onnx nodes done in 1.2720247670004028 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True]
[inline_functions] begin graph 132676290500304
[inline_functions] skip_functions=set()
[_inline_functions_iterations] inline function 'submod_3' domain 'local_functions'
[_inline_functions_iterations] 9 new nodes for 'submod_3', 'local_functions'
[inline_functions] done graph 132676290500304 in 0.031224165999446996
[GraphBuilder-TNS._add_shape_information] dynamic shapes replacements={'cache_length': 'cache_length', 'batch': 'batch', 'seq_length': 'seq_length', 's79': 'batch', 's26': 'batch', 's67': 'batch', 's77': 'batch', 's82': 'batch', 's34': 'batch', 's57': 'batch', 'batch^s97^batch^s10': 'batch', 'batch^s64^batch^s86': 'batch', 'batch^s3^batch^s41': 'batch', 's69': 'batch', 's60': 'batch', 'batch^s100^batch^s102': 'batch', 's84': 'batch', 's86': 'batch', 's8': 'batch', 's75': 'batch', 's92': 'batch', 's64': 'batch', 's102': 'batch', 's83': 'batch', 's29': 'batch', 'batch^s48^batch^s59': 'batch', 'batch^s67^batch^s61': 'batch', 's45': 'batch', 's62': 'batch', 's91': 'batch', 's98': 'batch', 's106': 'batch', 'batch^s29^batch^s8': 'batch', 's59': 'batch', 'batch^s52^batch^s93': 'batch', 's43': 'batch', 'batch^s36^batch^s13': 'batch', 's23': 'batch', 'batch^s34^batch^s77': 'batch', 'batch^s84^batch^s91': 'batch', 's93': 'batch', 's89': 'batch', 's47': 'batch', 'batch^s30^batch^s89': 'batch', 'batch^s82^batch^s62': 'batch', 's104': 'batch', 's52': 'batch', 'batch^s49^batch^s26': 'batch', 's100': 'batch', 's72': 'batch', 's39': 'batch', 'batch^s1^batch^s75': 'batch', 's61': 'batch', 's71': 'batch', 'batch^s45^batch^s47': 'batch', 'batch^s69^batch^s56': 'batch', 'batch^s90^batch^s57': 'batch', 's13': 'batch', 'batch^s104^batch^s106': 'batch', 's49': 'batch', 'batch^s87^batch^s23': 'batch', 'batch^s98^batch^s79': 'batch', 's30': 'batch', 's1': 'batch', 'batch^s92^batch^s83': 'batch', 'batch^s39^batch^s71': 'batch', 's97': 'batch', 's87': 'batch', 's56': 'batch', 's3': 'batch', 's48': 'batch', 's10': 'batch', 's90': 'batch', 'batch^s35^batch^s60': 'batch', 's36': 'batch', 's35': 'batch', 's41': 'batch', 's70': 'seq_length', 's105': 'cache_length', 's73': 'cache_length', 's65': 'cache_length', 's88': 'cache_length', 's32': 'cache_length', 's50': 'cache_length', 's27': 'cache_length', 's55': 'cache_length', 's33': 'cache_length', 's24': 'cache_length', 's46': 'cache_length', 's28': 'cache_length', 's22': 'cache_length', 's107': 'cache_length', 's21': 'cache_length', 's42': 'cache_length', 's99': 'cache_length', 's94': 'cache_length', 's7': 'cache_length', 's103': 'cache_length', 's81': 'cache_length', 's11': 'cache_length', 's96': 'cache_length', 's15': 'cache_length', 's38': 'cache_length', 's74': 'cache_length', 's9': 'cache_length', 's44': 'cache_length', 's78': 'cache_length', 's31': 'cache_length', 's25': 'cache_length', 's85': 'cache_length', 's2': 'cache_length', 's95': 'cache_length', 's4': 'cache_length', 's66': 'cache_length', 's18': 'cache_length', 's37': 'cache_length', 's14': 'cache_length', 's101': 'cache_length', 's68': 'cache_length', 's80': 'cache_length', 's40': 'cache_length', 's54': 'cache_length', 's58': 'cache_length', 's51': 'cache_length', 's63': 'cache_length', 's76': 'cache_length'}
[GraphBuilder-TNS.optimize] start with 2316 nodes
[GraphBuilder-TNS.optimize] #patterns=102
[GraphBuilder-TNS.optimize] start with subgraphs
[GraphBuilder-TNS.optimize] done with subgraphs
[GraphBuilderPatternOptimization-TNS.optimize] start with 1987 nodes, 461 initializers, 102 patterns, priorities=[0, 1, 2, 3], max_iter=7948
[GraphBuilderPatternOptimization-TNS.optimize] same children={'SameChildrenPattern', 'SameChildrenFromInputPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] iteration 0: 1987 nodes, priority=0
[GraphBuilderPatternOptimization-TNS.optimize] applies 226 matches, 75*CastPattern, 2*IdentityPattern, 3*ShapeBasedReshapeIsSqueezePattern, 96*ShapeBasedEditDistanceReshapePattern, 18*ShapeBasedIdentityPattern, 5*SameChildrenPattern, 1*SqueezeAddPattern, 1*SqueezeUnsqueezePattern, 1*UnsqueezeUnsqueezePattern, 24*FunctionAttentionPattern - time=0.239 | max_time=IdentityPattern:0.097
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=209, n_removed=264, n_applied=281 applied patterns, 1595 nodes left with 23 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 1
[GraphBuilderPatternOptimization-TNS.optimize] iteration 1: 1595 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] applies 201 matches, 2*ConcatTwiceUnaryPattern, 49*DropoutPattern, 25*LayerNormalizationPattern, 1*ShapeBasedExpandBroadcastPattern, 1*ShapeBasedExpandSwapPattern, 96*SlicesSplitPattern, 3*SqueezeUnsqueezePattern, 24*GeluOrtPattern - time=0.177 | max_time=GeluOrtPattern:0.012
[GraphBuilderPatternOptimization-TNS.optimize] iteration 2: 1128 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] applies 101 matches, 2*ConcatTwiceUnaryPattern, 25*LayerNormalizationScalePattern, 2*ShapeBasedExpandSwapPattern, 48*FunctionHalfRotaryEmbeddingPattern, 24*FastGeluPattern - time=0.117 | max_time=ShapeBasedEditDistanceReshapePattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] iteration 3: 912 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] applies 26 matches, 1*ShapeBasedExpandBroadcastPattern, 1*FunctionCausalMaskPattern, 24*SkipLayerNormalizationPattern - time=0.094 | max_time=ShapeBasedEditDistanceReshapePattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] iteration 4: 886 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] applies 2 matches, 1*ShapeBasedConcatExpandPattern, 1*FunctionCausalMaskMulAddPattern - time=0.091 | max_time=ShapeBasedEditDistanceReshapePattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 5: 880 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] applies 1 matches, [0]=MatchResult: FunctionCosSinCachePattern replaces ['Squeeze', 'Squeeze', 'Range', 'Unsqueeze', 'Cast', 'Reshape', 'Mul', 'Cos', 'Cast', 'Sin', 'Cast'] - time=0.087 | max_time=ShapeBasedEditDistanceReshapePattern:0.009
[GraphBuilderPatternOptimization-TNS.optimize] iteration 6: 869 nodes, priority=1
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 2
[GraphBuilderPatternOptimization-TNS.optimize] iteration 7: 869 nodes, priority=2
[GraphBuilderPatternOptimization-TNS.optimize] applies 1 matches, [0]=MatchResult: ContribRotaryEmbeddingPattern replaces ['Concat', 'Concat', 'Split', 'HalfRotaryEmbedding', 'Concat'] - time=0.089 | max_time=ShapeBasedEditDistanceReshapePattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 8: 874 nodes, priority=2
[GraphBuilderPatternOptimization-TNS.optimize] applies 3 matches, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern - time=0.098 | max_time=ShapeBasedEditDistanceReshapePattern:0.009
[GraphBuilderPatternOptimization-TNS.optimize] iteration 9: 878 nodes, priority=2
[GraphBuilderPatternOptimization-TNS.optimize] applies 6 matches, 2*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.096 | max_time=ShapeBasedEditDistanceReshapePattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=24, n_removed=32, n_applied=630 applied patterns, 876 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 10: 876 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 5 matches, 1*ShapeBasedEditDistanceReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.104 | max_time=IdentityPattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] iteration 11: 882 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 9 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ReshapeReshapePattern, 3*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.103 | max_time=IdentityPattern:0.013
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=653 applied patterns, 875 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 12: 875 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 8 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.092 | max_time=ShapeBasedEditDistanceReshapePattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] iteration 13: 878 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 14 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 3*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.091 | max_time=ShapeBasedEditDistanceReshapePattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=684 applied patterns, 866 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 14: 866 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.100 | max_time=ShapeBasedEditDistanceReshapePattern:0.009
[GraphBuilderPatternOptimization-TNS.optimize] iteration 15: 865 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.101 | max_time=ShapeBasedEditDistanceReshapePattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=720 applied patterns, 852 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 16: 852 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.115 | max_time=IdentityPattern:0.011
[GraphBuilderPatternOptimization-TNS.optimize] iteration 17: 851 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.091 | max_time=ShapeBasedEditDistanceReshapePattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=756 applied patterns, 838 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 18: 838 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.090 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 19: 837 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.085 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=792 applied patterns, 824 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 20: 824 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.088 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 21: 823 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.081 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=828 applied patterns, 810 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 22: 810 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.086 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 23: 809 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.089 | max_time=IdentityPattern:0.010
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=864 applied patterns, 796 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 24: 796 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.094 | max_time=IdentityPattern:0.011
[GraphBuilderPatternOptimization-TNS.optimize] iteration 25: 795 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.114 | max_time=ShapeBasedEditDistanceReshapePattern:0.010
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=900 applied patterns, 782 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 26: 782 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.099 | max_time=IdentityPattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 27: 781 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.094 | max_time=ShapeBasedEditDistanceReshapePattern:0.008
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=936 applied patterns, 768 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 28: 768 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.089 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 29: 767 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.116 | max_time=ShapeBasedEditDistanceReshapePattern:0.016
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=972 applied patterns, 754 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 30: 754 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.087 | max_time=IdentityPattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 31: 753 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.078 | max_time=ShapeBasedEditDistanceReshapePattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1008 applied patterns, 740 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 32: 740 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.085 | max_time=ShapeBasedEditDistanceReshapePattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 33: 739 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.079 | max_time=IdentityPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1044 applied patterns, 726 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 34: 726 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.086 | max_time=IdentityPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] iteration 35: 725 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 14 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 3*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.111 | max_time=ShapeBasedEditDistanceReshapePattern:0.011
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1079 applied patterns, 713 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 36: 713 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.084 | max_time=IdentityPattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 37: 712 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.086 | max_time=IdentityPattern:0.014
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1115 applied patterns, 699 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 38: 699 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.085 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 39: 698 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.075 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1151 applied patterns, 685 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 40: 685 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.087 | max_time=Reshape2Of3Pattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 41: 684 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.082 | max_time=IdentityPattern:0.012
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1187 applied patterns, 671 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 42: 671 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.088 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 43: 670 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.075 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1223 applied patterns, 657 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 44: 657 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.097 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 45: 656 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.074 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1259 applied patterns, 643 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 46: 643 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.073 | max_time=IdentityPattern:0.004
[GraphBuilderPatternOptimization-TNS.optimize] iteration 47: 642 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.116 | max_time=SameChildrenPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1295 applied patterns, 629 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 48: 629 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.075 | max_time=IdentityPattern:0.007
[GraphBuilderPatternOptimization-TNS.optimize] iteration 49: 628 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.068 | max_time=IdentityPattern:0.004
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1331 applied patterns, 615 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 50: 615 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.069 | max_time=IdentityPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] iteration 51: 614 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.066 | max_time=IdentityPattern:0.004
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1367 applied patterns, 601 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 52: 601 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.068 | max_time=IdentityPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] iteration 53: 600 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 15 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern - time=0.058 | max_time=IdentityPattern:0.004
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=26, n_removed=35, n_applied=1403 applied patterns, 587 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 54: 587 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 12 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbeddingPattern, 1*ContribRotaryEmbedding3DPattern, 1*MultiHeadAttention3DPattern - time=0.063 | max_time=IdentityPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] iteration 55: 584 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 14 matches, 1*ShapeBasedEditDistanceReshapePattern, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 4*SameChildrenPattern, 2*SqueezeUnsqueezePattern, 1*ContribRotaryEmbedding3DPattern - time=0.098 | max_time=IdentityPattern:0.016
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=15, n_removed=20, n_applied=1434 applied patterns, 569 nodes left with 4 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 56: 569 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 8 matches, 1*ShapeBasedEditDistanceReshapePattern, 5*ShapedBasedReshapePattern, 1*ReshapeReshapePattern, 1*MultiHeadAttention3DPattern - time=0.067 | max_time=IdentityPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 57: 560 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 7 matches, 1*ShapedBasedReshapePattern, 5*ReshapeReshapePattern, 1*SameChildrenPattern - time=0.060 | max_time=SameChildrenPattern:0.005
[GraphBuilderPatternOptimization-TNS.optimize] reapply {'SameChildrenPattern'}
[GraphBuilderPatternOptimization-TNS.optimize] n_added=0, n_removed=0, n_applied=1449 applied patterns, 553 nodes left with 1 iterations
[GraphBuilderPatternOptimization-TNS.optimize] increase priority to 3
[GraphBuilderPatternOptimization-TNS.optimize] iteration 58: 553 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] applies 5 matches, 5*ShapedBasedReshapePattern - time=0.061 | max_time=SameChildrenPattern:0.006
[GraphBuilderPatternOptimization-TNS.optimize] iteration 59: 548 nodes, priority=3
[GraphBuilderPatternOptimization-TNS.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-TNS.optimize] done after 60 iterations with 548 nodes in 14.804
[GraphBuilder-TNS.optimize] done with 473 nodes in 18.106
[GraphBuilder-TNS.to_onnx] make_model 499 inits 341 params
[GraphBuilder-TNS.time_evaluation_constants_] 0.0009653799970692489
[GraphBuilder-TNS._build_initializers] start with 499 initializers, large_model=True, external_threshold=1024
[GraphBuilder-TNS._build_initializers] switch low/high order
[GraphBuilder-TNS._build_initializers] done in 3.2619973353575915e-06s with 353 initializers, 341 large initializers
[GraphBuilder-TNS._add_shape_information] dynamic shapes replacements={'cache_length': 'cache_length', 'batch': 'batch', 'seq_length': 'seq_length', 's79': 'batch', 's26': 'batch', 's67': 'batch', 's77': 'batch', 's82': 'batch', 's34': 'batch', 's57': 'batch', 'batch^s97^batch^s10': 'batch', 'batch^s64^batch^s86': 'batch', 'batch^s3^batch^s41': 'batch', 's69': 'batch', 's60': 'batch', 'batch^s100^batch^s102': 'batch', 's84': 'batch', 's86': 'batch', 's8': 'batch', 's75': 'batch', 's92': 'batch', 's64': 'batch', 's102': 'batch', 's83': 'batch', 's29': 'batch', 'batch^s48^batch^s59': 'batch', 'batch^s67^batch^s61': 'batch', 's45': 'batch', 's62': 'batch', 's91': 'batch', 's98': 'batch', 's106': 'batch', 'batch^s29^batch^s8': 'batch', 's59': 'batch', 'batch^s52^batch^s93': 'batch', 's43': 'batch', 'batch^s36^batch^s13': 'batch', 's23': 'batch', 'batch^s34^batch^s77': 'batch', 'batch^s84^batch^s91': 'batch', 's93': 'batch', 's89': 'batch', 's47': 'batch', 'batch^s30^batch^s89': 'batch', 'batch^s82^batch^s62': 'batch', 's104': 'batch', 's52': 'batch', 'batch^s49^batch^s26': 'batch', 's100': 'batch', 's72': 'batch', 's39': 'batch', 'batch^s1^batch^s75': 'batch', 's61': 'batch', 's71': 'batch', 'batch^s45^batch^s47': 'batch', 'batch^s69^batch^s56': 'batch', 'batch^s90^batch^s57': 'batch', 's13': 'batch', 'batch^s104^batch^s106': 'batch', 's49': 'batch', 'batch^s87^batch^s23': 'batch', 'batch^s98^batch^s79': 'batch', 's30': 'batch', 's1': 'batch', 'batch^s92^batch^s83': 'batch', 'batch^s39^batch^s71': 'batch', 's97': 'batch', 's87': 'batch', 's56': 'batch', 's3': 'batch', 's48': 'batch', 's10': 'batch', 's90': 'batch', 'batch^s35^batch^s60': 'batch', 's36': 'batch', 's35': 'batch', 's41': 'batch', 's70': 'seq_length', 's105': 'cache_length', 's73': 'cache_length', 's65': 'cache_length', 's88': 'cache_length', 's32': 'cache_length', 's50': 'cache_length', 's27': 'cache_length', 's55': 'cache_length', 's33': 'cache_length', 's24': 'cache_length', 's46': 'cache_length', 's28': 'cache_length', 's22': 'cache_length', 's107': 'cache_length', 's21': 'cache_length', 's42': 'cache_length', 's99': 'cache_length', 's94': 'cache_length', 's7': 'cache_length', 's103': 'cache_length', 's81': 'cache_length', 's11': 'cache_length', 's96': 'cache_length', 's15': 'cache_length', 's38': 'cache_length', 's74': 'cache_length', 's9': 'cache_length', 's44': 'cache_length', 's78': 'cache_length', 's31': 'cache_length', 's25': 'cache_length', 's85': 'cache_length', 's2': 'cache_length', 's95': 'cache_length', 's4': 'cache_length', 's66': 'cache_length', 's18': 'cache_length', 's37': 'cache_length', 's14': 'cache_length', 's101': 'cache_length', 's68': 'cache_length', 's80': 'cache_length', 's40': 'cache_length', 's54': 'cache_length', 's58': 'cache_length', 's51': 'cache_length', 's63': 'cache_length', 's76': 'cache_length'}
[to_onnx] to_onnx done in 18.346887952000543s and 473 nodes, 353 initializers, 50 inputs, 49 outputs
-- done in 43.6935249800008

onnx_generate

Then we can call method generate for two tokens. This function is part of onnx_diagnostic but follows the implementation seen earlier for a torch model. Let’s ask first the function to return the session to avoid creating on the second call.

_res, session = onnx_generate(
    model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
)

# And now the full answer.
print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = onnx_generate(
    session, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="onnx", duration=duration))

print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...
-- done in 2.5727611210022587
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

Plots

df = pandas.DataFrame(data).set_index("name")
print(df)
          duration
name
generate  3.617492
custom    6.383380
onnx      2.572761
ax = df.plot(kind="bar", title="Time (s) comparison to generate a prompt.", rot=45)
ax.figure.tight_layout()
ax.figure.savefig("plot_generate.png")
Time (s) comparison to generate a prompt.

Total running time of the script: (1 minutes 1.175 seconds)

Related examples

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

LayerNormalization implementation cannot be exchanged

LayerNormalization implementation cannot be exchanged

Dynamic Shapes and Broadcasting

Dynamic Shapes and Broadcasting

Gallery generated by Sphinx-Gallery