From a LLM to processing a prompt

Method generate generates the model answer for a given prompt. Let’s implement our own to understand better how it works and then apply it to an ONNX model.

Example with Phi 1.5

epkg:microsoft/Phi-1.5 is a small LLM. The example given

import os
import time
import sys
import pandas
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic.ext_test_case import unit_test_going
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.torch_helper import to_any, get_weight_type
from onnx_diagnostic.helpers.rt_helper import onnx_generate
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config, task_from_id
from onnx_diagnostic.tasks import random_input_kwargs
from onnx_diagnostic.export.api import to_onnx


device = "cuda" if torch.cuda.is_available() else "cpu"
data = []

print("-- load the model...")
if unit_test_going():
    # unit_test_going() returns True if UNITTEST_GOING is 1
    # The example switches to a faster scenario.
    model_id = "arnir0/Tiny-LLM"
    data_export = get_untrained_model_with_inputs(model_id)
    model = data_export["model"]
    export_inputs = data_export["inputs"]
    export_shapes = data_export["dynamic_shapes"]
    tokenizer = AutoTokenizer.from_pretrained(model_id)
else:
    model_id = "microsoft/phi-1_5"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    config = get_pretrained_config(model_id)
    task = task = task_from_id(model_id)
    kwargs, fct = random_input_kwargs(config, task)
    res = fct(model, config, add_second_input=False, **kwargs)
    export_inputs = res["inputs"]
    export_shapes = res["dynamic_shapes"]
model = model.to(device)
print("-- done.")

print("-- tokenize the prompt...")
inputs = tokenizer(
    '''def print_prime(n):
   """
   Print all primes between 1 and n
   """''',
    return_tensors="pt",
    return_attention_mask=False,
).to(device)
print("-- done.")

print("-- compute the answer...")
begin = time.perf_counter()
outputs = model.generate(**inputs, max_new_tokens=100)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="generate", duration=duration))
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- load the model...
-- done.
-- tokenize the prompt...
-- done.
-- compute the answer...
-- done in 3.7011363160017936
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

eos_token_id?

This token means the end of the answer.

print("eos_token_id=", tokenizer.eos_token_id)
eos_token_id= 50256

Custom method generate

Let’s implement a simple function replicating when method generate does.

def simple_generate_with_cache(
    model, input_ids: torch.Tensor, eos_token_id: int, max_new_tokens: int = 100
):
    # First call: prefill
    outputs = model(input_ids, use_cache=True)

    # Next calls: decode
    for _ in tqdm(list(range(max_new_tokens))):
        next_token_logits = outputs.logits[:, -1, :]
        past_key_values = outputs.past_key_values

        # The most probable next token is chosen.
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
        # But we could select it using a multinomial law
        # <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
        # <<< top_probs, top_indices = torch.topk(probs, top_k)
        # <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]

        if next_token_id.item() == eos_token_id:
            break
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)

        # Feed only the new token, but with the cache
        outputs = model(next_token_id, use_cache=True, past_key_values=past_key_values)

    return input_ids


print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = simple_generate_with_cache(
    model, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="custom", duration=duration))

print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...

  0%|          | 0/100 [00:00<?, ?it/s]
  3%|▎         | 3/100 [00:00<00:05, 19.07it/s]
  5%|▌         | 5/100 [00:00<00:04, 19.17it/s]
  8%|▊         | 8/100 [00:00<00:04, 22.04it/s]
 11%|█         | 11/100 [00:00<00:03, 24.31it/s]
 14%|█▍        | 14/100 [00:00<00:03, 25.60it/s]
 17%|█▋        | 17/100 [00:00<00:03, 24.67it/s]
 20%|██        | 20/100 [00:01<00:05, 15.61it/s]
 23%|██▎       | 23/100 [00:01<00:04, 17.52it/s]
 26%|██▌       | 26/100 [00:01<00:04, 18.18it/s]
 29%|██▉       | 29/100 [00:01<00:03, 19.23it/s]
 32%|███▏      | 32/100 [00:01<00:03, 20.55it/s]
 35%|███▌      | 35/100 [00:01<00:02, 22.07it/s]
 38%|███▊      | 38/100 [00:01<00:02, 23.35it/s]
 41%|████      | 41/100 [00:01<00:02, 21.95it/s]
 44%|████▍     | 44/100 [00:02<00:02, 19.05it/s]
 47%|████▋     | 47/100 [00:02<00:02, 21.05it/s]
 50%|█████     | 50/100 [00:02<00:02, 18.57it/s]
 53%|█████▎    | 53/100 [00:02<00:02, 15.88it/s]
 55%|█████▌    | 55/100 [00:02<00:02, 16.32it/s]
 58%|█████▊    | 58/100 [00:02<00:02, 18.44it/s]
 61%|██████    | 61/100 [00:03<00:02, 18.14it/s]
 63%|██████▎   | 63/100 [00:03<00:02, 17.78it/s]
 66%|██████▌   | 66/100 [00:03<00:01, 19.01it/s]
 68%|██████▊   | 68/100 [00:03<00:01, 17.76it/s]
 70%|███████   | 70/100 [00:03<00:01, 18.08it/s]
 72%|███████▏  | 72/100 [00:03<00:01, 17.25it/s]
 74%|███████▍  | 74/100 [00:03<00:01, 17.39it/s]
 77%|███████▋  | 77/100 [00:04<00:01, 18.93it/s]
 80%|████████  | 80/100 [00:04<00:00, 20.74it/s]
 83%|████████▎ | 83/100 [00:04<00:00, 19.49it/s]
 85%|████████▌ | 85/100 [00:04<00:00, 17.42it/s]
 87%|████████▋ | 87/100 [00:04<00:00, 16.01it/s]
 89%|████████▉ | 89/100 [00:04<00:00, 15.99it/s]
 91%|█████████ | 91/100 [00:04<00:00, 15.82it/s]
 93%|█████████▎| 93/100 [00:04<00:00, 16.44it/s]
 95%|█████████▌| 95/100 [00:05<00:00, 17.04it/s]
 98%|█████████▊| 98/100 [00:05<00:00, 19.21it/s]
