101: A custom backend for torch

This example leverages the examples introduced on this page Custom Backends. It uses backend experimental_experiment.torch_dynamo.onnx_custom_backend() based on onnxruntime and running on CPU or CUDA. It could easily replaced by experimental_experiment.torch_dynamo.onnx_debug_backend(). This one based on the reference implemented from onnx can show the intermediate results if needed. It is very slow.

A model

import copy
from experimental_experiment.helpers import pretty_onnx
from onnx_array_api.plotting.graphviz_helper import plot_dot
import torch
from torch._dynamo.backends.common import aot_autograd

# from torch._functorch._aot_autograd.utils import make_boxed_func
from experimental_experiment.torch_dynamo import (
    onnx_custom_backend,
    get_decomposition_table,
)
from experimental_experiment.torch_interpreter import ExportOptions


class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Linear(10, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
        )

    def forward(self, x):
        return self.layers(x)


x = torch.randn(3, 10, dtype=torch.float32)

mlp = MLP()
print(mlp(x))
tensor([[ 0.0989],
        [-0.4838],
        [-0.3123]], grad_fn=<AddmmBackward0>)

A custom backend

This backend leverages onnxruntime. It is available through function experimental_experiment.torch_dynamo.onnx_custom_backend() and implemented by class OrtBackend.

compiled_model = torch.compile(
    copy.deepcopy(mlp),
    backend=lambda *args, **kwargs: onnx_custom_backend(*args, target_opset=18, **kwargs),
    dynamic=False,
    fullgraph=True,
)

print(compiled_model(x))
tensor([[ 0.0989],
        [-0.4838],
        [-0.3123]])

Training

It can be used for training as well. The compilation may not be working if the model is using function the converter does not know. Maybe, there exist a way to decompose this new function into existing functions. A recommended list is returned by with function get_decomposition_table. An existing list can be filtered out from some inefficient decompositions with function filter_decomposition_table.

aot_compiler = aot_autograd(
    fw_compiler=lambda *args, **kwargs: onnx_custom_backend(
        *args,
        target_opset=18,
        export_options=ExportOptions(decomposition_table=get_decomposition_table()),
        **kwargs,
    ),
)

compiled_model = torch.compile(
    copy.deepcopy(mlp),
    backend=aot_compiler,
    fullgraph=True,
    dynamic=False,
)

print(compiled_model(x))
tensor([[ 0.0989],
        [-0.4838],
        [-0.3123]], grad_fn=<CompiledFunctionBackward>)

Let’s see an iteration loop.

from sklearn.datasets import load_diabetes


class DiabetesDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = torch.from_numpy(X / 10).to(torch.float32)
        self.y = torch.from_numpy(y).to(torch.float32).reshape((-1, 1))

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        return self.X[i], self.y[i]


def trained_model(max_iter=5, dynamic=False, storage=None):
    aot_compiler = aot_autograd(
        fw_compiler=lambda *args, **kwargs: onnx_custom_backend(
            *args, target_opset=18, storage=storage, **kwargs
        ),
        decompositions=get_decomposition_table(),
    )

    compiled_model = torch.compile(
        MLP(),
        backend=aot_compiler,
        fullgraph=True,
        dynamic=dynamic,
    )

    trainloader = torch.utils.data.DataLoader(
        DiabetesDataset(*load_diabetes(return_X_y=True)),
        batch_size=5,
        shuffle=True,
        num_workers=0,
    )

    loss_function = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(compiled_model.parameters(), lr=1e-1)

    for epoch in range(0, max_iter):
        current_loss = 0.0

        for _, data in enumerate(trainloader, 0):
            X, y = data

            optimizer.zero_grad()
            p = compiled_model(X)
            loss = loss_function(p, y)
            loss.backward()

            optimizer.step()

            current_loss += loss.item()

        print(f"Loss after epoch {epoch+1}: {current_loss}")

    print("Training process has finished.")
    return compiled_model


trained_model(3)
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
Loss after epoch 1: 7305.781032562256
Loss after epoch 2: 5443.838684082031
Loss after epoch 3: 5179.002107620239
Training process has finished.

OptimizedModule(
  (_orig_mod): MLP(
    (layers): Sequential(
      (0): Linear(in_features=10, out_features=32, bias=True)
      (1): ReLU()
      (2): Linear(in_features=32, out_features=1, bias=True)
    )
  )
)

What about the ONNX model?

The backend converts the model into ONNX then runs it with onnxruntime. Let’s see what it looks like.

storage = {}

trained_model(3, storage=storage)

print(f"{len(storage['instance'])} were created.")

for i, inst in enumerate(storage["instance"][:2]):
    print()
    print(f"-- model {i} running on {inst['providers']}")
    print(pretty_onnx(inst["onnx"]))
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
Loss after epoch 1: 7516.51252746582
Loss after epoch 2: 5555.762649536133
Loss after epoch 3: 5202.970405578613
Training process has finished.
4 were created.

