experimental_experiment.xoptim

GraphBuilderPatternOptimization

class experimental_experiment.xoptim.GraphBuilderPatternOptimization(builder: GraphBuilder, patterns: List[PatternOptimization] | None = None, recursive: bool = False, verifies: bool = False, verbose: int = 0, dump_applied_patterns: str | None = None, processor: str = 'CPU')[source]

Implements optimization after the conversion is done. The differences between the two models can be display with a command line such as:

python -m onnx_array_api compare -m1 <model.onnx> -m2 <optimized.onnx> -m nodes -c 80

This class assumes a pattern cannot reuse an existing name.

Parameters:
  • builder – GraphBuilder

  • patterns – list of patterns to apply

  • recursive – goes through subgraphs

  • verifies – verifies the model but it takes time

  • verbose – verbosity

  • dump_applied_patterns – dump applied patterns in a folder, the users can check every pattern dumped as a FunctionProto

  • processor – optimization should be made for this processor or this list of processors (comma separated value)

apply_match(match: MatchResult) List[NodeProto][source]

Applies one match. Returns the new nodes.

do_not_remove(node: NodeProto) bool[source]

Tells if a node can be removed.

get_attribute(node: NodeProto, att_name: str, exc: bool = True) AttributeProto | None[source]

Returns an attribute for a node.

get_attributes_with_default(node: NodeProto, **default_values) Dict[str, Any][source]

Returns integer or float values for attributes.

get_axis(node: NodeProto, default_axis: int | None = None) int[source]

Retrieves the axis for many operators.

get_computed_constant(name: str, statistics: List[str] | None = None) Any[source]

Returns the value for the constant name.

get_constant_or_attribute(node: NodeProto, attribute: str, input_index: int, cvt: Callable | None = None) Any[source]

Returns an input or the value of an attribute. Some attributes became inputs in more recent opsets. The function checks both.

Parameters:
  • node – node

  • attribute – attribute name

  • input_index – input index

  • cvt – if not None, called this conversion function before returning the result

Returns:

value

get_constant_scalar(name: str, broadcast: bool = False) int | float[source]

Returns a scalar as a constant.

Parameters:
  • name – name

  • broadcast – consider [1], [[1]], [[[1]]] as constant as well

Returns:

int or float

get_constant_shape(name: str, exc: bool = True) Tuple[int, ...] | None[source]

Returns the shape of a constant.

Parameters:
  • name – name

  • exc – raises an exception is not possible

Returns:

shape

get_rank(name: str) int[source]

Returns the rank of a result.

get_shape(name: str) int[source]

Returns the shape of a result.

get_type(name: str) int[source]

Returns the type of a result.

has_processor(processor: str) bool[source]

Checks the process is on the list of used processors.

has_rank(name: str) int[source]

Tells if a result has a rank.

has_shape(name: str) bool[source]

Tells if a result has a shape.

has_type(name: str) bool[source]

Tells if a result has a type.

property input_names: List[str]

property

property inputs: List[Any]

property

is_constant(name: str) bool[source]

Tells if a result is a constant.

is_constant_scalar(name: str, value: Any | None = None, broadcast: bool = False) bool[source]

Tells if a constant is a scalar

Parameters:
  • name – name

  • broadcast – if True, consider 1, [1], [[1]], [[[1]]], … as scalar as well

  • value – value to compare to if specified

Returns:

boolean

is_output(name: str) bool[source]

Tells if a result is an output.

is_used(name: str) bool[source]

Tells if a result is used or not, including as an output of the graph.

is_used_by_subgraph(name: str) bool[source]

Tells if a result is used by a subgraphs.

is_used_more_than_once(name: str) bool[source]

Tells if a result is used more than once in the current graph or in a subgraph or if it is an output.

is_used_only_by(name, *nodes: List[NodeProto]) bool[source]

Tells if a result is only used by a specific set of nodes.

iter_nodes() Iterator[source]

iterator

property main_opset

Returns the opset for the main domain (assuming it is used).

make_node(op_type: str, inputs: str | List[str], outputs: int | List[str] | str = 1, domain: str = '', attributes: List[AttributeProto] | None = None, name: str | None = None, **kwargs) NodeProto[source]

