Note
Go to the end to download the full example code.
From a LLM to processing a prompt¶
Method generate generates the model answer for a given prompt.
Let’s implement our own to understand better how it works and
then apply it to an ONNX model.
Example with Phi 1.5¶
epkg:microsoft/Phi-1.5 is a small LLM. The example given
import os
import time
import sys
import pandas
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import to_any, get_weight_type
from onnx_diagnostic.helpers.rt_helper import onnx_generate
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config, task_from_id
from onnx_diagnostic.tasks import random_input_kwargs
from onnx_diagnostic.export.api import to_onnx
device = "cuda" if torch.cuda.is_available() else "cpu"
data = []
print("-- load the model...")
if unit_test_going():
# unit_test_going() returns True if UNITTEST_GOING is 1
# The example switches to a faster scenario.
model_id = "arnir0/Tiny-LLM"
data_export = get_untrained_model_with_inputs(model_id)
model = data_export["model"]
export_inputs = data_export["inputs"]
export_shapes = data_export["dynamic_shapes"]
tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
model_id = "microsoft/phi-1_5"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
config = get_pretrained_config(model_id)
task = task = task_from_id(model_id)
kwargs, fct = random_input_kwargs(config, task)
res = fct(model, config, add_second_input=False, **kwargs)
export_inputs = res["inputs"]
export_shapes = res["dynamic_shapes"]
model = model.to(device)
print("-- done.")
print("-- tokenize the prompt...")
inputs = tokenizer(
'''def print_prime(n):
"""
Print all primes between 1 and n
"""''',
return_tensors="pt",
return_attention_mask=False,
).to(device)
print("-- done.")
print("-- compute the answer...")
begin = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=100)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="generate", duration=duration))
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- load the model...
-- done.
-- tokenize the prompt...
-- done.
-- compute the answer...
-- done in 3.7011363160017936
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
"""
Print all primes between 1 and n
"""
primes = []
for num in range(2, n+1):
is_prime = True
for i in range(2, int(math.sqrt(num))+1):
if num % i == 0:
is_prime = False
break
if is_prime:
primes.append(num)
print(primes)
print_prime(20)
``
eos_token_id?¶
This token means the end of the answer.
print("eos_token_id=", tokenizer.eos_token_id)
eos_token_id= 50256
Custom method generate¶
Let’s implement a simple function replicating when method
generate does.
def simple_generate_with_cache(
model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
):
# First call: prefill
outputs = model(input_ids, use_cache=True)
# Next calls: decode
for _ in tqdm(list(range(max_new_tokens))):
next_token_logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# The most probable next token is chosen.
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
# But we could select it using a multinomial law
# <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
# <<< top_probs, top_indices = torch.topk(probs, top_k)
# <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
if next_token_id.item() == eos_token_id:
break
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
# Feed only the new token, but with the cache
outputs = model(next_token_id, use_cache=True, past_key_values=past_key_values)
return input_ids
print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = simple_generate_with_cache(
model, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="custom", duration=duration))
print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...
0%| | 0/100 [00:00<?, ?it/s]
3%|▎ | 3/100 [00:00<00:05, 19.07it/s]
5%|▌ | 5/100 [00:00<00:04, 19.17it/s]
8%|▊ | 8/100 [00:00<00:04, 22.04it/s]
11%|█ | 11/100 [00:00<00:03, 24.31it/s]
14%|█▍ | 14/100 [00:00<00:03, 25.60it/s]
17%|█▋ | 17/100 [00:00<00:03, 24.67it/s]
20%|██ | 20/100 [00:01<00:05, 15.61it/s]
23%|██▎ | 23/100 [00:01<00:04, 17.52it/s]
26%|██▌ | 26/100 [00:01<00:04, 18.18it/s]
29%|██▉ | 29/100 [00:01<00:03, 19.23it/s]
32%|███▏ | 32/100 [00:01<00:03, 20.55it/s]
35%|███▌ | 35/100 [00:01<00:02, 22.07it/s]
38%|███▊ | 38/100 [00:01<00:02, 23.35it/s]
41%|████ | 41/100 [00:01<00:02, 21.95it/s]
44%|████▍ | 44/100 [00:02<00:02, 19.05it/s]
47%|████▋ | 47/100 [00:02<00:02, 21.05it/s]
50%|█████ | 50/100 [00:02<00:02, 18.57it/s]
53%|█████▎ | 53/100 [00:02<00:02, 15.88it/s]
55%|█████▌ | 55/100 [00:02<00:02, 16.32it/s]
58%|█████▊ | 58/100 [00:02<00:02, 18.44it/s]
61%|██████ | 61/100 [00:03<00:02, 18.14it/s]
63%|██████▎ | 63/100 [00:03<00:02, 17.78it/s]
66%|██████▌ | 66/100 [00:03<00:01, 19.01it/s]
68%|██████▊ | 68/100 [00:03<00:01, 17.76it/s]
70%|███████ | 70/100 [00:03<00:01, 18.08it/s]
72%|███████▏ | 72/100 [00:03<00:01, 17.25it/s]
74%|███████▍ | 74/100 [00:03<00:01, 17.39it/s]
77%|███████▋ | 77/100 [00:04<00:01, 18.93it/s]
80%|████████ | 80/100 [00:04<00:00, 20.74it/s]
83%|████████▎ | 83/100 [00:04<00:00, 19.49it/s]
85%|████████▌ | 85/100 [00:04<00:00, 17.42it/s]
87%|████████▋ | 87/100 [00:04<00:00, 16.01it/s]
89%|████████▉ | 89/100 [00:04<00:00, 15.99it/s]
91%|█████████ | 91/100 [00:04<00:00, 15.82it/s]
93%|█████████▎| 93/100 [00:04<00:00, 16.44it/s]
95%|█████████▌| 95/100 [00:05<00:00, 17.04it/s]
98%|█████████▊| 98/100 [00:05<00:00, 19.21it/s]
100%|██████████| 100/100 [00:05<00:00, 18.83it/s]
-- done in 5.487349785998958
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
"""
Print all primes between 1 and n
"""
primes = []
for num in range(2, n+1):
is_prime = True
for i in range(2, int(math.sqrt(num))+1):
if num % i == 0:
is_prime = False
break
if is_prime:
primes.append(num)
print(primes)
print_prime(20)
``
Method generate for onnx models¶
We first need to export the model into ONNX.
ONNX Conversion¶
if "position_ids" in export_inputs:
del export_inputs["position_ids"]
del export_shapes["position_ids"]
dtype = get_weight_type(model)
print("-- model dtype:", dtype)
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom"
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
if not os.path.exists(model_name):
# This step is slow so let's skip it if it was already done.
print("-- conversion to ONNX.")
begin = time.perf_counter()
with torch_export_patches(patch_transformers=True):
to_onnx(
model,
(),
kwargs=to_any(export_inputs, device),
dynamic_shapes=export_shapes,
filename=model_name,
verbose=1,
exporter=exporter,
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
-- model dtype: torch.float16
onnx_generate¶
Then we can call method generate for two tokens.
This function is part of onnx_diagnostic but follows the implementation
seen earlier for a torch model.
Let’s ask first the function to return the session to avoid creating on the second call.
_res, session, _feeds = onnx_generate(
model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
)
# And now the full answer.
print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = onnx_generate(
session, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="onnx", duration=duration))
print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...
-- done in 2.70270987999902
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
"""
Print all primes between 1 and n
"""
primes = []
for num in range(2, n+1):
is_prime = True
for i in range(2, int(math.sqrt(num))+1):
if num % i == 0:
is_prime = False
break
if is_prime:
primes.append(num)
print(primes)
print_prime(20)
``
Plots¶
df = pandas.DataFrame(data).set_index("name")
print(df)
duration
name
generate 3.701136
custom 5.487350
onnx 2.702710

Total running time of the script: (0 minutes 34.322 seconds)
Related examples
LayerNormalization implementation cannot be exchanged