Source code for experimental_experiment.torch_interpreter.patches.patch_torch
importtorch
[docs]defpatched_infer_size(a,b):"""Patches ``torch._subclasses.fake_impls.infer_size``."""fromtorch.fx.experimental.symbolic_shapesimportguard_size_obliviousdimsA=len(a)dimsB=len(b)ndim=max(dimsA,dimsB)expandedSizes=[0]*ndimforiinrange(ndim-1,-1,-1):offset=ndim-1-idimA=dimsA-1-offsetdimB=dimsB-1-offsetsizeA=a[dimA]ifdimA>=0else1sizeB=b[dimB]ifdimB>=0else1# NB: It is very important to test for broadcasting, before testing# sizeA == sizeB. This is because the broadcasting tests are likely# to be statically known (in particular, if sizeA/sizeB is unbacked# but size-like, we will unsoundly assume they never equal 1), but# the sizeA == sizeB test may not be statically known. However, once# we have established that no broadcasting is happening, the# sizeA == sizeB is now expect_true and we can defer it as a runtime# assert (this works because Python will return the terminal# expression of an or statement as-is, without bool()'ing it; if this# were not the case, we'd need to write this using torch.sym_or() or# something like that).try:b1=guard_size_oblivious(sizeA==1)excepttorch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:b1=Falsetry:b2=guard_size_oblivious(sizeB==1)excepttorch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:b2=Falsetry:b3=guard_size_oblivious(sizeA==sizeB)excepttorch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:b3=Falseifb1orb2orb3:expandedSizes[i]=sizeBifguard_size_oblivious(sizeA==1)elsesizeAelse:# In this case, the current implementation of torch fails (17/12/2024).# Try model SmolLM.expandedSizes[i]=torch.sym_max(sizeA,sizeB)returntuple(expandedSizes)