Creates a node without adding it to the graph.

Parameters:
  • op_type – operator type

  • inputs – input names

  • outputs – outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name

  • domain – node domain

  • attributes – list of attributes

  • name – node name

  • kwargs – other attributes

Returns:

a node

make_node_check_opset(op_type: str, inputs: str | List[str], outputs: int | List[str] | str = 1, domain: str = '', attributes: List[AttributeProto] | None = None, name: str | None = None, **kwargs)[source]

Creates a node without adding it to the graph but adapt for some known operators changing over multiple opets.

Parameters:
  • op_type – operator type

  • inputs – input names

  • outputs – outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name

  • domain – node domain

  • attributes – list of attributes

  • name – node name

  • kwargs – other attributes

Returns:

a node

next_node(name: str) NodeProto[source]

Returns the next node if it is unique, otherwise fails.

next_nodes(name: str) List[NodeProto][source]

Returns the node consuming the given results.

node_before(name: str) NodeProto[source]

Returns the node producing this output. Returns None if it is an input or an initializer.

property nodes: List[NodeProto]

property

property opsets

property

optimize(max_iter=-1, remove_identity: bool = True, stop_after: int = -1) List[Dict[str, Any]][source]

Optimizes the based on the given list of patterns.

Parameters:
  • max_iter – maximum number of iterations

  • remove_identity – remove identity nodes, it is better to keep it True, not doing it might prevent other patterns to find a set of nodes to optimize

  • sopt_after – stop after this number of replacements (to debug), -1 not to stop

Returns:

the method returns informations about the applied processes.

The algorithm runs multiple iteration until the graph is not evolving or max_iter is reached. By default, it is equal to the number of nodes. An iteration is:

matches = []

builds all successors and predecessors

for all patterns P:

    for all nodes n:

        r = p.match(n)
        if r:
            if no node already scheduled to be rewritten by another match:
                matches.append(r)

    for all matches r:
        apply the match r

This algorithm may apply more than one rewriting at each iteration but it guarantees the local structure when applying the rewriting was not altered by another one.

property output_names: List[str]

property

property outputs: List[Any]

property

try_infer_shape(name: str, exc: bool = False) int[source]

Tries to infer the type of a result.

Parameters:
  • name – name of the result for which to infer the type

  • exc – if True, raises an exception if something goes wrong

Returns:

type

try_infer_type(name: str, exc: bool = False) int[source]

Tries to infer the type of a result.

Parameters:
  • name – name of the result for which to infer the type

  • exc – if True, raises an exception if something goes wrong

Returns:

type

unique_name(prefix: str) str[source]

Returns a unique name.

MatchResult

class experimental_experiment.xoptim.MatchResult(pattern: PatternOptimization, nodes: List[NodeProto], apply: Callable, insert_at: NodeProto | None = None)[source]

Returns matching results.

Parameters:
  • pattern – object detecting the pattern

  • nodes – nodes to be replaced

  • apply – node computing the replacements

  • insert_at – insert the new nodes at this point if specified

debug_string(g: GraphBuilder | None = None) str[source]

Returns a string showing the matched nodes.

PatternOptimization

class experimental_experiment.xoptim.PatternOptimization(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]

Defines an optimization pattern. Function match should return None if the match does not happen or better self.none(node, inspect.currentframe().f_lineno). That allows the user to know which line rejected a specific pattern by setting environment variable LOG_PATTERN_OPTIMIZE=10.

Parameters:
  • verbose – determine the verbosity, this can be also dermine by setting up environment variable LOG_PATTERN_OPTIMIZE=10

  • priority – at each iteration, all patterns whose priority is below one threshold are executed, if none of them matches, the priority is increase

  • min_opset – can be applied if main opset is > min_opset

apply(g: GraphBuilder, *nodes: Sequence[NodeProto]) List[NodeProto][source]

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

enumerate_matches(g: GraphBuilderPatternOptimization) Iterator[source]

Enumerates all the

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optmizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

none(node: NodeProto | None = None, lineno: int | None = None, msg: Callable | str | None = None)[source]

It may be useful which reason made a pattern matching fail. Instead of returning None, method match can return the following expression:

return self.none(node, inspect.currentframe().f_lineno)

