Export microsoft/phi-2

This function exports an smaller untrained model with the same architecture. It is faster than the pretrained model. When this works, the untrained model can be replaced by the trained one.

microsoft/phi-2 is not a big model but still quite big when it comes to write unittests. Function onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs() can be used to create a reduced untrained version of a model coming from HuggingFace. It downloads the configuration from the website but creates a dummy model with 1 or 2 hidden layers in order to reduce the size and get a fast execution. The goal is usually to test the export or to compare performance. The relevance does not matter.

Create the dummy model

import copy
import pprint
import warnings
import torch
import onnxruntime
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
from onnx_diagnostic.helpers.rt_helper import make_feeds
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.torch_models.hghub import (
    get_untrained_model_with_inputs,
)

warnings.simplefilter("ignore")

# another tiny id: arnir0/Tiny-LLM
data = get_untrained_model_with_inputs("microsoft/phi-2")
untrained_model, inputs, dynamic_shapes, config, size, n_weights = (
    data["model"],
    data["inputs"],
    data["dynamic_shapes"],
    data["configuration"],
    data["size"],
    data["n_weights"],
)

print(f"model {size / 2**20:1.1f} Mb with {n_weights // 1000} thousands of parameters.")
model 432.3 Mb with 113332 thousands of parameters.

The original model has 2.7 billion parameters. It was divided by more than 10. However, it can still be used with get_untrained_model_with_inputs("microsoft/phi-2", same_as_pretrained=True). Let’s see the configuration.

print(config)
PhiConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "PhiForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "head_dim": 80,
  "hidden_act": "gelu_new",
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 6144,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "phi",
  "num_attention_heads": 32,
  "num_hidden_layers": 2,
  "num_key_value_heads": 32,
  "partial_rotary_factor": 0.4,
  "qk_layernorm": false,
  "resid_pdrop": 0.1,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "subfolder": null,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.52.0.dev0",
  "use_cache": true,
  "vocab_size": 51200
}

Inputs:

print(string_type(inputs, with_shape=True))
dict(input_ids:T7s2x3,attention_mask:T7s2x33,position_ids:T7s2x3,past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80,T1s2x32x30x80], value_cache=#2[T1s2x32x30x80,T1s2x32x30x80]))

With min/max values.

