Source code for yobx.xshape.simplify_expressions
import ast
from collections import Counter
from typing import Dict, List, Optional, Tuple, Union
def _dump_node(n: ast.AST) -> str:
return ast.dump(n, include_attributes=False)
class _Common:
"""
Base class for common visitors and transformers.
:param expr: used only for error messages.
"""
def __init__(self, expr: Optional[str] = None):
self.expr = expr
def get_debug_msg(self) -> str:
if self.expr:
return f" expression={self.expr!r}"
return ""
[docs]
class CommonVisitor(ast.NodeVisitor, _Common):
"""
Base class for custom AST visitors.
:param expr: used only for error messages.
"""
def __init__(self, expr: Optional[str] = None):
ast.NodeVisitor.__init__(self)
_Common.__init__(self, expr)
[docs]
class CommonTransformer(ast.NodeTransformer, _Common):
"""
Base class for custom AST transformers.
:param expr: used only for error messages.
"""
def __init__(self, expr: Optional[str] = None):
ast.NodeTransformer.__init__(self)
_Common.__init__(self, expr)
[docs]
class SimpleSimpliflyTransformer(CommonTransformer):
"""Simplifies expressions such as ``batch^batch``, ``x+0``, ``x*1``."""
def visit_BinOp(self, node):
self.generic_visit(node)
if isinstance(node.op, ast.BitXor):
if (
isinstance(node.left, ast.Name)
and isinstance(node.right, ast.Name)
and node.left.id == node.right.id
):
return node.left
if isinstance(node.op, ast.Add):
if isinstance(node.left, ast.Constant) and node.left.value == 0:
return node.right
if isinstance(node.right, ast.Constant) and node.right.value == 0:
return node.left
if isinstance(node.op, ast.Mult):
if isinstance(node.left, ast.Constant) and node.left.value == 1:
return node.right
if isinstance(node.right, ast.Constant) and node.right.value == 1:
return node.left
return node
[docs]
class MulDivCancellerTransformer(CommonTransformer):
"""Simplifies ``2*x//2`` into ``x``."""
@classmethod
def _flatten_mul_div(cls, node: ast.AST) -> Tuple[List[ast.AST], List[ast.AST]]:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
lnum, lden = cls._flatten_mul_div(node.left)
rnum, rden = cls._flatten_mul_div(node.right)
return lnum + rnum, lden + rden
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.FloorDiv):
lnum, lden = cls._flatten_mul_div(node.left)
rnum, rden = cls._flatten_mul_div(node.right)
return lnum + rden, lden + rnum
return [node], []
@classmethod
def _rebuild_from_factors(
cls, numerator: List[ast.AST], denominator: List[ast.AST]
) -> ast.AST:
def _product(factors: List[ast.AST]) -> ast.AST:
if not factors:
return ast.Constant(value=1)
node = factors[0]
for f in factors[1:]:
node = ast.BinOp(left=node, op=ast.Mult(), right=f) # type: ignore[arg-type]
return node
numer_node = _product(numerator)
if not denominator:
return numer_node
denom_node = _product(denominator)
return ast.BinOp(left=numer_node, op=ast.FloorDiv(), right=denom_node) # type: ignore[arg-type]
def visit_BinOp(self, node: ast.BinOp) -> ast.AST:
node = self.generic_visit(node) # type: ignore[assignment]
if not (isinstance(node, ast.BinOp) and (isinstance(node.op, (ast.Mult, ast.FloorDiv)))):
return node
numerator, denominator = self._flatten_mul_div(node)
numer_keys = [_dump_node(n) for n in numerator]
denom_keys = [_dump_node(d) for d in denominator]
num_counter = Counter(numer_keys)
den_counter = Counter(denom_keys)
common_keys = set(num_counter.keys()) & set(den_counter.keys())
for k in common_keys:
cancel = min(num_counter[k], den_counter[k])
num_counter[k] -= cancel
den_counter[k] -= cancel
remaining_numer = []
needed_num = dict(num_counter)
for n, k in zip(numerator, numer_keys):
if needed_num.get(k, 0) > 0:
remaining_numer.append(n)
needed_num[k] -= 1
remaining_denom = []
needed_den = dict(den_counter)
for d, k in zip(denominator, denom_keys):
if needed_den.get(k, 0) > 0:
remaining_denom.append(d)
needed_den[k] -= 1
new_node = self._rebuild_from_factors(remaining_numer, remaining_denom)
return ast.copy_location(new_node, node)
[docs]
class MaxToXorTransformer(CommonTransformer):
"""Replaces ``Max(a,b)`` by ``a^b``."""
def visit_Call(self, node):
node = self.generic_visit(node)
if (
isinstance(node.func, ast.Name)
and node.func.id in ("max", "Max")
and len(node.args) == 2
):
a, b = node.args
return ast.BinOp(left=a, op=ast.BitXor(), right=b)
return node
[docs]
class SimplifyParensTransformer(CommonTransformer):
"""To simplify parenthesis."""
def visit_BinOp(self, node):
self.generic_visit(node)
return node
def visit_Expr(self, node):
return self.generic_visit(node)
[docs]
class ExpressionSimplifierAddVisitor(CommonVisitor):
"""Simplifies expression such as ``2*x-x``."""
def __init__(self, expr: Optional[str] = None):
super().__init__(expr)
self.coeffs: Dict[str, int] = {}
self.const = 0
def visit_BinOp(self, node):
if isinstance(node.op, ast.Add):
self.visit(node.left)
self.visit(node.right)
elif isinstance(node.op, ast.Sub):
self.visit(node.left)
neg = ExpressionSimplifierAddVisitor()
neg.visit(node.right)
for v, c in neg.coeffs.items():
if v not in self.coeffs:
self.coeffs[v] = 0
self.coeffs[v] -= c
self.const -= neg.const
elif isinstance(node.op, ast.Mult):
if isinstance(node.left, ast.Constant) or isinstance(node.right, ast.Constant):
simp = ExpressionSimplifierAddVisitor()
simp.visit(node.left if isinstance(node.right, ast.Constant) else node.right)
value = (
node.right.value if isinstance(node.right, ast.Constant) else node.left.value
)
for v, c in simp.coeffs.items():
if v not in self.coeffs:
self.coeffs[v] = 0
self.coeffs[v] += value * c
self.const += simp.const * value
else:
self.generic_visit(node)
else:
self.generic_visit(node)
def visit_Constant(self, node):
self.const += node.value
[docs]
def generic_visit(self, node):
s = ast.unparse(node)
if s not in self.coeffs:
self.coeffs[s] = 1
else:
self.coeffs[s] += 1
def make_simplified(self) -> str:
terms = []
for var, coeff in self.coeffs.items():
if coeff == 0:
continue
elif coeff == 1:
terms.append(f"+{var}")
elif coeff == -1:
terms.append(f"-{var}")
else:
terms.append(f"{'+' if coeff > 0 else ''}{coeff}*{var}")
if self.const != 0:
terms.append(f"{'+' if self.const > 0 else ''}{self.const}")
result = "".join(terms)
res = result[1:] if result.startswith("+") else (result if result else "0")
return res.replace(" ", "")
[docs]
class ReorderCommutativeOpsTransformer(CommonTransformer):
"""Sorts terms in additions or multiplications ``b+a`` -> ``a+b``."""
def visit_BinOp(self, node: ast.BinOp):
# First recurse into children
self.generic_visit(node)
# Only process + and *
if isinstance(node.op, (ast.Add, ast.Mult)):
operands = self._flatten(node, type(node.op))
operands.sort(key=self._expr_key)
return self._rebuild(operands, node.op)
return node
def _flatten(self, node: ast.AST, op_type) -> List[ast.AST]:
"""Flattens a chain of same-type binary operations."""
if isinstance(node, ast.BinOp) and isinstance(node.op, op_type):
return self._flatten(node.left, op_type) + self._flatten(node.right, op_type)
return [node]
def _rebuild(self, operands: List[ast.AST], op: ast.operator) -> ast.AST:
"""Rebuilds a binary tree from sorted operands."""
expr = operands[0]
for operand in operands[1:]:
expr = ast.BinOp(left=expr, op=op, right=operand) # type: ignore[arg-type]
return expr
def _expr_key(self, node: ast.AST) -> str:
"""Generates a sortable key for expressions."""
return ast.unparse(node)
[docs]
class ExactMulDivConstantFolderTransformer(CommonTransformer):
"""
Folds integer constants in multiplicative chains with true division,
but only when exact (no remainder).
Example: ``1024*a//2`` -> ``512*a``
"""
def visit_BinOp(self, node: ast.BinOp):
node = self.generic_visit(node) # type: ignore[assignment]
if not isinstance(node, ast.BinOp):
return node
if not isinstance(node.op, (ast.Mult, ast.FloorDiv)):
return node
numerator, denominator = self._flatten_mul_div(node)
num_const = 1
den_const = 1
num_other: List[ast.AST] = []
den_other: List[ast.AST] = []
for n in numerator:
if isinstance(n, ast.Constant) and isinstance(n.value, int):
num_const *= n.value
else:
num_other.append(n)
for d in denominator:
if isinstance(d, ast.Constant) and isinstance(d.value, int):
den_const *= d.value
else:
den_other.append(d)
# Only fold exact integer divisions and avoid division by zero.
if den_const == 0:
return node
if den_other:
return node
if num_const % den_const != 0:
return node
folded = num_const // den_const
factors: List[ast.AST] = []
if folded != 1 or not num_other:
factors.append(ast.Constant(value=folded))
factors.extend(num_other)
new_node = self._build_product(factors)
return ast.copy_location(new_node, node)
def _flatten_mul_div(self, node: ast.AST) -> Tuple[List[ast.AST], List[ast.AST]]:
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.Mult):
lnum, lden = self._flatten_mul_div(node.left)
rnum, rden = self._flatten_mul_div(node.right)
return lnum + rnum, lden + rden
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.FloorDiv):
lnum, lden = self._flatten_mul_div(node.left)
rnum, rden = self._flatten_mul_div(node.right)
return lnum + rden, lden + rnum
return [node], []
def _build_product(self, factors: List[ast.AST]) -> ast.AST:
out = factors[0]
for f in factors[1:]:
out = ast.BinOp(left=out, op=ast.Mult(), right=f) # type: ignore[arg-type]
return out
[docs]
def simplify_expression(expr: Union[str, int]) -> Union[str, int]:
"""Simplifies an expression."""
if isinstance(expr, int):
return expr
assert isinstance(expr, str), f"Unexpected type {expr} for the expression."
tree = ast.parse(expr, mode="eval")
transformers = [
SimpleSimpliflyTransformer(expr=expr),
MulDivCancellerTransformer(expr=expr),
ExactMulDivConstantFolderTransformer(expr=expr),
MulDivCancellerTransformer(expr=expr),
MaxToXorTransformer(expr=expr),
SimplifyParensTransformer(expr=expr),
ReorderCommutativeOpsTransformer(expr=expr),
]
for tr in transformers:
tree = tr.visit(tree)
ast.fix_missing_locations(tree.body)
expr = ast.unparse(tree)
simp = ExpressionSimplifierAddVisitor(expr=expr)
simp.visit(tree.body)
return simp.make_simplified()
[docs]
def simplify_two_expressions(expr1: str, expr2: str) -> Dict[str, int]:
"""Simplifies an expression exp1 == exp2."""
expr = f"{expr1}-({expr2})"
simp1 = ExpressionSimplifierAddVisitor(expr)
simp1.visit(ast.parse(expr, mode="eval").body)
return {k: v for k, v in simp1.coeffs.items() if v != 0}