import cProfile
import json
import math
import os
import site
from collections import OrderedDict, deque
from io import StringIO
from pstats import SortKey, Stats
from typing import Any, Callable, Dict, List, Optional
class ProfileNode:
Graph structure to represent a profiling.
:param filename: filename
:param line: line number
:param func_name: function name
:param nc1: number of calls 1
:param nc2: number of calls 2
:param tin: time spent in the function
:param tout: time spent in the function and in the sub functions
def __init__(
filename: str,
line: int,
func_name: str,
nc1: int,
nc2: int,
tin: float,
tall: float,
if "method 'disable' of '_lsprof.Profiler'" in func_name:
raise RuntimeError(f"Function not allowed in the profiling: {func_name!r}.")
self.filename = filename
self.line = line
self.func_name = func_name
self.nc1 = nc1
self.nc2 = nc2
self.tin = tin
self.tall = tall
self.called_by = []
self.calls_to = []
self.calls_to_elements = []
def add_called_by(self, pnode: "ProfileNode"):
"This function is called by these lines."
def add_calls_to(self, pnode: "ProfileNode", time_elements):
"This function calls this node."
def _key(filename: str, line: int, fct: Callable) -> str:
key = "%s:%d:%s" % (filename, line, fct)
return key
def key(self):
"Returns `file:line`."
return ProfileNode._key(self.filename, self.line, self.func_name)
def get_root(self):
"Returns the root of the graph."
done = set()
def _get_root(node, stor=None):
if stor is not None:
if not node.called_by:
return node
if len(node.called_by) == 0:
return None
res = None
for ct in node.called_by:
k = id(node), id(ct)
if k in done:
res = ct
if res is None:
# All paths have been explored and no entry point was found.
# Choosing the most consuming function.
return None
done.add((id(node), id(res)))
return _get_root(res, stor=stor)
root = _get_root(self)
if root is None:
candidates = []
_get_root(self, stor=candidates)
tall = [(n.tall, n) for n in candidates]
root = tall[-1][-1]
return root
def __repr__(self) -> str:
return "%s(%r, %r, %r, %r, %r, %r, %r) # %d-%d" % (
def __iter__(self):
"Returns all nodes in the graph."
done = set()
stack = deque()
while len(stack) > 0:
node = stack.popleft()
if node.key in done:
yield node
_modules_ = {
"<frozen importlib._bootstrap>",
def filter_node_(node, info=None) -> bool:
Filters out node to be displayed by default.
:param node: node
:param info: if the node is called by a function,
this dictionary can be used to overwrite the attributes
held by the node
:return: boolean (True to keep, False to forget)
if node.filename in ProfileNode._modules_:
if info is None:
if node.nc1 <= 10 and node.nc2 <= 10 and node.tall <= 1e-4:
return False
if info["nc1"] <= 10 and info["nc2"] <= 10 and info["tall"] <= 1e-4:
return False
return True
def as_dict(self, filter_node=None, sort_key=SortKey.LINE):
Renders the results of a profiling interpreted with
function @fn profile2graph. It can then be loaded with
a dataframe.
:param filter_node: display only the nodes for which
this function returns True, if None, the default function
removes built-in function with small impact
:param sort_key: sort sub nodes by...
:return: rows
def sort_key_line(dr):
if isinstance(dr, tuple):
return (dr[0].filename, dr[0].line)
return (dr.filename, dr.line)
def sort_key_tin(dr):
if isinstance(dr, tuple):
return -dr[1][2]
return -dr.tin
def sort_key_tall(dr):
if isinstance(dr, tuple):
return -dr[1][3]
return -dr.tall
if sort_key == SortKey.LINE:
sortk = sort_key_line
elif sort_key == SortKey.CUMULATIVE:
sortk = sort_key_tall
elif sort_key == SortKey.TIME:
sortk = sort_key_tin
raise NotImplementedError(
f"Unable to sort subcalls with this key {sort_key!r}."
def depth_first(node, roots_keys, indent=0):
text = {
"fct": node.func_name,
"where": node.key,
"nc1": node.nc1,
"nc2": node.nc2,
"tin": node.tin,
"tall": node.tall,
"indent": indent,
"ncalls": len(node.calls_to),
"debug": "A",
yield text
for n, nel in sorted(zip(node.calls_to, node.calls_to_elements), key=sortk):
if n.key in roots_keys:
text = {
"fct": n.func_name,
"where": n.key,
"nc1": nel[0],
"nc2": nel[1],
"tin": nel[2],
"tall": nel[3],
"indent": indent + 1,
"ncalls": len(n.calls_to),
"more": "+",
"debug": "B",
if filter_node is not None and not filter_node(n, info=text):
yield text
if filter_node is not None and not filter_node(n):
yield from depth_first(n, roots_keys, indent + 1)
if filter_node is None:
filter_node = ProfileNode.filter_node_
nodes = list(self)
roots = [node for node in nodes if len(node.called_by) != 1]
roots_key = {r.key: r for r in roots}
rows = []
for root in sorted(roots, key=sortk):
if filter_node is not None and not filter_node(root):
rows.extend(depth_first(root, roots_key))
return rows
def to_text(self, filter_node=None, sort_key=SortKey.LINE, fct_width=60) -> str:
Prints the profiling to text.
:param filter_node: display only the nodes for which
this function returns True, if None, the default function
removes built-in function with small impact
:param sort_key: sort sub nodes by...
:return: rows
def align_text(text, size):
if size <= 0:
return text
if len(text) <= size:
return text + " " * (size - len(text))
h = size // 2 - 1
return text[:h] + "..." + text[-h + 1 :]
dicts = self.as_dict(filter_node=filter_node, sort_key=sort_key)
max_nc = max(max(_["nc1"] for _ in dicts), max(_["nc2"] for _ in dicts))
dg = int(math.log(max_nc) / math.log(10) + 1.5)
line_format = (
"{indent}{fct} -- {nc1: %dd} {nc2: %dd} -- {tin:1.5f} {tall:1.5f}"
" -- {name} ({fct2})" % (dg, dg)
text = []
for row in dicts:
line = line_format.format(
indent=" " * (row["indent"] * 4),
fct=align_text(row["fct"], fct_width - row["indent"] * 4),
if row.get("more", "") == "+":
line += " +++"
return "\n".join(text)
def to_json(
self, filter_node=None, sort_key=SortKey.LINE, as_str=True, **kwargs
) -> str:
Renders the results of a profiling interpreted with
function @fn profile2graph as :epkg:`JSON`.
:param filter_node: display only the nodes for which
this function returns True, if None, the default function
removes built-in function with small impact
:param sort_key: sort sub nodes by...
:param as_str: converts the json into a string
:param kwargs: see :func:`json.dumps`
:return: rows
def sort_key_line(dr):
if isinstance(dr, tuple):
return (dr[0].filename, dr[0].line)
return (dr.filename, dr.line)
def sort_key_tin(dr):
if isinstance(dr, tuple):
return -dr[1][2]
return -dr.tin
def sort_key_tall(dr):
if isinstance(dr, tuple):
return -dr[1][3]
return -dr.tall
if sort_key == SortKey.LINE:
sortk = sort_key_line
elif sort_key == SortKey.CUMULATIVE:
sortk = sort_key_tall
elif sort_key == SortKey.TIME:
sortk = sort_key_tin
raise NotImplementedError(
f"Unable to sort subcalls with this key {sort_key!r}."
def walk(node, roots_keys, indent=0):
item = {
"details": {
"fct": node.func_name,
"where": node.key,
"nc1": node.nc1,
"nc2": node.nc2,
"tin": node.tin,
"tall": node.tall,
"indent": indent,
"ncalls": len(node.calls_to),
child = OrderedDict()
for n, nel in sorted(zip(node.calls_to, node.calls_to_elements), key=sortk):
key = (nel[0], f"{nel[3]:1.5f}:{n.func_name}")
if n.key in roots_keys:
details = {
"fct": n.func_name,
"where": n.key,
"nc1": nel[0],
"nc2": nel[1],
"tin": nel[2],
"tall": nel[3],
"indent": indent,
"ncalls": len(node.calls_to),
if filter_node is not None and not filter_node(n, info=details):
child[key] = {"details": details}
if filter_node is not None and not filter_node(n):
child[key] = walk(n, roots_key, indent + 1)
if child:
mx = max(_[0] for _ in child)
dg = int(math.log(mx) / math.log(10) + 1.5)
form = f"%-{dg}d-%s"
child = OrderedDict((form % k, v) for k, v in child.items())
item["calls"] = child
return item
if filter_node is None:
filter_node = ProfileNode.filter_node_
nodes = list(self)
roots = [node for node in nodes if len(node.called_by) != 1]
roots_key = {r.key: r for r in roots}
rows = OrderedDict()
for root in sorted(roots, key=sortk):
if filter_node is not None and not filter_node(root):
key = (root.nc1, f"{root.tall:1.5f}:::{root.func_name}")
rows[key] = walk(root, roots_key)
mx = max(_[0] for _ in rows)
dg = int(math.log(mx) / math.log(10) + 1.5)
form = f"%-{dg}d-%s"
rows = OrderedDict((form % k, v) for k, v in rows.items())
if as_str:
return json.dumps({"profile": rows}, **kwargs)
return {"profile": rows}
def _process_pstats(
ps: Stats,
clean_text: Optional[Callable] = None,
verbose: bool = False,
fLOG: Optional[Callable] = None,
) -> List[Dict[str, Any]]:
Converts class `Stats <
profile.html#pstats.Stats>`_ into something
readable for a dataframe.
:param ps: instance of type :func:`pstats.Stats`
:param clean_text: function to clean function names
:param verbose: change verbosity
:param fLOG: logging function
:return: list of rows
if clean_text is None:
clean_text = lambda x: x
def add_rows(rows, d):
tt1, tt2 = 0, 0
for k, v in d.items():
stin = 0
stall = 0
if verbose and fLOG is not None:
"[pstats] %s=%r"
% ((clean_text(k[0].replace("\\", "/")),) + k[1:], v)
if len(v) < 5:
row = {
"file": "%s:%d" % (clean_text(k[0].replace("\\", "/")), k[1]),
"fct": k[2],
"ncalls1": v[0],
"ncalls2": v[1],
"tin": v[2],
"tall": v[3],
stin += v[2]
stall += v[3]
if len(v) == 5:
t1, t2 = add_rows(rows, v[-1])
stin += t1
stall += t2
row["cum_tin"] = stin
row["cum_tall"] = stall
tt1 += stin
tt2 += stall
return tt1, tt2
rows = []
add_rows(rows, ps.stats)
return rows
def profile2df(
ps: Stats,
as_df: bool = True,
clean_text: Optional[bool] = None,
verbose: bool = False,
Converts profiling statistics into a Dataframe.
:param ps: an instance of `pstats
:param as_df: returns the results as a dataframe (True)
or a list of dictionaries (False)
:param clean_text: function to clean function names
:param verbose: verbosity
:param fLOG: logging function
:return: a DataFrame
import pstats
from onnx_array_api.profiling import profile2df
ps = pstats.Stats('')
df = profile2df(pd)
rows = _process_pstats(ps, clean_text, verbose=verbose, fLOG=fLOG)
if not as_df:
return rows
import pandas
df = pandas.DataFrame(rows)
df = df[["fct", "file", "ncalls1", "ncalls2", "tin", "cum_tin", "tall", "cum_tall"]]
df = (
df.groupby(["fct", "file"], as_index=False)
.sort_values("cum_tall", ascending=False)
return df.copy()
def profile(
fct: Callable,
sort: str = "cumulative",
rootrem: Optional[str] = None,
as_df: bool = False,
return_results: bool = False,
) -> str:
Profiles the execution of a function.
:param fct: function to profile
:param sort: see `sort_stats <
:param rootrem: root to remove in filenames
:param as_df: return the results as a dataframe and not text
:param return_results: if True, return results as well
(in the first position)
:param kwargs: additional parameters used to create the profiler,
see :epkg:`cProfile.Profile`
:return: raw results, statistics text dump (or dataframe is *as_df* is True)
.. plot::
import matplotlib.pyplot as plt
from onnx_array_api.profiling import profile
def subf(x):
return sum(x)
def fctm():
x1 = subf([1, 2, 3])
x2 = subf([1, 2, 3, 4])
return x1 + x2
pr, df = profile(lambda: [fctm() for i in range(0, 1000)], as_df=True)
ax = df[['namefct', 'cum_tall']].head(n=15).set_index(
'namefct').plot(kind='bar', figsize=(8, 3), rot=30)
ax.set_title("example of a graph")
for la in ax.get_xticklabels():
pr = cProfile.Profile(**kwargs)
fct_res = fct()
s = StringIO()
ps = Stats(pr, stream=s).sort_stats(sort)
res = s.getvalue()
pack = site.getsitepackages()
except AttributeError:
import numpy
pack = os.path.normpath(
os.path.abspath(os.path.join(os.path.dirname(numpy.__file__), ".."))
pack = [pack]
pack_ = os.path.normpath(os.path.join(pack[-1], ".."))
def clean_text(res):
res = res.replace(pack[-1], "site-packages")
res = res.replace(pack_, "lib")
if rootrem is not None:
if isinstance(rootrem, str):
res = res.replace(rootrem, "")
for sub in rootrem:
if isinstance(sub, str):
res = res.replace(sub, "")
elif isinstance(sub, tuple) and len(sub) == 2:
res = res.replace(sub[0], sub[1])
raise TypeError(
f"rootrem must contains strings or tuple not {rootrem!r}."
return res
if as_df:
def better_name(row):
if len(row["fct"]) > 15:
return f"{row['file'].split(':')[-1]}-{row['fct']}"
name = row["file"].replace("\\", "/")
return f"{name.split('/')[-1]}-{row['fct']}"
rows = _process_pstats(ps, clean_text)
import pandas
df = pandas.DataFrame(rows)
df = df[
df["namefct"] = df.apply(lambda row: better_name(row), axis=1)
df = (
df.groupby(["namefct", "file"], as_index=False)
.sort_values("cum_tall", ascending=False)
if return_results:
return fct_res, ps, df
return ps, df
res = clean_text(res)
if return_results:
return fct_res, ps, res
return ps, res
def profile2graph(
ps: Stats,
clean_text: Optional[Callable] = None,
verbose: bool = False,
fLOG: Optional[Callable] = None,
) -> ProfileNode:
Converts profiling statistics into a graphs.
:param ps: an instance of `pstats
:param clean_text: function to clean function names
:param verbose: verbosity
:param fLOG: logging function
:return: an instance of class @see cl ProfileNode
:epkg:`pyinstrument` has a nice display to show
time spent and call stack at the same time. This function
tries to replicate that display based on the results produced
by module :mod:`cProfile`. Here is an example.
.. runpython::
import time
from onnx_array_api.profiling import profile, profile2graph
def fct0(t):
def fct1(t):
def fct2():
def fct3():
def fct4():
ps = profile(fct4)[0]
root, nodes = profile2graph(ps, clean_text=lambda x: x.split('/')[-1])
text = root.to_text()
if clean_text is None:
clean_text = lambda x: x
nodes = {}
for k, v in ps.stats.items():
if verbose and fLOG is not None:
fLOG(f"[pstats] {k}={v!r}")
if len(v) < 5:
if k[0] == "~" and len(v) == 0:
# raw function never called by another
if "method 'disable' of '_lsprof.Profiler'" in k[2]:
node = ProfileNode(
filename=clean_text(k[0].replace("\\", "/")),
if node.key in nodes:
raise RuntimeError(f"Key {node.key!r} is already present, node={node!r}.")
nodes[node.key] = node
for k, v in ps.stats.items():
if "method 'disable' of '_lsprof.Profiler'" in k[2]:
filename = clean_text(k[0].replace("\\", "/"))
ks = ProfileNode._key(filename, k[1], k[2])
node = nodes[ks]
sublist = v[4]
for f, vv in sublist.items():
if "method 'disable' of '_lsprof.Profiler'" in f[2]:
name = clean_text(f[0].replace("\\", "/"))
key = ProfileNode._key(name, f[1], f[2])
if key not in nodes:
raise RuntimeError(
"Unable to find key %r into\n%s" % (key, "\n".join(sorted(nodes)))
if k[0] == "~" and len(v) == 0:
child = nodes[key]
child.add_calls_to(node, vv)
for _k, v in nodes.items():
root = v.get_root()
return root, nodes