print(string_type(inputs, with_shape=True, with_min_max=True))
dict(input_ids:T7s2x3[5868,42369:A28806.5],attention_mask:T7s2x33[1,1:A1.0],position_ids:T7s2x3[30,32:A31.0],past_key_values:DynamicCache(key_cache=#2[T1s2x32x30x80[-4.339683532714844,4.168929100036621:A-0.00015688629796453135],T1s2x32x30x80[-4.691745758056641,4.495841979980469:A-0.0017857432339830719]], value_cache=#2[T1s2x32x30x80[-4.3160271644592285,4.298698902130127:A-0.0037800918842078777],T1s2x32x30x80[-4.749629974365234,5.0046281814575195:A0.00033762750999332054]]))

And the dynamic shapes

{'attention_mask': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'},
 'input_ids': {0: Dim('batch', min=1, max=1024), 1: 'seq_length'},
 'past_key_values': [[{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}],
                     [{0: Dim('batch', min=1, max=1024), 2: 'cache_length'},
                      {0: Dim('batch', min=1, max=1024), 2: 'cache_length'}]],
 'position_ids': {0: Dim('batch', min=1, max=1024), 1: 'cache+seq'}}

We execute the model to produce expected outputs.

expected = untrained_model(**copy.deepcopy(inputs))
print(f"expected: {string_type(expected, with_shape=True, with_min_max=True)}")
expected: CausalLMOutputWithPast(logits:T1s2x3x51200[-2.425957202911377,2.6620113849639893:A-0.00048338043225195786],past_key_values:DynamicCache(key_cache=#2[T1s2x32x33x80[-4.339683532714844,4.168929100036621:A-0.0001798408414748936],T1s2x32x33x80[-4.691745758056641,4.495841979980469:A-0.001198434475900593]], value_cache=#2[T1s2x32x33x80[-4.3160271644592285,4.298698902130127:A-0.0041093787574674355],T1s2x32x33x80[-4.749629974365234,5.0046281814575195:A-0.0005269664030454032]]))

Export to fx.Graph

torch.export.export() is the first step before converting a model into ONNX. The inputs are duplicated (with copy.deepcopy) because the model may modify them inline (a cache for example). Shapes may not match on the second call with the modified inputs.

with torch_export_patches(patch_transformers=True):

    # Two unnecessary steps but useful in case of an error
    # We check the cache is registered.
    assert is_cache_dynamic_registered()

    # We check there is no discrepancies when the cache is applied.
    d = max_diff(expected, untrained_model(**copy.deepcopy(inputs)))
    assert (
        d["abs"] < 1e-5
    ), f"The model with patches produces different outputs: {string_diff(d)}"

    # Then we export: the only import line in this section.
    ep = torch.export.export(
        untrained_model,
        (),
        kwargs=copy.deepcopy(inputs),
        dynamic_shapes=use_dyn_not_str(dynamic_shapes),
        strict=False,  # mandatory for torch==2.6
    )

    # We check the exported program produces the same results as well.
    # This step is again unnecessary.
    d = max_diff(expected, ep.module()(**copy.deepcopy(inputs)))
    assert d["abs"] < 1e-5, f"The exported model different outputs: {string_diff(d)}"

Export to ONNX

The export works. We can export to ONNX now torch.onnx.export(). Patches are still needed because the export applies torch.export.ExportedProgram.run_decompositions() may export local pieces of the model again.

with torch_export_patches(patch_transformers=True):
    epo = torch.onnx.export(
        ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True
    )
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 53 of general pattern rewrite rules.

We can save it.

epo.save("plot_export_tiny_phi2.onnx", external_data=True)

# Or directly get the :class:`onnx.ModelProto`.
onx = epo.model_proto

Discrepancies

The we check the conversion to ONNX. Let’s make sure the ONNX model produces the same outputs. It takes flatten inputs.

feeds = make_feeds(onx, copy.deepcopy(inputs), use_numpy=True, copy=True)

print(f"torch inputs: {string_type(inputs)}")
print(f"onxrt inputs: {string_type(feeds)}")
torch inputs: dict(input_ids:T7r2,attention_mask:T7r2,position_ids:T7r2,past_key_values:DynamicCache(key_cache=#2[T1r4,T1r4], value_cache=#2[T1r4,T1r4]))
onxrt inputs: dict(input_ids:A7r2,attention_mask:A7r2,position_ids:A7r2,past_key_values_key_cache_0:A1r4,past_key_values_key_cache_1:A1r4,past_key_values_value_cache_0:A1r4,past_key_values_value_cache_1:A1r4)

We then create a onnxruntime.InferenceSession.

sess = onnxruntime.InferenceSession(
    onx.SerializeToString(), providers=["CPUExecutionProvider"]
)

Let’s run.

got = sess.run(None, feeds)

And finally the discrepancies.

diff = max_diff(expected, got, flatten=True)
print(f"onnx discrepancies: {string_diff(diff)}")
onnx discrepancies: abs=2.205371856689453e-06, rel=0.0012052791646639817, n=983040.0

It looks good.

doc.plot_legend("export\nuntrained smaller\nmicrosoft/phi-2", "torch.onnx.export", "orange")
plot export tiny phi2

Possible Issues

Unknown task

Function onnx_diagnostic.torch_models.hghub.get_untrained_model_with_inputs() is unabl to guess a task associated to the model. A different set of dummy inputs is defined for every task. The user needs to explicitly give that information to the function. Tasks are the same as the one defined by HuggingFace/models.

Inputs are incorrect

Example Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM) explains how to retrieve that information. If you cannot guess the dynamic shapes - a cache can be tricky sometimes, follow example Dynamic Shapes for *args, **kwargs.

DynamicCache or any other cache cannot be exported

That’s the role of onnx_diagnostic.torch_export_patches.torch_export_patches(). It registers the necessary information into pytorch to make the export work with these. Its need should slowly disappear until transformers includes the serialization functions.

Control Flow

Every mixture of models goes through a control flow (a test). It also happens when a cache is truncated. The code of the model needs to be changed. See example Export a model with a control flow (If). Loops are not supported yet.

Issue with dynamic shapes

Example Do not use python int with dynamic shapes gives one reason this process may fail but that’s not the only one. Example Find and fix an export issue due to dynamic shapes gives an way to locate the cause but that does not cover all the possible causes. Raising an issue on github would be the recommended option until it is fixed.

Total running time of the script: (0 minutes 23.760 seconds)

Related examples

Test the export on untrained models

Test the export on untrained models

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Steel method forward to guess inputs and dynamic shapes (with Tiny-LLM)

Export Tiny-LLM with patches

Export Tiny-LLM with patches

Gallery generated by Sphinx-Gallery