Export OptiMind-SFT with InputObserver

This reuses the recipe introduced by example Export a LLM with InputObserver (with Tiny-LLM) for model microsoft/OptiMind-SFT. We only export class GptOssExperts.

Let’s create a random model

import pandas
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches import (
    register_additional_serialization_functions,
    torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver

device = "cuda"
model_id = "microsoft/OptiMind-SFT"
print(f"get tokenizer {model_id!r}")
tokenizer = AutoTokenizer.from_pretrained(model_id)
print(f"get config {model_id!r}")
config = AutoConfig.from_pretrained(model_id)
config.num_hidden_layers = 2
config.layer_types = config.layer_types[:2]
print(f"create model from config for {model_id!r}")
model = AutoModelForCausalLM.from_config(config)
print(f"the model is created with {len(list(model.named_modules()))} subdmodules.")
model = model.to(device)
get tokenizer 'microsoft/OptiMind-SFT'
get config 'microsoft/OptiMind-SFT'
create model from config for 'microsoft/OptiMind-SFT'
the model is created with 29 subdmodules.

We need to only export class GptOssExperts

export_module = None
for _name, sub in model.named_modules():
    if sub.__class__.__name__ == "GptOssExperts":
        export_module = sub

assert export_module is not None, (
    f"Unable to find a submodule from class GptOssExperts in "
    f"{set(sub.__class__.__name__ for _, sub in model.named_modules())}"
)

Let’s run the model and capture inputs and outputs

def generate_text(
    prompt,
    model,
    tokenizer,
    max_length=50,
    temperature=0.01,
    top_k=50,
    top_p=0.95,
    do_sample=True,
):
    inputs = tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)

    outputs = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        do_sample=do_sample,
    )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text


prompt = "Continue: it rains, what should I do?"
observer = InputObserver()
with (
    register_additional_serialization_functions(patch_transformers=True),
    observer(export_module),
):
    generate_text(prompt, model, tokenizer)

Export

First, what was inferred.

args = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes()
print(f"args={string_type(args, with_shape=True, with_device=True)}")
print(f"dynamic_shapes={dynamic_shapes}")
args=(GT16s10x2880,GT7s10x4,GT16s10x4)
dynamic_shapes=({0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)})

Next, the export.

filename = "plot_export_optimind_experts_input_observer.onnx"
with torch_export_patches(patch_transformers=True):
    to_onnx(
        export_module,
        args=args,
        filename=filename,
        dynamic_shapes=dynamic_shapes,
        exporter="custom",
        verbose=1,
    )
