.xoptim¶
GraphBuilderPatternOptimization¶
- class experimental_experiment.xoptim.GraphBuilderPatternOptimization(builder: GraphBuilder, patterns: List[PatternOptimization] | None = None, recursive: bool = False, verifies: bool = False, verbose: int = 0, dump_applied_patterns: str | None = None, processor: str = 'CPU')[source]¶
Implements optimization after the conversion is done. The differences between the two models can be display with a command line such as:
python -m onnx_array_api compare -m1 <model.onnx> -m2 <optimized.onnx> -m nodes -c 80
This class assumes a pattern cannot reuse an existing name.
- Parameters:
builder – GraphBuilder
patterns – list of patterns to apply
recursive – goes through subgraphs
verifies – verifies the model but it takes time
verbose – verbosity
dump_applied_patterns – dump applied patterns in a folder, the users can check every pattern dumped as a FunctionProto
processor – optimization should be made for this processor or this list of processors (comma separated value)
- apply_match(match: MatchResult) List[NodeProto] [source]¶
Applies one match. Returns the new nodes.
- get_attribute(node: NodeProto, att_name: str, exc: bool = True) AttributeProto | None [source]¶
Returns an attribute for a node.
- get_attributes_with_default(node: NodeProto, **default_values) Dict[str, Any] [source]¶
Returns integer or float values for attributes.
- get_axis(node: NodeProto, default_axis: int | None = None) int [source]¶
Retrieves the axis for many operators.
- get_computed_constant(name: str, statistics: List[str] | None = None) Any [source]¶
Returns the value for the constant name.
- get_constant_or_attribute(node: NodeProto, attribute: str, input_index: int, cvt: Callable | None = None) Any [source]¶
Returns an input or the value of an attribute. Some attributes became inputs in more recent opsets. The function checks both.
- Parameters:
node – node
attribute – attribute name
input_index – input index
cvt – if not None, called this conversion function before returning the result
- Returns:
value
- get_constant_scalar(name: str, broadcast: bool = False) int | float [source]¶
Returns a scalar as a constant.
- Parameters:
name – name
broadcast – consider [1], [[1]], [[[1]]] as constant as well
- Returns:
int or float
- get_constant_shape(name: str, exc: bool = True) Tuple[int, ...] | None [source]¶
Returns the shape of a constant.
- Parameters:
name – name
exc – raises an exception is not possible
- Returns:
shape
- get_registered_constraints() Dict[str, Set[str | int]] [source]¶
Returns the constraints registered so far.
- is_constant_scalar(name: str, value: Any | None = None, broadcast: bool = False) bool [source]¶
Tells if a constant is a scalar
- Parameters:
name – name
broadcast – if True, consider 1, [1], [[1]], [[[1]]], … as scalar as well
value – value to compare to if specified
- Returns:
boolean
- is_used(name: str) bool [source]¶
Tells if a result is used or not, including as an output of the graph.
- is_used_more_than_once(name: str) bool [source]¶
Tells if a result is used more than once in the current graph or in a subgraph or if it is an output.
- is_used_only_by(name, *nodes: List[NodeProto]) bool [source]¶
Tells if a result is only used by a specific set of nodes.
- property main_opset¶
Returns the opset for the main domain (assuming it is used).
- make_node(op_type: str, inputs: str | List[str], outputs: int | List[str] | str = 1, domain: str = '', attributes: List[AttributeProto] | None = None, name: str | None = None, **kwargs) NodeProto [source]¶
Creates a node without adding it to the graph.
- Parameters:
op_type – operator type
inputs – input names
outputs – outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name
domain – node domain
attributes – list of attributes
name – node name
kwargs – other attributes
- Returns:
a node
- make_node_check_opset(op_type: str, inputs: str | List[str], outputs: int | List[str] | str = 1, domain: str = '', attributes: List[AttributeProto] | None = None, name: str | None = None, **kwargs)[source]¶
Creates a node without adding it to the graph but adapt for some known operators changing over multiple opets.
- Parameters:
op_type – operator type
inputs – input names
outputs – outputs names, if one integer, creates n unique names, if str, creates one unique names, if a list, use the name
domain – node domain
attributes – list of attributes
name – node name
kwargs – other attributes
- Returns:
a node
- node_before(name: str) NodeProto [source]¶
Returns the node producing this output. Returns None if it is an input or an initializer.
- property opsets¶
property
- optimize(max_iter=-1, remove_identity: bool = True, stop_after: int = -1) List[Dict[str, Any]] [source]¶
Optimizes the based on the given list of patterns.
- Parameters:
max_iter – maximum number of iterations
remove_identity – remove identity nodes, it is better to keep it True, not doing it might prevent other patterns to find a set of nodes to optimize
sopt_after – stop after this number of replacements (to debug), -1 not to stop
- Returns:
the method returns informations about the applied processes.
The algorithm runs multiple iteration until the graph is not evolving or max_iter is reached. By default, it is equal to the number of nodes. An iteration is:
matches = [] builds all successors and predecessors for all patterns P: for all nodes n: r = p.match(n) if r: if no node already scheduled to be rewritten by another match: matches.append(r) for all matches r: apply the match r
This algorithm may apply more than one rewriting at each iteration but it guarantees the local structure when applying the rewriting was not altered by another one.
- try_infer_shape(name: str, exc: bool = False) int [source]¶
Tries to infer the type of a result.
- Parameters:
name – name of the result for which to infer the type
exc – if True, raises an exception if something goes wrong
- Returns:
type
MatchResult¶
- class experimental_experiment.xoptim.MatchResult(pattern: PatternOptimization, nodes: List[NodeProto], apply: Callable, insert_at: NodeProto | None = None)[source]¶
Returns matching results.
- Parameters:
pattern – object detecting the pattern
nodes – nodes to be replaced
apply – node computing the replacements
insert_at – insert the new nodes at this point if specified
- debug_string(g: GraphBuilder | None = None) str [source]¶
Returns a string showing the matched nodes.
PatternOptimization¶
- class experimental_experiment.xoptim.PatternOptimization(verbose: int = 0, priority: int = 1, min_opset: int = 1)[source]¶
Defines an optimization pattern. Function match should return None if the match does not happen or better
self.none(node, inspect.currentframe().f_lineno)
. That allows the user to know which line rejected a specific pattern by setting environment variableLOG_PATTERN_OPTIMIZE=10
. An environment variable equal to the class name can be set as well to track this specific pattern.- Parameters:
verbose – determine the verbosity, this can be also dermine by setting up environment variable
LOG_PATTERN_OPTIMIZE=10
priority – at each iteration, all patterns whose priority is below one threshold are executed, if none of them matches, the priority is increase
min_opset – can be applied if main opset is > min_opset
- apply(g: GraphBuilder, *nodes: Sequence[NodeProto]) List[NodeProto] [source]¶
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- enumerate_matches(g: GraphBuilderPatternOptimization) Iterator [source]¶
Enumerates all the
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None [source]¶
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization
, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optmizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult
. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- none(node: NodeProto | None = None, lineno: int | None = None, msg: Callable | str | None = None)[source]¶
It may be useful which reason made a pattern matching fail. Instead of returning None, method match can return the following expression:
return self.none(node, inspect.currentframe().f_lineno)
By setting the verbosity (see next Section), the user may then know which lines in the code returned None and which condition failed.
EasyPatternOptimization¶
- class experimental_experiment.xoptim.EasyPatternOptimization(verbose: int = 0, priority: int = 0, min_opset: int = 1)[source]¶
Implements a pattern optimization for quick experimentation. The current implementation does not match on domain name. It does not compares attributes either. The environment variable
AMBIGUITIES=1
can be set to one to raise an exception when this case happens.- add_validate_param(key: str, value: Any)[source]¶
Stores a value to retrieve when apply_pattern is called.
- apply(g: GraphBuilder, *nodes: Sequence[NodeProto]) List[NodeProto] [source]¶
The method does the rewriting. It assumes it can happen. It takes a list of nodes impacted by the rewriting assumes no other pattern optimizer will be modify them. It receives the list of nodes returned by method apply. Since it is a list of argument, method match can include None values. The method returns the new nodes. The optimizer considers that any node given to this function is removed from the graph, and any node returned by it are added. If a received node must be kept, it must be added to the list of returned node.
- Parameters:
nodes – nodes returned by method match, there are then removed
- Returns:
nodes to add to graph.
- apply_pattern(g: GraphBuilder, *args, **kwargs)[source]¶
Applies the replacement.
- match(g: GraphBuilderPatternOptimization, node: NodeProto, matched: List[MatchResult]) MatchResult | None [source]¶
Determines nodes around node which can be rewritten.
- Parameters:
g – is a
GraphBuilderPatternOptimization
, it holds all the existing nodes, is able to return any information about type, shape, the node before, the node after another one.node – the matching must determine if some nodes around this one are part of set of nodes this pattern optmizer can rewrite. From there, the function explores wherever it needs, checking any condition it needs.
matched – usually unused, it returns of nodes already matching a pattern
The method must not modify the graph. The method returns None if no match is found or an instance of class
MatchResult
. It must contain:a list of nodes involved in the rewriting. It does not mean all of them will be removed but all of them are needed to do the rewriting and must not be impacted by other pattern optimizer.
A function doing the rewriting (usually method apply of the pattern class).
An existing node where the rewritten nodes can be inserted. Knowing it makes it faster to rewriter. If not specified, the optimizer will automatically determine the position of the new nodes.
- match_pattern(g: GraphBuilder, *args: List[str], **kwargs: Dict[str, Any])[source]¶
Builds the pattern to match.
- post_apply_pattern(g, *nodes)[source]¶
Method to overload to apply as step after the pattern was applied.
- validate_attribute_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool [source]¶
Validates the mapping of the attributes
- Parameters:
g – GraphBuilder
deleted_nodes – matched nodes from the model (to be deleted)
pattern_nodes – matched nodes coming from the pattern
- Returns:
validate the mapping or not, default is True
- validate_mapping(g: GraphBuilderPatternOptimization, deleted_nodes: List[NodeProto], pattern_nodes: List[NodeProto] | None = None) bool [source]¶
Validates the mapping.
- Parameters:
g – GraphBuilder
deleted_nodes – matched nodes from the model (to be deleted)
pattern_nodes – matched nodes coming from the pattern
- Returns:
validate the mapping or not, default is True
Other functions¶
- experimental_experiment.xoptim.get_pattern(obj: PatternOptimization | str, as_list: bool = False, verbose: int = 0) PatternOptimization [source]¶
Returns an optimization pattern based on its name.
- experimental_experiment.xoptim.get_pattern_list(positive_list: str | List[str | type] | None = 'default', negative_list: str | List[str | type] | None = None, verbose: int = 0)[source]¶
Builds a list of patterns based on two lists, negative and positive.
<<<
import pprint from experimental_experiment.xoptim import get_pattern_list pprint.pprint(get_pattern_list("default", ["Cast"]))
>>>
[BatchNormalizationPattern(), BatchNormalizationTrainingPattern(), CastLayerNormalizationCastPattern(), CastCastBinaryPattern(), CastOpCastPattern(), ClipClipPattern(), ComputationCastOpCastPattern(), ConvBiasNullPattern(), DropoutPattern(), ExpandPattern(), ExpandBroadcastPattern(), ExpandSwapPattern(), GeluPattern(), IdentityPattern(), LayerNormalizationPattern(), LayerNormalizationScalePattern(), LeakyReluPattern(), MulMulMulScalarPattern(), ReduceReshapePattern(), ReduceSumNormalizePattern(), ReshapePattern(), ReshapeMatMulReshapePattern(), Reshape2Of3Pattern(), ReshapeReshapeBinaryPattern(), MatMulAddPattern(), GemmTransposePattern(), MatMulReshape2Of3Pattern(), MulMulMatMulPattern(), ReshapeReshapePattern(), RotaryConcatPartPattern(), SameChildrenPattern(), SequenceConstructAtPattern(), SliceSlicePattern(), SlicesSplitPattern(), SoftmaxCrossEntropyLossCastPattern(), SplitConcatPattern(), Sub1MulPattern(), SwitchOrderBinaryPattern(), SwitchReshapeActivationPattern(), TransposeEqualReshapePattern(), TransposeMatMulPattern(), TransposeReshapeMatMulPattern(), TransposeReshapeTransposePattern(), TransposeTransposePattern(), UnsqueezeEqualPattern(), UnsqueezeUnsqueezePattern()]