from typing import Any, Dict, Optional
import onnx
import torch
from . import OpRunKernel, OpRunTensor
[docs]
class OpRunControlFlow(OpRunKernel):
"""Common ancestor for control flows."""
[docs]
@classmethod
def has_subgraphs(cls) -> bool:
"""Returns True if the kernel has subgraphs."""
return True
def __init__(
self,
node: onnx.NodeProto,
version: Optional[int] = None,
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
verbose: int = 0,
):
super().__init__(node, version, verbose=verbose)
assert (
parent is not None
), f"parent must be specified for operator {self.__class__.__name__!r}"
for att in node.attribute:
if att.type == onnx.AttributeProto.GRAPH:
rt = parent.__class__(
att.g,
providers=parent.providers,
opsets=parent.opsets,
local_functions=parent.functions,
verbose=parent.verbose,
custom_kernels=parent.custom_kernels,
)
setattr(self, att.name, rt)
[docs]
class If_1(OpRunControlFlow):
"If"
[docs]
def run(self, cond, context: Optional[Dict[str, Any]] = None):
rt = self.then_branch if cond.tensor.item() else self.else_branch # type: ignore[attr-defined]
return rt.run_with_values(context=context)
[docs]
class Loop_16(OpRunControlFlow):
"Loop"
def __init__(
self,
node: onnx.NodeProto,
version: Optional[int] = None,
parent: Optional["onnx_diagnostic.reference.TorchOnnxEvaluator"] = None, # noqa: F821
verbose: int = 0,
):
super().__init__(node, version, parent, verbose=verbose)
self.output_index = {n: i for i, n in enumerate(self.body.output_names)}
self.N = len(self.body.input_names) - 2
self.K = len(self.body.output_names) - self.N - 1
[docs]
def run(self, M, cond, *args, context: Optional[Dict[str, Any]] = None):
if args:
v_initial = args[0]
args = args[1:]
else:
v_initial = None
assert M is None or hasattr(
M, "dtype"
), f"M must be empty or an array but its type is {type(M)}."
body = self.body
loop_inputs = body.input_names
inputs = dict.fromkeys(loop_inputs)
if v_initial is not None:
inputs[loop_inputs[2]] = v_initial
cond_name = body.output_names[0]
if args:
begin = len(loop_inputs) - len(args)
all_inputs = loop_inputs[begin:]
for name, val in zip(all_inputs, args):
inputs[name] = val
if context is not None:
for a in context:
inputs[a] = context[a]
k_carried_away = [[] for i in range(self.K)] # type: ignore
it = 0
while (cond is None or cond.tensor is None or cond.tensor.item()) and (
M is None or M.tensor is None or it < M.tensor.item()
):
if len(body.input_names) > 0 and body.input_names[0] is not None:
inputs[body.input_names[0]] = OpRunTensor(
torch.tensor(it, dtype=None if M is None else M.dtype)
)
if len(body.input_names) > 1 and body.input_names[1] is not None:
inputs[body.input_names[1]] = cond
outputs = list(
self.body.run_with_values(
*[inputs[k] for k in self.body.input_names], context=context
)
)
if self.K > 0:
for k in range(self.K):
k_carried_away[k].append(outputs[-self.K + k])
index_cond = self.output_index[cond_name]
cond = outputs[index_cond]
assert (
cond is not None
), f"Condition {cond_name!r} returned by the subgraph cannot be None."
for i, o in zip(body.input_names[2:], body.output_names[1:]):
inputs[i] = outputs[self.output_index[o]]
it += 1
if it == 0:
outputs = [inputs[i] for i in body.input_names[2:]]
else:
outputs = outputs[1 : 1 + self.N]
outputs.extend([OpRunTensor(torch.cat(x, axis=0)) for x in k_carried_away])
while len(outputs) < len(self.body.output_names):
outputs.append(OpRunTensor(torch.empty(())))
return tuple(outputs)