By setting the verbosity (see next Section), the user may then know which lines in the code returned None and which condition failed.

EasyPatternOptimization

class experimental_experiment.xoptim.EasyPatternOptimization(verbose: int = 0, priority: int = 0, min_opset: int = 1)[source]

Implements a pattern optimization for quick experimentation. The current implementation does not match on domain name. It does not compares attributes either.

add_validate_param(key: str, value: Any)[source]

Stores a value to retrieve when apply_pattern is called.

apply(g: GraphBuilder, *nodes: Sequence[NodeProto]) List[NodeProto][source]

The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.

Parameters:

nodes – nodes returned by method match, there are then removed

Returns:

nodes to add to graph.

apply_pattern(g: GraphBuilder, *args, **kwargs)[source]

Applies the replacement.

display_pattern(g, fct) str[source]

Shows the pattern to match or to apply.

match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None[source]

Determines nodes around node which can be rewritten.

Parameters:
  • g – is a GraphBuilderPatternOptimization, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.

  • node – the matching must determine if some nodes around this one are part of set of nodes this pattern optmizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.

  • matched – usually unused, it returns of nodes already matching a pattern

The method must not modify the graph. The method returns None if no match is found or an instance of class MatchResult. It must contain:

  • a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.

  • A function doing the rewriting (usually method apply of the pattern class).

  • An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.

match_pattern(g: GraphBuilder, *args: List[str], **kwargs: Dict[str, Any])[source]

Builds the pattern to match.

validate_attribute_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool[source]

Validates the mapping of the attributes

Parameters:
  • g – GraphBuilder

  • deleted_nodes – matched nodes from the model (to be deleted)

  • pattern_nodes – matched nodes coming from the pattern

Returns:

validate the mapping or not, default is True

validate_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool[source]

Validates the mapping.

Parameters:
  • g – GraphBuilder

  • deleted_nodes – matched nodes from the model (to be deleted)

  • pattern_nodes – matched nodes coming from the pattern

Returns:

validate the mapping or not, default is True

Other functions

experimental_experiment.xoptim.get_pattern(obj: PatternOptimization | str, as_list: bool = False, verbose: int = 0) PatternOptimization[source]

Returns an optimization pattern based on its name.

experimental_experiment.xoptim.get_pattern_list(positive_list: str | List[str | type] | None = 'default', negative_list: str | List[str | type] | None = None, verbose: int = 0)[source]

Builds a list of patterns based on two lists, negative and positive.

<<<

import pprint
from experimental_experiment.xoptim import get_pattern_list

pprint.pprint(get_pattern_list("default", ["Cast"]))

>>>

    [BatchNormalizationPattern(),
     BatchNormalizationTrainingPattern(),
     CastLayerNormalizationCastPattern(),
     CastCastBinaryPattern(),
     CastOpCastPattern(),
     ComputationCastOpCastPattern(),
     ConvBiasNullPattern(),
     DropoutPattern(),
     ExpandPattern(),
     ExpandBroadcastPattern(),
     ExpandSwapPattern(),
     GeluPattern(),
     IdentityPattern(),
     LayerNormalizationPattern(),
     LayerNormalizationScalePattern(),
     LeakyReluPattern(),
     MulMulMulScalarPattern(),
     ReduceReshapePattern(),
     ReduceSumNormalizePattern(),
     ReshapePattern(),
     ReshapeMatMulReshapePattern(),
     Reshape2Of3Pattern(),
     ReshapeReshapeBinaryPattern(),
     MatMulAddPattern(),
     GemmTransposePattern(),
     MatMulReshape2Of3Pattern(),
     MulMulMatMulPattern(),
     ReshapeReshapePattern(),
     RotaryConcatPartPattern(),
     SameChildrenPattern(),
     SlicesSplitPattern(),
     SoftmaxCrossEntropyLossCastPattern(),
     Sub1MulPattern(),
     SwitchOrderBinaryPattern(),
     TransposeMatMulPattern(),
     TransposeReshapeMatMulPattern(),
     TransposeReshapeTransposePattern(),
     TransposeTransposePattern(),
     UnsqueezeEqualPattern(),
     UnsqueezeUnsqueezePattern()]