Source code for experimental_experiment.xshape.simplify_expressions
import ast
from collections import Counter
from typing import Dict, List, Optional, Tuple, Union
def _dump_node(n: ast.AST) -> str:
return ast.dump(n, include_attributes=False)
class _Common:
def __init__(self, expr: Optional[str] = None):
self.expr = expr
def get_debug_msg(self) -> str:
if self.expr:
return f" expression={self.expr!r}"
return ""
[docs]
class CommonVisitor(ast.NodeVisitor, _Common):
def __init__(self, expr: Optional[str] = None):
ast.NodeVisitor.__init__(self)
_Common.__init__(self, expr)
[docs]
class CommonTransformer(ast.NodeTransformer, _Common):
def __init__(self, expr: Optional[str] = None):
ast.NodeTransformer.__init__(self)
_Common.__init__(self, expr)
[docs]
class SimpleSimpliflyTransformer(CommonTransformer):
"""Simplifies expressions such as ``batch^batch``, ``x+0``, ``x*1``."""
def visit_BinOp(self, node):
self.generic_visit(node)
if isinstance(node.op, ast.BitXor):
if (
isinstance(node.left, ast.Name)
and isinstance(node.right, ast.Name)
and node.left.id == node.right.id
):
return node.left
if isinstance(node.op, ast.Add):
if isinstance(node.left, ast.Constant) and node.left.value == 0:
return node.right
if isinstance(node.right, ast.Constant) and node.right.value == 0:
return node.left
if isinstance(node.op, ast.Mult):
if isinstance(node.left, ast.Constant) and node.left.value == 1:
return node.right
if isinstance(node.right, ast.Constant) and node.right.value == 1:
return node.left
return node
[docs]
class MulDivCancellerTransformer(CommonTransformer):
"""Simplifies ``2*x//2`` into ``x``."""
@classmethod
def _flatten_mul_div(cls, node: ast.AST) -> Tuple[List[ast.AST], List[ast.AST]]:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
lnum, lden = cls._flatten_mul_div(node.left)
rnum, rden = cls._flatten_mul_div(node.right)
return lnum + rnum, lden + rden
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.FloorDiv):
lnum, lden = cls._flatten_mul_div(node.left)
rnum, rden = cls._flatten_mul_div(node.right)
return lnum + rden, lden + rnum
return [node], []
@classmethod
def _rebuild_from_factors(cls, numer: List[ast.AST], denom: List[ast.AST]) -> ast.AST:
def _product(factors: List[ast.AST]) -> ast.AST:
if not factors:
return ast.Constant(value=1)
node = factors[0]
for f in factors[1:]:
node = ast.BinOp(left=node, op=ast.Mult(), right=f)
return node
numer_node = _product(numer)
if not denom:
return numer_node
denom_node = _product(denom)
return ast.BinOp(left=numer_node, op=ast.FloorDiv(), right=denom_node)
def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
node = self.generic_visit(node)
if not (isinstance(node, ast.BinOp) and (isinstance(node.op, (ast.Mult, ast.FloorDiv)))):
return node
numer, denom = self._flatten_mul_div(node)
numer_keys = [_dump_node(n) for n in numer]
denom_keys = [_dump_node(d) for d in denom]
num_counter = Counter(numer_keys)
den_counter = Counter(denom_keys)
common_keys = set(num_counter.keys()) & set(den_counter.keys())
for k in common_keys:
cancel = min(num_counter[k], den_counter[k])
num_counter[k] -= cancel
den_counter[k] -= cancel
remaining_numer = []
needed_num = dict(num_counter)
for n, k in zip(numer, numer_keys):
if needed_num.get(k, 0) > 0:
remaining_numer.append(n)
needed_num[k] -= 1
remaining_denom = []
needed_den = dict(den_counter)
for d, k in zip(denom, denom_keys):
if needed_den.get(k, 0) > 0:
remaining_denom.append(d)
needed_den[k] -= 1
new_node = self._rebuild_from_factors(remaining_numer, remaining_denom)
return ast.copy_location(new_node, node)
[docs]
class MaxToXorTransformer(CommonTransformer):
"""Replaces ``Max(a,b)`` by ``a^b``."""
def visit_Call(self, node):
self.generic_visit(node)
if (
isinstance(node.func, ast.Name)
and node.func.id in ("max", "Max")
and len(node.args) == 2
):
a, b = node.args
return ast.BinOp(left=a, op=ast.BitXor(), right=b)
return node
[docs]
class SimplifyParensTransformer(CommonTransformer):
"""To simplify parenthesis."""
def visit_BinOp(self, node):
self.generic_visit(node)
return node
def visit_Expr(self, node):
return self.generic_visit(node)
[docs]
class ExpressionSimplifierAddVisitor(CommonVisitor):
"""Simplifies expression such as ``2*x-x``."""
def __init__(self, expr: Optional[str] = None):
super().__init__(expr)
self.coeffs = {}
self.const = 0
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
self.visit(node.left)
self.visit(node.right)
elif isinstance(node.op, ast.Sub):
self.visit(node.left)
neg = ExpressionSimplifierAddVisitor()
neg.visit(node.right)
for v, c in neg.coeffs.items():
if v not in self.coeffs:
self.coeffs[v] = 0
self.coeffs[v] -= c
self.const -= neg.const
elif isinstance(node.op, ast.Mult):
if isinstance(node.left, ast.Constant) or isinstance(node.right, ast.Constant):
simp = ExpressionSimplifierAddVisitor()
simp.visit(node.left if isinstance(node.right, ast.Constant) else node.right)
value = (
node.right.value if isinstance(node.right, ast.Constant) else node.left.value
)
for v, c in simp.coeffs.items():
if v not in self.coeffs:
self.coeffs[v] = 0
self.coeffs[v] += value * c
self.const += simp.const * value
else:
self.generic_visit(node)
else:
self.generic_visit(node)
def visit_Constant(self, node):
self.const += node.value
[docs]
def generic_visit(self, node):
s = ast.unparse(node)
if s not in self.coeffs:
self.coeffs[s] = 1
else:
self.coeffs[s] += 1
def make_simplified(self) -> str:
terms = []
for var, coeff in self.coeffs.items():
if coeff == 0:
continue
elif coeff == 1:
terms.append(f"+{var}")
elif coeff == -1:
terms.append(f"-{var}")
else:
terms.append(f"{'+' if coeff > 0 else ''}{coeff}*{var}")
if self.const != 0:
terms.append(f"{'+' if self.const > 0 else ''}{self.const}")
result = "".join(terms)
res = result[1:] if result.startswith("+") else (result if result else "0")
return res.replace(" ", "")
[docs]
def simplify_expression(expr: Union[str, int]) -> Union[str, int]:
"""Simplifies an expression."""
if isinstance(expr, int):
return expr
assert isinstance(expr, str), f"Unexpected type {expr} for the expression."
tree = ast.parse(expr, mode="eval")
transformers = [
SimpleSimpliflyTransformer(expr=expr),
MulDivCancellerTransformer(expr=expr),
MaxToXorTransformer(expr=expr),
SimplifyParensTransformer(expr=expr),
]
for tr in transformers:
tree = tr.visit(tree)
ast.fix_missing_locations(tree.body)
expr = ast.unparse(tree)
simp = ExpressionSimplifierAddVisitor(expr=expr)
simp.visit(tree.body)
return simp.make_simplified()
[docs]
def simplify_two_expressions(expr1: str, expr2: str) -> Dict[str, int]:
"""Simplifies an expression exp1 == exp2."""
expr = f"{expr1}-({expr2})"
simp1 = ExpressionSimplifierAddVisitor(expr)
simp1.visit(ast.parse(expr, mode="eval").body)
return {k: v for k, v in simp1.coeffs.items() if v != 0}