100%|██████████| 100/100 [00:05<00:00, 18.83it/s]
-- done in 5.487349785998958
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

Method generate for onnx models

We first need to export the model into ONNX.

ONNX Conversion

if "position_ids" in export_inputs:
    del export_inputs["position_ids"]
    del export_shapes["position_ids"]
dtype = get_weight_type(model)
print("-- model dtype:", dtype)
export_inputs["past_key_values"] = to_any(export_inputs["past_key_values"], dtype)
exporter = "onnx-dynamo" if "dynamo" in sys.argv else "custom"
model_name = f"model_{model_id.replace('/', '-')}.{exporter}.onnx"
if not os.path.exists(model_name):
    # This step is slow so let's skip it if it was already done.
    print("-- conversion to ONNX.")
    begin = time.perf_counter()
    with torch_export_patches(patch_transformers=True):
        to_onnx(
            model,
            (),
            kwargs=to_any(export_inputs, device),
            dynamic_shapes=export_shapes,
            filename=model_name,
            verbose=1,
            exporter=exporter,
        )
    duration = time.perf_counter() - begin
    print(f"-- done in {duration}")
-- model dtype: torch.float16

onnx_generate

Then we can call method generate for two tokens. This function is part of onnx_diagnostic but follows the implementation seen earlier for a torch model. Let’s ask first the function to return the session to avoid creating on the second call.

_res, session, _feeds = onnx_generate(
    model_name, inputs.input_ids, 2, max_new_tokens=2, return_session=True
)

# And now the full answer.
print("-- compute the answer with custom generate...")
begin = time.perf_counter()
outputs = onnx_generate(
    session, inputs.input_ids, eos_token_id=tokenizer.eos_token_id, max_new_tokens=100
)
duration = time.perf_counter() - begin
print(f"-- done in {duration}")
data.append(dict(name="onnx", duration=duration))

print("-- done.")
print("output shape:", string_type(outputs, with_shape=True, with_min_max=True))
print("-- decode the answer...")
text = tokenizer.batch_decode(outputs)[0]
print("-- done.")
print(text)
-- compute the answer with custom generate...
-- done in 2.70270987999902
-- done.
output shape: T7s1x123[7,50285:A10138.878048780489]
-- decode the answer...
-- done.
def print_prime(n):
   """
   Print all primes between 1 and n
   """
   primes = []
   for num in range(2, n+1):
       is_prime = True
       for i in range(2, int(math.sqrt(num))+1):
           if num % i == 0:
               is_prime = False
               break
       if is_prime:
           primes.append(num)
   print(primes)

print_prime(20)
``

Plots

df = pandas.DataFrame(data).set_index("name")
print(df)
          duration
name
generate  3.701136
custom    5.487350
onnx      2.702710
ax = df.plot(kind="bar", title="Time (s) comparison to generate a prompt.", rot=45)
ax.figure.tight_layout()
ax.figure.savefig("plot_generate.png")
Time (s) comparison to generate a prompt.

Total running time of the script: (0 minutes 34.322 seconds)

Related examples

Reproducible Parallelized Reduction is difficult

Reproducible Parallelized Reduction is difficult

LayerNormalization implementation cannot be exchanged

LayerNormalization implementation cannot be exchanged

Dynamic Shapes and Broadcasting

Dynamic Shapes and Broadcasting

Gallery generated by Sphinx-Gallery