[to_onnx] build the graph module from <class 'transformers.models.gpt_oss.modeling_gpt_oss.GptOssExperts'>, type(args)=<class 'tuple'>
[to_onnx] dynamic_shapes=({0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)})
[_make_builder_interpreter] export_options=ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>))
[_make_builder_interpreter] input args=(T16r2,T7r2,T16r2)
[_make_builder_interpreter] input kwargs=None
[_make_builder_interpreter] dynamic_shapes=({0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)})
[_make_builder_interpreter] same_signature=True, tracing_mode=symbolic
[ExportOptions.export] ExportOptions(aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)) - torch._dynamo.export 'GptOssExperts'
[ExportOptions.export] aten_as_function=('aten.histc.default', 'aten.index_copy.default', 'aten.index_put.default', 'aten._grouped_mm.default', 'aten.setitem', <built-in function setitem>)
[ExportOptions.export] torch_export strict=False, verbose=1
[ExportOptions.export] dynamic_shapes=({0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)}, {0: DimHint(DYNAMIC)})
[ExportOptions.export] args=(T16r2,T7r2,T16r2)
[ExportOptions.export] kwargs=None
[ExportOptions.export] export start with strict=False...
[ExportOptions.export] export with backed_size_oblivious=auto
[torch_export] backed_size_oblivious='auto'
[torch_export] inferred backed_size_oblivious=None
[torch_export] export starts with backed_size_oblivious=None
[ExportOptions.export] export done in 0.31827711400183034
[ExportOptions.export] post_process_exported_program with decomposition_table=None
[ExportOptions.export] remove inplace nodes
[CustomTracer.remove_inplace] starts with 47 nodes (n_inplace_submobules=0)
[CustomTracer.remove_inplace] S1: 1 inplace nodes
[CustomTracer.remove_inplace] S2: 1 inplace nodes and 100 iterations
[CustomTracer.remove_inplace] end with 100 iterations and 45 nodes (n_inplace=1)
[ExportOptions.export] inplaces: 1 inplaced nodes were removed
[ExportOptions.export] done remove inplace in 0.0014837279995845165, modified=1
[ExportOptions.export] done with no decomposition in 0.0015589950016874354
[to_onnx] graph module done in 0.6925063050002791 s
[to_onnx] start creating the onnx nodes
[to_onnx] interpreter.function_options=FunctionOptions(export_as_function=True, name='*', domain='*', external_threshold=256, move_initializer_to_constant=True, return_initializer=True, merge_allowed=True, rename_allowed=True)
[to_onnx] 49 onnx nodes done in 0.09824645799744758 s
[to_onnx] start conversion to onnx (before optimization) mask_outputs=[True]
[GraphBuilder-UNW.inline_functions] begin inlining graph
[GraphBuilder-UNW.inline_functions] skip_functions={('aten', 'aten__grouped_mm_default'), ('aten', 'aten_histc_default')}
[GraphBuilder-UNW.inline_functions] done inlining graph 133174008248992 in 0.0009099499984586146
[GraphBuilder-UNW._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'batch_1': 'batch_1', 'batch_2': 'batch_2', 'DYN0': 'batch', 's47': 'batch', 'DYN1': 'batch_1', 's81': 'batch_1', 'DYN2': 'batch_2', 's46': 'batch_2'}
[GraphBuilder-UNW.optimize] start with 49 nodes
[GraphBuilder-UNW.optimize] #patterns=113
[GraphBuilder-UNW.optimize] start with subgraphs
[GraphBuilder-UNW.optimize] done with subgraphs
[GraphBuilderPatternOptimization-UNW.optimize] start with 44 nodes, 18 initializers, 113 patterns, priorities=[0, 1, 2, 3], max_iter=176
[GraphBuilderPatternOptimization-UNW.optimize] same children={'SameChildrenFromInputPattern', 'SameChildrenPattern'}
[GraphBuilderPatternOptimization-UNW.optimize] iteration 0: 44 nodes, priority=0
[GraphBuilderPatternOptimization-UNW.optimize] applies 6 matches, 3*CastPattern, 1*IdentityPattern, 1*ShapeBasedEditDistanceReshapePattern, 1*SqueezeUnsqueezePattern - time=0.007 | max_time=SoftmaxCrossEntropyLossCastPattern:0.002
[GraphBuilderPatternOptimization-UNW.optimize] iteration 1: 38 nodes, priority=0
[GraphBuilderPatternOptimization-UNW.optimize] increase priority to 1
[GraphBuilderPatternOptimization-UNW.optimize] iteration 2: 38 nodes, priority=1
[GraphBuilderPatternOptimization-UNW.optimize] applies 1 matches, [0]=MatchResult: QuickGeluPattern replaces ['Sigmoid', 'Mul'] - time=0.005 | max_time=CastOpCastPattern:0.000
[GraphBuilderPatternOptimization-UNW.optimize] iteration 3: 37 nodes, priority=1
[GraphBuilderPatternOptimization-UNW.optimize] increase priority to 2
[GraphBuilderPatternOptimization-UNW.optimize] iteration 4: 37 nodes, priority=2
[GraphBuilderPatternOptimization-UNW.optimize] increase priority to 3
[GraphBuilderPatternOptimization-UNW.optimize] iteration 5: 37 nodes, priority=3
[GraphBuilderPatternOptimization-UNW.optimize] stops current_priority_index=4, priorities=[0, 1, 2, 3]
[GraphBuilderPatternOptimization-UNW.optimize] done after 6 iterations with 37 nodes in 0.039
[OrderOptimization.optimize] ALGO-2
[OrderOptimization.random_order] -- starts with 37 nodes, 16 initializers
[OrderOptimization.shape_order] done after in 0.00023336400045081973s with changed=2 scale=3
[GraphBuilder-UNW.optimize] done with 37 nodes in 0.047
[GraphBuilder-UNW.to_onnx] make_model 20 inits 4 params
[GraphBuilder-UNW.time_evaluation_constants_] 0
[GraphBuilder-UNW._build_initializers] start with 20 initializers, large_model=True, external_threshold=1024
[GraphBuilder-UNW._build_initializers] switch low/high order
[GraphBuilder-UNW._build_initializers] done in 1.5259975043591112e-06s with 16 initializers, 4 large initializers
[GraphBuilder-UNW._add_shape_information] dynamic shapes replacements={'batch': 'batch', 'batch_1': 'batch_1', 'batch_2': 'batch_2', 'DYN0': 'batch', 's47': 'batch', 'DYN1': 'batch_1', 's81': 'batch_1', 'DYN2': 'batch_2', 's46': 'batch_2'}
[to_onnx] to_onnx done in 0.05885478799973498s and 37 nodes, 16 initializers, 3 inputs, 1 outputs

Let’s measure the discrepancies.

data = observer.check_discrepancies(filename, progress_bar=True, atol=1e-2, include_io=True)
df = pandas.DataFrame(data)
df.to_excel("plot_export_optimind_input_observer.xlsx")
print(df)
  0%|          | 0/3 [00:00<?, ?it/s]
 33%|███▎      | 1/3 [00:12<00:25, 12.99s/it]
 67%|██████▋   | 2/3 [00:15<00:07,  7.03s/it]
100%|██████████| 3/3 [00:18<00:00,  5.16s/it]
100%|██████████| 3/3 [00:18<00:00,  6.26s/it]
                                               error  SUCCESS  ...    outputs_torch  outputs_ort
0  Unable to create a session stored in '_debug_I...    False  ...  #1[T16s10x2880]         None
1  Unable to create a session stored in '_debug_I...    False  ...   #1[T16s1x2880]         None
2  Unable to create a session stored in '_debug_I...    False  ...   #1[T16s1x2880]         None

[3 rows x 11 columns]

Let’s show the errors.

for row in data:
    if not row["SUCCESS"] and "error" in row:
        print(row["error"])
Unable to create a session stored in '_debug_InferenceSession_last_failure.onnx'), providers=['CUDAExecutionProvider']
Unable to create a session stored in '_debug_InferenceSession_last_failure.onnx'), providers=['CUDAExecutionProvider']
Unable to create a session stored in '_debug_InferenceSession_last_failure.onnx'), providers=['CUDAExecutionProvider']
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
plot export optimind input observer

Total running time of the script: (2 minutes 8.048 seconds)

Related examples

Export attention from arnir0/Tiny-LLM with InputObserver

Export attention from arnir0/Tiny-LLM with InputObserver

Export Gemma3 tiny random with InputObserver

Export Gemma3 tiny random with InputObserver

Export a LLM with InputObserver (with Tiny-LLM)

Export a LLM with InputObserver (with Tiny-LLM)

Gallery generated by Sphinx-Gallery