import os
import numpy as np
import onnx
from typing import Any, Dict, List, Optional, Union
from ..export_helpers import torch_export
[docs]
def get_fused_aten_ops_dispatcher():
"""
Returns a dispatcher with additional converting function to
convert fused operators into ATen ops onnxruntime can call.
"""
from ..torch_interpreter import Dispatcher
def onnx_scaled_dot_product_efficient_attention(
g: "GraphBuilder", # noqa: F821
sts: Dict[str, Any],
outputs: List[str],
query,
key,
value,
attn_bias,
compute_log_sumexp: bool,
dropout_p: float,
is_causal: bool,
scale: float = 1.0,
**kwargs,
):
assert len(outputs) == 4, f"Unexpected number of outputs {outputs}{g.get_debug_msg()}"
assert len(kwargs) == 0, (
f"Unexpected kwargs {kwargs} in "
f"onnx_scaled_dot_product_efficient_attention{g.get_debug_msg()}"
)
t_compute_log_sumexp = g.make_initializer(
"",
np.array(compute_log_sumexp, dtype=np.bool_),
source="onnx_scaled_dot_product_efficient_attention.t_compute_log_sumexp",
)
t_dropout_p = g.make_initializer(
"",
np.array(dropout_p, dtype=np.float32),
source="onnx_scaled_dot_product_efficient_attention.t_dropout_p",
)
t_is_causal = g.make_initializer(
"",
np.array(is_causal, dtype=np.bool_),
source="onnx_scaled_dot_product_efficient_attention.t_is_causal",
)
t_scale = g.make_initializer(
"",
np.array(scale or 1.0, dtype=np.float32),
source="onnx_scaled_dot_product_efficient_attention.t_scale",
)
output, log_sumexp, philox_seed, philox_offset = g.make_node(
"ATen",
[
query,
key,
value,
attn_bias or "",
t_compute_log_sumexp,
t_dropout_p,
t_is_causal,
t_scale,
],
outputs=[
g.unique_name(n) for n in ["output", "log_sumexp", "philox_seed", "philox_offset"]
],
operator="_scaled_dot_product_efficient_attention",
domain="org.pytorch.aten",
name="scaled_dot_product_efficient_attention",
)
g.set_type(output, g.get_type(query))
g.set_type(log_sumexp, onnx.TensorProto.FLOAT)
g.set_rank(output, g.get_rank(query))
g.set_rank(log_sumexp, g.get_rank(query))
g.set_type(philox_seed, onnx.TensorProto.INT64)
g.set_type(philox_offset, onnx.TensorProto.INT64)
g.set_shape(philox_seed, tuple())
g.set_shape(philox_offset, tuple())
g.add_domain("org.pytorch.aten")
res = []
for i, (name, rename) in enumerate(
zip([output, log_sumexp, philox_seed, philox_offset], outputs)
):
res.append(
g.op.Identity(
name,
outputs=[rename],
name=(
"onnx_scaled_dot_product_efficient_attention"
if i < 0
else "_DONOTREMOVE_onnx_scaled_dot_product_efficient_attention"
),
)
)
return tuple(res)
def onnx_scaled_dot_product_attention_backward(
g: "GraphBuilder", # noqa: F821
sts: Dict[str, Any],
outputs: List[str],
grad,
query,
key,
value,
attn_bias,
output,
logsumexp,
philox_seed,
philox_offset,
dropout_p,
grad_input_mask,
is_causal: bool,
scale: float = 1.0,
**kwargs,
):
assert len(outputs) == 4, f"Unexpected number of outputs {outputs}{g.get_debug_msg()}"
assert len(kwargs) == 0, (
f"Unexpected kwargs {kwargs} in "
f"onnx_scaled_dot_product_attention_backward{g.get_debug_msg()}"
)
t_scale = g.make_initializer(
"",
np.array(scale or 1.0, dtype=np.float32),
source="onnx_scaled_dot_product_attention_backward.t_scale",
)
t_dropout_p = g.make_initializer(
"",
np.array(dropout_p, dtype=np.float32),
source="onnx_scaled_dot_product_attention_backward.t_dropout_p",
)
t_is_causal = g.make_initializer(
"",
np.array(is_causal, dtype=np.bool_),
source="onnx_scaled_dot_product_attention_backward.t_is_causal",
)
t_grad_input_mask = g.make_initializer(
"",
np.array(grad_input_mask, dtype=np.int64),
source="onnx_scaled_dot_product_attention_backward.t_grad_input_mask",
)
# onnxruntime fails with type inference failed
# Let's add some Cast even if not needed.
dt = g.get_type(grad)
helper = ",".join(map(str, [dt, dt, dt, dt]))
node_name = f"scaled_dot_product_attention_backward[{helper}]"
grad_query, grad_key, grad_value, grad_attn_bias = g.make_node(
"ATen",
[
grad,
query,
key,
value,
attn_bias or "",
output,
logsumexp,
philox_seed,
philox_offset,
t_dropout_p,
t_grad_input_mask,
t_is_causal,
t_scale,
],
outputs=outputs,
operator="_scaled_dot_product_efficient_attention_backward",
domain="org.pytorch.aten",
name=node_name,
)
g.add_domain("org.pytorch.aten")
return grad_query, grad_key, grad_value, grad_attn_bias
dispatcher = Dispatcher(
{
"_scaled_dot_product_efficient_attention_default": onnx_scaled_dot_product_efficient_attention, # noqa: E501
"_scaled_dot_product_efficient_attention_backward_default": onnx_scaled_dot_product_attention_backward, # noqa: E501
}
)
return dispatcher
[docs]
def create_compiled_model(
model: Any,
backend: str,
target_opset: int,
use_dynamic: bool = False,
verbose: int = 0,
enable_pattern: Union[str, List[str]] = "default",
disable_pattern: Optional[Union[str, List[str]]] = None,
return_storage: bool = False,
rename_inputs: bool = True,
dump_prefix: Optional[str] = None,
dump_patterns: Optional[str] = None,
optimize: bool = True,
ort_optimize: bool = True,
use_fused_aten_ops: bool = False,
processor: str = "CPU",
order_algorithm: str = "NONE",
) -> Any:
"""
Creates the compiled model.
:param model: module
:param backend: kind of backend
:param use_dynamic: use dynamic shape
:param verbose: verbosity
:param enable_pattern: to enable optimization pattern
:param disable_pattern: to disable optimization pattern
:param return_storage: return a container for the models,
only works with backend *custom* and *debug*
:param rename_inputs: rename inputs into ``input_{i}``
:param dump_prefix: dumps the models (backend, custom and debug)
:param dump_patterns: dumps the optimization applied patterns if applicable
:param optimize: enable optimizations
:param ort_optimize: enables onnxruntime optimization
:param use_fused_aten_ops: use fused opetor when converting the model,
it only works the backend custom
:param processor: optimization should be made for this processor
or this list of processors (comma separated value)
:param order_algorithm: algorithm optimizing the order the onnx node,
none by default
:return: compiled model
"""
import torch
from torch._dynamo.backends.common import aot_autograd
from experimental_experiment.torch_models.training_helper import make_aot_ort
from experimental_experiment.torch_dynamo import (
get_decomposition_table,
get_decomposition_table_dynamo,
dynger_backend,
onnx_custom_backend,
onnx_debug_backend,
)
assert dump_patterns is None or isinstance(
dump_patterns, str
), f"Unexpected type {type(dump_patterns)} for dump_patterns."
assert isinstance(
ort_optimize, bool
), f"Unexpected type={type(ort_optimize)} for ort_optimize={ort_optimize}"
ort_optimization_level = "ORT_ENABLE_ALL" if ort_optimize else "ORT_DISABLE_ALL"
if use_fused_aten_ops and backend in {
"ort",
"custom",
"custom-fallback",
"backort",
"plug",
"ort+",
}:
from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor
from onnxruntime.capi import _pybind_state as _C
_C.register_aten_op_executor(
str(aten_op_executor.is_tensor_argument_address()),
str(aten_op_executor.execute_aten_operator_address()),
)
dispatcher = get_fused_aten_ops_dispatcher()
else:
dispatcher = None
if backend == "ort":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
assert (
optimize
), f"Optimization with onnxscript cannot be disabled with backend={backend!r}."
_local_aot_ort, local_ort = make_aot_ort(dynamic=use_dynamic, verbose=verbose)
return torch.compile(model, backend=local_ort)
if backend == "ort+":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
_local_aot_ort, local_ort = make_aot_ort(
dynamic=use_dynamic,
rewrite=True,
verbose=verbose,
rewrite_more=optimize,
enable_pattern=enable_pattern,
disable_pattern=disable_pattern,
processor=processor,
ort_optimization_level=ort_optimization_level,
order_algorithm=order_algorithm,
dump_patterns=dump_patterns,
)
return torch.compile(model, backend=local_ort)
if backend == "plug":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
os.environ["ONNXRT_CHANGE_REWRITER"] = "1"
_local_aot_ort, local_ort = make_aot_ort(
dynamic=use_dynamic,
rewrite=True,
verbose=verbose,
rewrite_more=True,
enable_pattern=enable_pattern,
disable_pattern=disable_pattern,
processor=processor,
order_algorithm=order_algorithm,
dump_prefix=dump_prefix,
dump_patterns=dump_patterns,
)
return torch.compile(model, backend=local_ort)
if backend == "inductor":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
return torch.compile(model, backend="inductor", dynamic=use_dynamic)
if backend == "trt":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
assert not use_dynamic, (
"TensorRT is not implemented when use_dynamic is False. "
"In that case, inputs should be a list of torch_tensorrt.Input objects. "
)
# TODO: create a specific backend,
# https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/_compiler.py#L44
import torch_tensorrt
class trt_backend:
def __init__(self, model):
self.model = model
self.trt = None
def __call__(self, *args):
if self.trt is None:
if self.verbose:
print("[create_compiled_model] run torch.export.export")
exp_program = torch_export(self.model, args)
if self.verbose:
print("[create_compiled_model] run torch_tensorrt.dynamo.compile")
self.trt = torch_tensorrt.dynamo.compile(exp_program, args)
if self.verbose:
print("[create_compiled_model] done")
return self.trt(*args)
return trt_backend(model)
if backend == "eager":
assert not return_storage, f"return_storage=True not implemented with backend={backend!r}"
return model
if backend == "custom":
storage = {} if return_storage else None
target_opset = target_opset
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_custom_backend(
*args,
target_opset=target_opset,
verbose=verbose,
enable_pattern=enable_pattern,
disable_pattern=disable_pattern,
storage=storage,
rename_inputs=rename_inputs,
dump_prefix=dump_prefix,
dump_patterns=dump_patterns,
optimize=optimize,
dispatcher=dispatcher,
processor=processor,
ort_optimization_level=ort_optimization_level,
order_algorithm=order_algorithm,
**kwargs,
),
decompositions=get_decomposition_table(),
)
cc = torch.compile(model, backend=aot_compiler, fullgraph=True, dynamic=use_dynamic)
if return_storage:
return cc, storage
return cc
if backend == "backort":
storage = {} if return_storage else None
target_opset = target_opset
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_custom_backend(
*args,
target_opset=target_opset,
verbose=verbose,
enable_pattern=enable_pattern,
disable_pattern=disable_pattern,
storage=storage,
rename_inputs=rename_inputs,
dump_prefix=dump_prefix,
dump_patterns=dump_patterns,
optimize=optimize,
exporter="dynamo",
dispatcher=dispatcher,
processor=processor,
ort_optimization_level=ort_optimization_level,
order_algorithm=order_algorithm,
**kwargs,
),
decompositions=get_decomposition_table_dynamo(),
)
cc = torch.compile(model, backend=aot_compiler, fullgraph=True, dynamic=use_dynamic)
if return_storage:
return cc, storage
return cc
if backend == "debug":
storage = {} if return_storage else None
target_opset = target_opset
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: onnx_debug_backend(
*args,
target_opset=target_opset,
backend="ref",
enable_pattern=enable_pattern,
disable_pattern=disable_pattern,
storage=storage,
rename_inputs=rename_inputs,
verbose=verbose,
dump_prefix=dump_prefix,
dump_patterns=dump_patterns,
optimize=optimize,
dispatcher=dispatcher,
processor=processor,
ort_optimization_level=ort_optimization_level,
order_algorithm=order_algorithm,
**kwargs,
),
decompositions=get_decomposition_table(),
)
cc = torch.compile(model, backend=aot_compiler, fullgraph=True, dynamic=use_dynamic)
if return_storage:
return cc, storage
return cc
if backend == "dynger":
aot_compiler = aot_autograd(
fw_compiler=lambda *args, **kwargs: dynger_backend(
*args, verbose=verbose, optimize=optimize, **kwargs
),
decompositions=get_decomposition_table(),
)
cc = torch.compile(model, backend=aot_compiler, fullgraph=True, dynamic=use_dynamic)
if return_storage:
return cc, None
return cc
if backend == "ortmodule":
from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel
if dump_prefix:
os.environ["ORTMODULE_CACHE_DIR"] = os.path.dirname(dump_prefix)
opts = DebugOptions(
save_onnx=True,
log_level=LogLevel.VERBOSE if verbose else LogLevel.ERROR,
onnx_prefix=dump_prefix,
)
else:
opts = DebugOptions(
log_level=LogLevel.VERBOSE if verbose else LogLevel.ERROR,
)
return ORTModule(model, opts)
raise ValueError(f"Unexpected backend={backend!r}.")