-- model 0 running on ['CPUExecutionProvider']
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input0' type=dtype('float32') shape=[32, 10]
input: name='input1' type=dtype('float32') shape=[32]
input: name='input2' type=dtype('float32') shape=[5, 10]
input: name='input3' type=dtype('float32') shape=[1, 32]
input: name='input4' type=dtype('float32') shape=[1]
Gemm(input2, input0, input1, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
  Relu(addmm) -> output_2
    Gemm(output_2, input3, input4, transA=0, transB=1, alpha=1.00, beta=1.00) -> output_0
Transpose(input3, perm=[1,0]) -> output_3
Identity(input2) -> output_1
output: name='output_0' type=dtype('float32') shape=[5, 1]
output: name='output_1' type=dtype('float32') shape=[5, 10]
output: name='output_2' type=dtype('float32') shape=[5, 32]
output: name='output_3' type=dtype('float32') shape=[32, 1]

-- model 1 running on ['CPUExecutionProvider']
opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input0' type=dtype('float32') shape=[5, 10]
input: name='input1' type=dtype('float32') shape=[5, 32]
input: name='input2' type=dtype('float32') shape=[32, 1]
input: name='input3' type=dtype('float32') shape=[5, 1]
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Constant(value_float=0.0) -> output_NONE_2
Gemm(input3, input1, transA=1, transB=0) -> output_3
Gemm(input3, input2, transA=0, transB=1) -> mm
ReduceSum(input3, init7_s1_0, keepdims=0) -> output_4
LessOrEqual(input1, init1_s1_) -> _onx_lessorequal0
  Where(_onx_lessorequal0, init1_s1_, mm) -> threshold_backward
    Gemm(threshold_backward, input0, transA=1, transB=0) -> output_0
ReduceSum(threshold_backward, init7_s1_0, keepdims=0) -> output_1
output: name='output_0' type=dtype('float32') shape=[32, 10]
output: name='output_1' type=dtype('float32') shape=[32]
output: name='output_NONE_2' type=dtype('float32') shape=None
output: name='output_3' type=dtype('float32') shape=[1, 32]
output: name='output_4' type=dtype('float32') shape=[1]

The forward graph.

plot_dot(storage["instance"][0]["onnx"])
plot torch custom backend 101
<Axes: >

The backward graph.

plot_dot(storage["instance"][1]["onnx"])
plot torch custom backend 101
<Axes: >

What about dynamic shapes?

Any input or output having _dim_ in its name is a dynamic dimension. Any output having _NONE_ in its name is replace by None. It is needed by pytorch.

storage = {}

trained_model(3, storage=storage, dynamic=True)

print(f"{len(storage['instance'])} were created.")

for i, inst in enumerate(storage["instance"]):
    print()
    print(f"-- model {i} running on {inst['providers']}")
    print()
    print(pretty_onnx(inst["onnx"]))
/home/xadupre/vv/this/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py:130: UserWarning: Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale.
  warnings.warn(
Loss after epoch 1: 7669.064441680908
Loss after epoch 2: 5634.684083938599
Loss after epoch 3: 5237.810531616211
Training process has finished.
2 were created.

-- model 0 running on ['CPUExecutionProvider']

opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input0' type=dtype('float32') shape=[32, 10]
input: name='input1' type=dtype('float32') shape=[32]
input: name='input_dim_2' type=dtype('int64') shape=[1]
input: name='input3' type=dtype('float32') shape=['s0', 10]
input: name='input4' type=dtype('float32') shape=[1, 32]
input: name='input5' type=dtype('float32') shape=[1]
Gemm(input3, input0, input1, transA=0, transB=1, alpha=1.00, beta=1.00) -> addmm
  Relu(addmm) -> output_2
    Gemm(output_2, input4, input5, transA=0, transB=1, alpha=1.00, beta=1.00) -> output_0
Transpose(input4, perm=[1,0]) -> output_3
Identity(input3) -> output_1
Identity(input_dim_2) -> output_dim_4
output: name='output_0' type=dtype('float32') shape=['s0', 1]
output: name='output_1' type=dtype('float32') shape=['s0', 10]
output: name='output_2' type=dtype('float32') shape=['s0', 32]
output: name='output_3' type=dtype('float32') shape=[32, 1]
output: name='output_dim_4' type=dtype('int64') shape=[1]

-- model 1 running on ['CPUExecutionProvider']

opset: domain='' version=18
doc_string: large_model=False, inline=False, external_threshold=102...
input: name='input_dim_0' type=dtype('int64') shape=[1]
input: name='input1' type=dtype('float32') shape=['s0', 10]
input: name='input2' type=dtype('float32') shape=['s0', 32]
input: name='input3' type=dtype('float32') shape=[32, 1]
input: name='input4' type=dtype('float32') shape=['s0', 1]
init: name='init7_s1_0' type=dtype('int64') shape=(1,) -- array([0])
init: name='init1_s1_' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Constant(value_float=0.0) -> output_NONE_2
  Identity(output_NONE_2) -> output_NONE_3
Gemm(input4, input2, transA=1, transB=0) -> output_4
Gemm(input4, input3, transA=0, transB=1) -> mm
ReduceSum(input4, init7_s1_0, keepdims=0) -> output_5
LessOrEqual(input2, init1_s1_) -> _onx_lessorequal0
  Where(_onx_lessorequal0, init1_s1_, mm) -> threshold_backward
    Gemm(threshold_backward, input1, transA=1, transB=0) -> output_0
ReduceSum(threshold_backward, init7_s1_0, keepdims=0) -> output_1
output: name='output_0' type=dtype('float32') shape=[32, 10]
output: name='output_1' type=dtype('float32') shape=[32]
output: name='output_NONE_2' type=dtype('float32') shape=None
output: name='output_NONE_3' type=dtype('float32') shape=None
output: name='output_4' type=dtype('float32') shape=[1, 32]
output: name='output_5' type=dtype('float32') shape=[1]

The forward graph.

plot_dot(storage["instance"][0]["onnx"])
plot torch custom backend 101
<Axes: >

The backward graph.

plot_dot(storage["instance"][1]["onnx"])
plot torch custom backend 101
<Axes: >

Pattern Optimizations

By default, once exported into onnx, a model is optimized by looking for patterns. Each of them locally replaces a couple of nodes to optimize the computation (see experimental_experiment.xoptim.patterns and # experimental_experiment.xoptim.patterns_ort).

Total running time of the script: (0 minutes 2.331 seconds)

Gallery generated by Sphinx-Gallery