diff --git a/EqSat/src/simple_ast.rs b/EqSat/src/simple_ast.rs index 78b8333..87f5bcb 100644 --- a/EqSat/src/simple_ast.rs +++ b/EqSat/src/simple_ast.rs @@ -27,6 +27,7 @@ pub struct AstIdx(pub u32); pub struct Arena { pub elements: Vec<(SimpleAst, AstData)>, ast_to_idx: AHashMap, + isle_cache: AHashMap, // Map a name to it's corresponds symbol index. symbol_ids: Vec<(String, AstIdx)>, @@ -37,6 +38,7 @@ impl Arena { pub fn new() -> Self { let elements = Vec::with_capacity(65536); let ast_to_idx = AHashMap::with_capacity(65536); + let isle_cache = AHashMap::with_capacity(65536); let symbol_ids = Vec::with_capacity(255); let name_to_symbol = AHashMap::with_capacity(255); @@ -44,6 +46,7 @@ impl Arena { Arena { elements: elements, ast_to_idx: ast_to_idx, + isle_cache: isle_cache, symbol_ids: symbol_ids, name_to_symbol: name_to_symbol, @@ -813,6 +816,9 @@ pub fn eval_ast(ctx: &Context, idx: AstIdx, value_mapping: &HashMap // Recursively apply ISLE over an AST. pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { + if ctx.arena.isle_cache.get(&idx).is_some() { + return *ctx.arena.isle_cache.get(&idx).unwrap(); + } let mut ast = ctx.arena.get_node(idx).clone(); match ast { @@ -862,7 +868,9 @@ pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx { ast = result.unwrap(); } - return ctx.arena.ast_to_idx[&ast]; + let result = ctx.arena.ast_to_idx[&ast]; + ctx.arena.isle_cache.insert(idx, result); + result } // Evaluate the current AST for all possible combinations of zeroes and ones as inputs. diff --git a/Mba.Simplifier/Bindings/AstIdx.cs b/Mba.Simplifier/Bindings/AstIdx.cs index 1a5849f..93d1720 100644 --- a/Mba.Simplifier/Bindings/AstIdx.cs +++ b/Mba.Simplifier/Bindings/AstIdx.cs @@ -22,6 +22,11 @@ public override string ToString() return ctx.GetAstString(Idx); } + public override int GetHashCode() + { + return Idx.GetHashCode(); + } + public unsafe static implicit operator uint(AstIdx reg) => reg.Idx; public unsafe static implicit operator AstIdx(uint reg) => new AstIdx(reg); diff --git a/Mba.Simplifier/Minimization/BooleanMinimizer.cs b/Mba.Simplifier/Minimization/BooleanMinimizer.cs index e88f3fa..a7e37c4 100644 --- a/Mba.Simplifier/Minimization/BooleanMinimizer.cs +++ b/Mba.Simplifier/Minimization/BooleanMinimizer.cs @@ -191,7 +191,7 @@ private static AstIdx MinimizeAnf(AstCtx ctx, IReadOnlyList variables, T } var r = ctx.MinimizeAnf(TableDatabase.Instance.db, truthTable, tempVars, MultibitSiMBA.JitPage.Value); - var backSubst = GeneralSimplifier.ApplyBackSubstitution(ctx, r, invSubstMapping); + var backSubst = GeneralSimplifier.BackSubstitute(ctx, r, invSubstMapping); return backSubst; } } diff --git a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs index 97c9ec0..b7dba93 100644 --- a/Mba.Simplifier/Pipeline/GeneralSimplifier.cs +++ b/Mba.Simplifier/Pipeline/GeneralSimplifier.cs @@ -20,8 +20,14 @@ namespace Mba.Simplifier.Pipeline { + + public class GeneralSimplifier { + public static bool DbgLog = false; + + private const bool REDUCE_POLYS = false; + private readonly AstCtx ctx; // For any given node, we store the best possible ISLE result. @@ -88,14 +94,14 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (ctx.IsConstant(id)) return id; - if(linClass != AstClassification.Nonlinear) + if (linClass != AstClassification.Nonlinear) { // Bail out if there are too many variables. var vars = ctx.CollectVariables(id); - if(vars.Count > 11 || vars.Count == 0) + if (vars.Count > 11 || vars.Count == 0) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -116,12 +122,17 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // Discard any vanished substitutions var usedVars = ctx.CollectVariables(withSubstitutions).ToHashSet(); - foreach(var (substValue, substVar) in substMapping.ToList()) + foreach (var (substValue, substVar) in substMapping.ToList()) { if (!usedVars.Contains(substVar)) substMapping.Remove(substValue); } + if (substMapping.Count > 8) + { + Console.WriteLine(substMapping.Count); + //Debugger.Break(); + } // Try to take a guess (MSiMBA) and prove it's equivalence var guess = SimplifyViaGuessAndProve(withSubstitutions, substMapping, ref isSemiLinear); if (guess != null) @@ -140,6 +151,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) withSubstitutions = TryUnmergeLinCombs(withSubstitutions, substMapping, ref isSemiLinear); withSubstitutions = SimplifyViaTermRewriting(withSubstitutions); + // If polynomial parts are present, try to simplify them. var inverseMapping = substMapping.ToDictionary(x => x.Value, x => x.Key); AstIdx? reducedPoly = null; @@ -149,10 +161,10 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) reducedPoly = ReducePolynomials(GetRootTerms(ctx, withSubstitutions), substMapping, inverseMapping); // If we succeeded, reset the state. - if(reducedPoly != null) + if (reducedPoly != null) { // Back substitute the original substitutions. - reducedPoly = ApplyBackSubstitution(ctx, reducedPoly.Value, inverseMapping); + reducedPoly = BackSubstitute(ctx, reducedPoly.Value, inverseMapping); // Reset internal state. substMapping.Clear(); @@ -164,7 +176,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) // If there are any substitutions, we want to try simplifying the polynomial parts. var variables = ctx.CollectVariables(withSubstitutions); - if (polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) + if (REDUCE_POLYS && polySimplify && substMapping.Count > 0 && ctx.GetHasPoly(id)) { var maybeSimplified = TrySimplifyMixedPolynomialParts(withSubstitutions, substMapping, inverseMapping, variables); if (maybeSimplified != null && maybeSimplified.Value != id) @@ -182,7 +194,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) if (variables.Count > 11) { var simplified = SimplifyViaTermRewriting(id); - simbaCache.Add(id, simplified); + simbaCache.TryAdd(id, simplified); return simplified; } @@ -193,7 +205,7 @@ private AstIdx SimplifyViaRecursiveSiMBA(AstIdx id, bool polySimplify = true) var result = withSubstitutions; if (!ctx.IsConstant(withSubstitutions)) result = LinearSimplifier.Run(ctx.GetWidth(withSubstitutions), ctx, withSubstitutions, false, isSemiLinear, false, variables); - var backSub = ApplyBackSubstitution(ctx, result, inverseMapping); + var backSub = BackSubstitute(ctx, result, inverseMapping); // Apply constant folding / term rewriting. var propagated = SimplifyViaTermRewriting(backSub); @@ -235,6 +247,29 @@ private static ulong Pow(ulong bbase, ulong exponent) private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary substitutionMapping, ref bool isSemiLinear, bool inBitwise = false) { + /* + // This is dubious: Do we actually need to run simba here... for some reason performance degrades if not + // TODO: Maybe comment this out + var cls = ctx.GetClass(id); + if (cls == AstClassification.Bitwise) + return SimplifyViaRecursiveSiMBA(id); + if (cls == AstClassification.BitwiseWithConstants) + { + isSemiLinear = true; + return SimplifyViaRecursiveSiMBA(id); + } + + // Note: These two checks seem to hurt performance too! + if (cls == AstClassification.Linear && !inBitwise) + return SimplifyViaRecursiveSiMBA(id); + if (cls == AstClassification.SemiLinear && !inBitwise) + { + isSemiLinear = true; + return SimplifyViaRecursiveSiMBA(id); + } + */ + + // Sometimes we perform constant folding in this method. // To make sure that we correctly track whether the expression is semi-linear, we use this method to process replacements. var visitReplacement = (AstIdx replacementIdx, bool inBitwise, ref bool isSemiLinear) => GetAstWithSubstitutions(replacementIdx, substitutionMapping, ref isSemiLinear, inBitwise); @@ -308,6 +343,14 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub var oldSum = sum; var newSum = ctx.SingleSimplify(sum); sum = newSum; + + + if(GeneralSimplifier.DbgLog) + { + Console.WriteLine($"ConstTerm = {DagFormatter.Format(ctx, v0)}"); + Console.WriteLine($"\nWhole dag: {DagFormatter.Format(ctx, sum)}\n\n\n\n\n"); + } + // In this case, we apply constant folding(but we do not search recursively). return GetAstWithSubstitutions(sum, substitutionMapping, ref isSemiLinear, inBitwise); @@ -352,7 +395,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub (and0, and1) = (and1, and0); (id0, id1) = (id1, id0); } - + // Rewrite (a&mask) as `Trunc(a)`, or `Trunc(a & mask)` if mask is not completely a bit mask. // This is a form of adhoc demanded bits based simplification if (ctx.IsConstant(and0) && !ctx.IsConstant(and1)) @@ -379,7 +422,7 @@ private AstIdx GetAstWithSubstitutions(AstIdx id, Dictionary sub return ext; } } - + return ctx.And(and0, and1); } @@ -508,12 +551,12 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) // Skip if this is not a multiplication. var opcode = ctx.GetOpcode(id); - var roots = GetRootMultiplications(ctx,id); + var roots = GetRootMultiplications(ctx, id); ulong coeffSum = 0; Dictionary constantPowers = new(); List others = new(); - foreach(var root in roots) + foreach (var root in roots) { var code = ctx.GetOpcode(root); if (code == AstOp.Constant) @@ -525,12 +568,12 @@ public static PolynomialParts GetPolynomialParts(AstCtx ctx, AstIdx id) constantPowers.TryAdd(root, 0); constantPowers[root]++; } - else if(code == AstOp.Pow) + else if (code == AstOp.Pow) { // If we have a power by a nonconstant, we can't really do much here. var degree = ctx.GetOp1(root); var constPower = ctx.TryGetConstantValue(degree); - if(constPower == null) + if (constPower == null) { others.Add(root); continue; @@ -568,7 +611,7 @@ public static int VarsFirst(AstCtx ctx, AstIdx a, AstIdx b) return comeFirst; if (op1 && !op0) return comeLast; - if(op0 && op1) + if (op0 && op1) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); return comeLast; } @@ -590,7 +633,7 @@ private int CompareTo(AstIdx a, AstIdx b) return comeLast; // Sort symbols alphabetically - if(op0 == AstOp.Symbol && op1 == AstOp.Symbol) + if (op0 == AstOp.Symbol && op1 == AstOp.Symbol) return ctx.GetSymbolName(a).CompareTo(ctx.GetSymbolName(b)); if (op0 == AstOp.Pow) return comeLast; @@ -604,7 +647,7 @@ private AstIdx GetSubstitution(AstIdx id, Dictionary substitutio if (substitutionMapping.TryGetValue(id, out var existing)) return existing; - while(true) + while (true) { var subst = ctx.Symbol($"subst{substCount}", ctx.GetWidth(id)); substCount++; @@ -631,10 +674,10 @@ private AstIdx TryUnmergeLinCombs(AstIdx withSubstitutions, Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary(); var results = new List(); - for(int i = 0 ; i < inputExpressions.Count; i++) + for (int i = 0; i < inputExpressions.Count; i++) { // Substitute all of the nonlinear parts for this expression // Here we share the list of substitutions @@ -701,7 +744,7 @@ private Dictionary UnmergeNegatedParts(Dictionary> vecToExpr = new(); - for(int i = 0; i < results.Count; i++) + for (int i = 0; i < results.Count; i++) { var expr = results[i]; var w = ctx.GetWidth(expr); @@ -717,7 +760,7 @@ private Dictionary UnmergeNegatedParts(Dictionary varToNewSubstValue = new Dictionary(); - foreach(var (key, members) in vecToExpr) + foreach (var (key, members) in vecToExpr) { var temp = results[members.First().index]; var w = ctx.GetWidth(temp); @@ -731,12 +774,12 @@ private Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary UnmergeNegatedParts(Dictionary varToDemandedBits = new(); - foreach(var (expr, substVar) in substitutionMapping) - ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits); + var cache = new HashSet<(AstIdx idx, ulong currDemanded)>(); + int totalDemanded = 0; + foreach (var (expr, substVar) in substitutionMapping) + { + ComputeSymbolDemandedBits(expr, ModuloReducer.GetMask(ctx.GetWidth(expr)), varToDemandedBits, cache, ref totalDemanded); + if (totalDemanded > 12) + break; + } + - // Compute the total number of demanded variable bits in the substituted parts. - ulong totalDemanded = 0; - foreach (var demandedBits in varToDemandedBits.Values) - totalDemanded += (ulong)BitOperations.PopCount(demandedBits); + // Bail if there are too many demanded bits! if (totalDemanded > 12) return null; @@ -849,7 +896,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong if (constrainedIdx == null) { // Simplify the constrained parts. - var withoutSubstitutions = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var r = SimplifyUnconstrained(withoutSubstitutions, varToDemandedBits); if (r == null) return null; @@ -872,7 +919,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong return null; // Simplify unconstrained parts. - var unconstrainedBackSub = ApplyBackSubstitution(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var unconstrainedBackSub = BackSubstitute(ctx, unconstrainedIdx.Value, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var unconstrainedSimpl = SimplifyUnconstrained(unconstrainedBackSub, varToDemandedBits); if (unconstrainedSimpl == null) return null; @@ -888,7 +935,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong { // Construct a result vector for the linear part. var substVars = substitutionMapping.Values.ToList(); - var allVars = ctx.CollectVariables(withSubstitutions); + IReadOnlyList allVars = ctx.CollectVariables(withSubstitutions); var bitSize = ctx.GetWidth(withSubstitutions); var numCombinations = (ulong)Math.Pow(2, allVars.Count); var groupSizes = LinearSimplifier.GetGroupSizes(allVars.Count); @@ -927,10 +974,12 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong List constrainedParts = new(); // Decompose result vector into semi-linear, unconstrained, and constrained parts. + // Upcast variables as necessary! + allVars = LinearSimplifier.CastVariables(ctx, allVars, bitSize); int resultVecIdx = 0; - for(int i = 0; i < linearCombinations.Count; i++) + for (int i = 0; i < linearCombinations.Count; i++) { - foreach(var (coeff, bitMask) in linearCombinations[i]) + foreach (var (coeff, bitMask) in linearCombinations[i]) { if (coeff == 0) goto skip; @@ -966,7 +1015,7 @@ private bool CanFitConstantInUndemandedBits(KnownBits kb, ulong constant, ulong } // TODO: Refactor out! - private static (ulong[], List>) GetAnf(uint width, List variables, List groupSizes, ulong[] resultVector, bool multiBit) + private static (ulong[], List>) GetAnf(uint width, IReadOnlyList variables, List groupSizes, ulong[] resultVector, bool multiBit) { // Get all combinations of variables. var moduloMask = ModuloReducer.GetMask(width); @@ -1033,7 +1082,7 @@ private static (ulong[], List>) GetAnf(uint w private unsafe AstIdx? SimplifyConstrained(AstIdx withSubstitutions, Dictionary substitutionMapping, Dictionary varToDemandedBits) { // Compute a result vector for the original expression - var withoutSubstitutions = ApplyBackSubstitution(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); + var withoutSubstitutions = BackSubstitute(ctx, withSubstitutions, substitutionMapping.ToDictionary(x => x.Value, x => x.Key)); var w = ctx.GetWidth(withoutSubstitutions); var inputVars = ctx.CollectVariables(withoutSubstitutions); var originalResultVec = LinearSimplifier.JitResultVector(ctx, w, ModuloReducer.GetMask(w), inputVars, withoutSubstitutions, true, (ulong)Math.Pow(2, inputVars.Count)); @@ -1042,7 +1091,7 @@ private static (ulong[], List>) GetAnf(uint w var exprToSubstVar = substitutionMapping.OrderBy(x => ctx.GetAstString(x.Value)).ToList(); var allVars = inputVars.Concat(exprToSubstVar.Select(x => x.Value)).ToList(); // Sort them.... var pagePtr = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, true); + new Amd64OptimizingJit(ctx).Compile(withSubstitutions, allVars, pagePtr, false); var jittedWithSubstitutions = (delegate* unmanaged[SuppressGCTransition])pagePtr; // Return null if the expressions are not provably equivalent @@ -1060,7 +1109,7 @@ private static (ulong[], List>) GetAnf(uint w } // Returns true if two expressions are guaranteed to be equivalent - private unsafe bool IsConstrainedExpressionEquivalent(uint width,List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) + private unsafe bool IsConstrainedExpressionEquivalent(uint width, List inputVars, List<(AstIdx demandedVar, ulong demandedMask)> demandedVars, List> exprToSubstVar, delegate* unmanaged[SuppressGCTransition] jittedWithSubstitutions, ulong[] originalResultVec) { int totalDemanded = demandedVars.Sum(x => BitOperations.PopCount(x.demandedMask)); @@ -1153,12 +1202,12 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in // Jit the input expression var pagePtr1 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, true); + new Amd64OptimizingJit(ctx).Compile(withoutSubstitutions, inputVars, pagePtr1, false); var jittedBefore = (delegate* unmanaged[SuppressGCTransition])pagePtr1; // Jit the output expression var pagePtr2 = JitUtils.AllocateExecutablePage(4096); - new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, true); + new Amd64OptimizingJit(ctx).Compile(expectedExpr, inputVars, pagePtr2, false); var jittedAfter = (delegate* unmanaged[SuppressGCTransition])pagePtr2; // Prove that they are equivalent for all possible input combinations @@ -1208,22 +1257,56 @@ private unsafe bool IsConstrainedExpressionEquivalent(uint width,List in } } + JitUtils.FreeExecutablePage(pagePtr1); + JitUtils.FreeExecutablePage(pagePtr2); return expectedExpr; } + public struct DemandedBitsTuple + { + public AstIdx Idx; + + public ulong CurrDemanded; + + public DemandedBitsTuple(AstIdx idx, ulong currDemanded) + { + Idx = idx; + CurrDemanded = currDemanded; + } + + public override int GetHashCode() + { + int hash = 17; + hash = hash * 31 + Idx.GetHashCode(); + hash = hash * 31 + CurrDemanded.GetHashCode(); + return hash; + } + } + // TODO: Cache results to avoid exponentially visiting shared nodes - private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits) + private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionary symbolDemandedBits, HashSet<(AstIdx idx, ulong currDemanded)> seen, ref int totalDemanded) { - var op0 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits); - var op1 = (ulong demanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits); + if (totalDemanded > 12) + return; + if (!seen.Add((idx, currDemanded))) + return; + + totalDemanded += 1; + + var op0 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp0(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); + var op1 = (ulong demanded, ref int totalDemanded) => ComputeSymbolDemandedBits(ctx.GetOp1(idx), demanded, symbolDemandedBits, seen, ref totalDemanded); var opc = ctx.GetOpcode(idx); - switch(opc) + switch (opc) { // If we have a symbol, union the set of demanded bits case AstOp.Symbol: - symbolDemandedBits.TryAdd(idx, 0); - symbolDemandedBits[idx] |= currDemanded; + //symbolDemandedBits.TryAdd(idx, 0); + symbolDemandedBits.TryGetValue(idx, out var oldDemanded); + var newDemanded = oldDemanded | currDemanded; + symbolDemandedBits[idx] = newDemanded; + totalDemanded += BitOperations.PopCount(newDemanded & ~oldDemanded); + break; // If we have a constant, there is nothing to do. case AstOp.Constant: @@ -1233,26 +1316,27 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // For addition by a constant we can also get more precision case AstOp.Add: case AstOp.Mul: + case AstOp.Pow: // If we have addition/multiplication, we only care about bits at and below the highest set bit. var demandedWidth = 64 - (uint)BitOperations.LeadingZeroCount(currDemanded); currDemanded = ModuloReducer.GetMask(demandedWidth); - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.Lshr: var shiftBy = ctx.GetOp1(idx); var shiftByConstant = ctx.TryGetConstantValue(shiftBy); if (shiftByConstant == null) { - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; } // If we know the value we are shifting by, we can truncate the demanded bits. - op0(currDemanded >> (ushort)shiftByConstant.Value); - op1(currDemanded); + op0(currDemanded >> (ushort)shiftByConstant.Value, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; case AstOp.And: @@ -1260,8 +1344,8 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a&b, demandedbits(a) does not include any known zero bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Zeroes & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Zeroes & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } @@ -1270,25 +1354,25 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar // If we have a|b, demandedbits(a) does not include any known one bits from b. Works both ways. var op0Demanded = ~ctx.GetKnownBits(ctx.GetOp1(idx)).Ones & currDemanded; var op1Demanded = ~ctx.GetKnownBits(ctx.GetOp0(idx)).Ones & currDemanded; - op0(op0Demanded); - op1(op1Demanded); + op0(op0Demanded, ref totalDemanded); + op1(op1Demanded, ref totalDemanded); break; } // TODO: We can gain some precision by exploiting XOR known bits. case AstOp.Xor: - op0(currDemanded); - op1(currDemanded); + op0(currDemanded, ref totalDemanded); + op1(currDemanded, ref totalDemanded); break; // TODO: Treat negation as x^-1, then use XOR transfer function case AstOp.Neg: - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Trunc: currDemanded &= ModuloReducer.GetMask(ctx.GetWidth(idx)); - op0(currDemanded); + op0(currDemanded, ref totalDemanded); break; case AstOp.Zext: - op0(currDemanded & ctx.GetWidth(ctx.GetOp0(idx))); + op0(currDemanded & ModuloReducer.GetMask(ctx.GetWidth(ctx.GetOp0(idx))), ref totalDemanded); break; default: throw new InvalidOperationException($"Cannot compute demanded bits for {opc}"); @@ -1298,7 +1382,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar private AstIdx? TrySimplifyMixedPolynomialParts(AstIdx id, Dictionary substMapping, Dictionary inverseSubstMapping, List varList) { // Back substitute in the (possibly) polynomial parts - var newId = ApplyBackSubstitution(ctx, id, inverseSubstMapping); + var newId = BackSubstitute(ctx, id, inverseSubstMapping); // Decompose each term into structured polynomial parts var terms = GetRootTerms(ctx, newId); @@ -1313,14 +1397,14 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar return null; // Add back any banned parts. - if(banned.Any()) + if (banned.Any()) { var sum = ctx.Add(banned.Select(x => GetAstForPolynomialParts(x))); result = ctx.Add(result.Value, sum); } // Do a full back substitution again. - result = ApplyBackSubstitution(ctx, result.Value, inverseSubstMapping); + result = BackSubstitute(ctx, result.Value, inverseSubstMapping); // Bail out if this resulted in a worse result. var cost1 = ctx.GetCost(result.Value); @@ -1336,7 +1420,7 @@ private void ComputeSymbolDemandedBits(AstIdx idx, ulong currDemanded, Dictionar private List UnmergePolynomialParts(Dictionary substitutionMapping, List parts) { // Skip if there is only one substituted part. - if(substitutionMapping.Count <= 1) + if (substitutionMapping.Count <= 1) return parts; // Try to rewrite substituted parts as negations of one another. Exit early if this fails. @@ -1350,7 +1434,7 @@ private List UnmergePolynomialParts(Dictionary var outPowers = new Dictionary(); foreach (var (factor, degree) in part.ConstantPowers) { - var unmerged = ApplyBackSubstitution(ctx, factor, rewriteMapping); + var unmerged = BackSubstitute(ctx, factor, rewriteMapping); outPowers.TryAdd(unmerged, 0); outPowers[unmerged] += degree; } @@ -1369,11 +1453,11 @@ private List UnmergePolynomialParts(Dictionary // Rewrite as a sum of polynomial parts, where the factors are linear MBAs with substitution of nonlinear parts. var bannedParts = new List(); List partsWithSubstitutions = new(); - foreach(var polyPart in polyParts) + foreach (var polyPart in polyParts) { bool isSemiLinear = false; Dictionary powers = new(); - foreach(var factor in polyPart.ConstantPowers) + foreach (var factor in polyPart.ConstantPowers) { var withSubstitutions = GetAstWithSubstitutions(factor.Key, substMapping, ref isSemiLinear); powers.TryAdd(withSubstitutions, 0); @@ -1381,7 +1465,7 @@ private List UnmergePolynomialParts(Dictionary } // TODO: Handle the semi-linear case. - if(isSemiLinear) + if (isSemiLinear) { bannedParts.Add(polyPart); continue; @@ -1443,7 +1527,7 @@ private List UnmergePolynomialParts(Dictionary } // Calculate the max possible size of the resulting expression when multiplied out. - for(ulong i = 0; i < degree; i++) + for (ulong i = 0; i < degree; i++) { size = SaturatingMul(size, numNonZeroes); } @@ -1459,7 +1543,7 @@ private List UnmergePolynomialParts(Dictionary // When the basis element corresponds to the constant offset, we want to make the base bitwise expression be `1`. // Otherwise we just substitute it with a variable. AstIdx basis = ctx.Constant(1, (byte)bitSize); - if(i != 0) + if (i != 0) { if (!basisSubstitutions.TryGetValue((ulong)i, out basis)) { @@ -1481,7 +1565,7 @@ private List UnmergePolynomialParts(Dictionary // If the expanded form would be too large, we want to block this polynomial. // It would take too long. - if(size >= 1000) + if (size >= 1000) { bannedParts.Add(polyPart); continue; @@ -1502,7 +1586,7 @@ private List UnmergePolynomialParts(Dictionary poly = constOffset; } - + polys.Add(poly.Value); } @@ -1514,15 +1598,15 @@ private List UnmergePolynomialParts(Dictionary var linComb = ctx.Add(polys); var reduced = ExpandReduce(linComb, false); // Add back banned parts - if(bannedParts.Any()) + if (bannedParts.Any()) { var sum = ctx.Add(bannedParts.Select(x => GetAstForPolynomialParts(x))); reduced = ctx.Add(reduced, sum); } var invBases = basisSubstitutions.ToDictionary(x => x.Value, x => LinearSimplifier.ConjunctionFromVarMask(ctx, allVars, 1, x.Key)); - var backSub = ApplyBackSubstitution(ctx, reduced, invBases); - backSub = ApplyBackSubstitution(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); + var backSub = BackSubstitute(ctx, reduced, invBases); + backSub = BackSubstitute(ctx, backSub, substMapping.ToDictionary(x => x.Value, x => x.Key)); return backSub; } @@ -1628,16 +1712,16 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Try to decompose into high degree polynomials parts. List polyTerms = new(); List other = new(); - foreach(var term in terms) + foreach (var term in terms) { // Typically this is going to be a multiplication(coefficient over substituted variable), or whole substituted variable. // TODO: Handle negation. var opcode = ctx.GetOpcode(term); - if(opcode != AstOp.Mul && opcode != AstOp.Symbol) + if (opcode != AstOp.Mul && opcode != AstOp.Symbol) goto skip; - + // Search for coeff*subst - if(opcode == AstOp.Mul) + if (opcode == AstOp.Mul) { // If multiplication, we are looking for coeff*(subst), where coeff is a constant. var coeff = ctx.GetOp0(term); @@ -1656,14 +1740,14 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) } // Search for a plain substitution(omitted coefficient of 1) - if(opcode == AstOp.Symbol && IsSubstitutedPolynomialSymbol(term, inverseSubstMapping)) + if (opcode == AstOp.Symbol && IsSubstitutedPolynomialSymbol(term, inverseSubstMapping)) { var invSubst = inverseSubstMapping[term]; polyTerms.Add(GetPolynomialParts(ctx, invSubst)); continue; } - skip: + skip: other.Add(term); continue; } @@ -1680,7 +1764,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) var uniqueBases = new Dictionary(); foreach (var poly in polyTerms) { - foreach(var (_base, degree) in poly.ConstantPowers) + foreach (var (_base, degree) in poly.ConstantPowers) { // Set the default degree to zero. uniqueBases.TryAdd(_base, 0); @@ -1688,7 +1772,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // For each unique base, we want to keep track of the highest degree. var oldDegree = uniqueBases[_base]; var newDeg = degree; - if(newDeg > oldDegree) + if (newDeg > oldDegree) uniqueBases[_base] = newDeg; } } @@ -1706,7 +1790,7 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // If the dense vector size would be greater than 64**3, we bail out. // In those cases, we may consider implementing variable partitioning and simplifying each partition separately. - if (vecSize > 64*64*64) + if (vecSize > 64 * 64 * 64) return null; // For now we only support polynomials up to degree 255, although this is a somewhat arbitrary limit. @@ -1727,10 +1811,10 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) foreach (var poly in polyTerms) { var coeff = poly.coeffSum; - + var constPowers = poly.ConstantPowers; var degrees = new byte[orderedVars.Count]; - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -1756,24 +1840,24 @@ public static List GetRootMultiplications(AstCtx ctx, AstIdx id) // Add back all of the ignored parts. newTerms.AddRange(other); // Add back all of the discarded polynomial parts - foreach(var part in discarded) + foreach (var part in discarded) { var ast = GetAstForPolynomialParts(part); newTerms.Add(ast); } // Then finally convert the sparse polynomial back to an AST. - foreach(var (monom, coeff) in simplified.coeffs) + foreach (var (monom, coeff) in simplified.coeffs) { if (coeff == 0) continue; List factors = new(); factors.Add(ctx.Constant(coeff, width)); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); - if(deg == 0) + if (deg == 0) { factors.Add(ctx.Constant(1, width)); continue; @@ -1813,17 +1897,17 @@ private bool IsSubstitutedPolynomialSymbol(AstIdx id, IReadOnlyDictionary terms = new(); var width = ctx.GetWidth(id); - foreach(var (monom, coeff) in result.coeffs) + foreach (var (monom, coeff) in result.coeffs) { List factors = new(); factors.Add(ctx.Constant(coeff, width)); - foreach(var (varIdx, degree) in monom.varDegrees) + foreach (var (varIdx, degree) in monom.varDegrees) { // Skip a constant factor of 1 if (degree == 0) continue; - if(degree == 1) + if (degree == 1) { factors.Add(varIdx); continue; @@ -1879,7 +1963,7 @@ public AstIdx ExpandReduce(AstIdx id, bool polySimplify = true) // Back substitute the substitute variables. var inverseMapping = substMapping.ToDictionary(x => x.Value, x => x.Key); - sum = ApplyBackSubstitution(ctx, sum, inverseMapping); + sum = BackSubstitute(ctx, sum, inverseMapping); // Try to simplify using the general simplifier. sum = ctx.RecursiveSimplify(sum); @@ -1909,12 +1993,12 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa return poly; }; - switch(opcode) + switch (opcode) { case AstOp.Mul: var factors = GetRootMultiplications(ctx, id); var facPolys = factors.Select(x => TryExpand(x, substMapping, false)).ToList(); - var product = IntermediatePoly.Mul(ctx,facPolys); + var product = IntermediatePoly.Mul(ctx, facPolys); resultPoly = product; // In this case we should probably distribute the coefficient down always. @@ -1981,7 +2065,7 @@ private IntermediatePoly TryExpand(AstIdx id, Dictionary substMa // If this is the root of a polynomial part, we want to try and reduce it. // Alternatively we may apply a reduction if there are too many terms. bool shouldReduce = isRoot || resultPoly?.coeffs?.Count > 10; - if(shouldReduce && resultPoly != null) + if (shouldReduce && resultPoly != null) { resultPoly = TryReduce(resultPoly); } @@ -2018,15 +2102,15 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) { // Bail out if the result would be too large. UInt128 result = matrixSize * deg; - if (result > (UInt128)(64*64*64)) + if (result > (UInt128)(64 * 64 * 64)) return poly; matrixSize = SaturatingMul(matrixSize, deg); matrixSize &= poly.moduloMask; } - + // Place a limit on the matrix size. - if (matrixSize > (ulong)(64*64*64)) + if (matrixSize > (ulong)(64 * 64 * 64)) return poly; var width = poly.bitWidth; @@ -2040,7 +2124,7 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) var degrees = new byte[orderedVars.Count]; foreach (var (monom, coeff) in poly.coeffs) { - for(int varIdx = 0; varIdx < orderedVars.Count; varIdx++) + for (int varIdx = 0; varIdx < orderedVars.Count; varIdx++) { var variable = orderedVars[varIdx]; ulong degree = 0; @@ -2059,18 +2143,18 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) var newCount = simplified.coeffs.Count(x => x.Value != 0); // If we got a result with more terms, skip it. // This is required when doing expansion, since expansion is exponential in the number of terms. - if(newCount > oldCount) + if (newCount > oldCount) return poly; var outPoly = new IntermediatePoly(width); // Otherwise we can convert the sparse polynomial back to an AST. - foreach(var (monom, coeff) in simplified.coeffs) + foreach (var (monom, coeff) in simplified.coeffs) { if (coeff == 0) continue; Dictionary varDegrees = new(); - for(int i = 0; i < orderedVars.Count; i++) + for (int i = 0; i < orderedVars.Count; i++) { var deg = monom.GetVarDeg(i); if (deg == 0) @@ -2079,7 +2163,7 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) } // Handle the case of a constant offset. - if(varDegrees.Count == 0) + if (varDegrees.Count == 0) { varDegrees.Add(ctx.Constant(1, width), 1); } @@ -2091,18 +2175,19 @@ private IntermediatePoly TryReduce(IntermediatePoly poly) return outPoly; } - public static AstIdx ApplyBackSubstitution(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache = null) + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions) + => BackSubstitute(ctx, id, backSubstitutions, new(16)); + + public static AstIdx BackSubstitute(AstCtx ctx, AstIdx id, Dictionary backSubstitutions, Dictionary cache) { - if (cache == null) - cache = new(); if (backSubstitutions.TryGetValue(id, out var backSub)) return backSub; if (cache.TryGetValue(id, out var existing)) return existing; - var op0 = () => ApplyBackSubstitution(ctx, ctx.GetOp0(id), backSubstitutions, cache); - var op1 = () => ApplyBackSubstitution(ctx, ctx.GetOp1(id), backSubstitutions, cache); + var op0 = () => BackSubstitute(ctx, ctx.GetOp0(id), backSubstitutions, cache); + var op1 = () => BackSubstitute(ctx, ctx.GetOp1(id), backSubstitutions, cache); var opcode = ctx.GetOpcode(id); var width = ctx.GetWidth(id); diff --git a/Mba.Simplifier/Pipeline/LinearSimplifier.cs b/Mba.Simplifier/Pipeline/LinearSimplifier.cs index b1989bf..639c144 100644 --- a/Mba.Simplifier/Pipeline/LinearSimplifier.cs +++ b/Mba.Simplifier/Pipeline/LinearSimplifier.cs @@ -46,7 +46,7 @@ public class LinearSimplifier private readonly bool tryDecomposeMultiBitBases; private readonly Action? resultVectorHook; - + private readonly int depth; private readonly ApInt moduloMask = 0; // Number of combinations of input variables(2^n), for a single bit index. @@ -69,14 +69,14 @@ public class LinearSimplifier private AstIdx? initialInput = null; - public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null) + public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList variables = null, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) { if (variables == null) variables = ctx.CollectVariables(ast.Value); - return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec).Simplify(false, alreadySplit); + return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth).Simplify(false, alreadySplit); } - public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null) + public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action? resultVectorHook = null, ApInt[] inVec = null, int depth = 0) { // If we are given an AST, verify that the correct width was passed. if (ast != null && bitSize != ctx.GetWidth(ast.Value)) @@ -90,6 +90,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables this.multiBit = multiBit; this.tryDecomposeMultiBitBases = tryDecomposeMultiBitBases; this.resultVectorHook = resultVectorHook; + this.depth = depth; moduloMask = (ApInt)ModuloReducer.GetMask(bitSize); groupSizes = GetGroupSizes(variables.Count); numCombinations = (ApInt)Math.Pow(2, variables.Count); @@ -124,7 +125,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList variables } } - private static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) + public static IReadOnlyList CastVariables(AstCtx ctx, IReadOnlyList variables, uint bitSize) { // If all variables are of a correct size, no casting is necessary. if (!variables.Any(x => ctx.GetWidth(x) != bitSize)) @@ -610,6 +611,20 @@ private AstIdx FindTwoTermsUnnegated(ApInt constant, ApInt a, ApInt b) private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demandedMask, ApInt[] variableCombinations, List> linearCombinations) { + + var vNames = this.variables.Select(x => ctx.GetAstString(x)); + //var expected = new List() { "subst594:i32", "subst595:i32", "subst596:i32", "subst597:i32", "(uns45:i64 tr i32)", "(uns48:i64 tr i32)"}; + + /* + var expected = new List() { "subst27:i32", "(uns41:i64 tr i32)", "(uns74:i64 tr i32)", "(uns75:i64 tr i32)" }; + if (expected.All(x => vNames.Any(y => y.Contains(x)))) + Debugger.Break(); + + if (depth > 10) + Debugger.Break(); + */ + + // Collect all variables used in the output expression. List mutVars = new(variables.Count); while (demandedMask != 0) @@ -619,6 +634,8 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded demandedMask &= ~(1ul << xorIdx); } + + var clone = variables.ToList(); AstIdx sum = ctx.Constant(constantOffset, width); for (int i = 0; i < linearCombinations.Count; i++) { @@ -628,14 +645,20 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded continue; var combMask = variableCombinations[i]; - var vComb = ctx.GetConjunctionFromVarMask(mutVars, combMask); + var widths = variables.Select(x => ctx.GetWidth(x)).ToList(); + //Console.WriteLine(widths.Distinct().Count()); + //foreach (var vIdx in variables) + // Console.WriteLine($"{ctx.GetAstString(vIdx)} => {ctx.GetWidth(vIdx)}"); + //Console.WriteLine("\n\n"); + + var vComb = ctx.GetConjunctionFromVarMask(clone, combMask); var term = Term(vComb, curr[0].coeff); sum = ctx.Add(sum, term); } // TODO: Instead of constructing a result vector inside the recursive linear simplifier call, we could instead convert the ANF vector back to DNF. // This should be much more efficient than constructing a result vector via JITing and evaluating an AST representation of the ANF vector. - return LinearSimplifier.Run(width, ctx, sum, false, false, false, variables); + return LinearSimplifier.Run(width, ctx, sum, false, false, false, mutVars, depth: depth + 1); } private void EliminateUniqueValues(Dictionary coeffToTable) diff --git a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs index 9d98e31..6888f21 100644 --- a/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs +++ b/Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs @@ -5,6 +5,7 @@ using Microsoft.Z3; using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; @@ -59,11 +60,11 @@ public ProbableEquivalenceChecker(AstCtx ctx, List variables, AstIdx bef public unsafe bool ProbablyEquivalent(bool slowHeuristics = false) { var jit1 = new Amd64OptimizingJit(ctx); - jit1.Compile(before, variables, pagePtr1, true); + jit1.Compile(before, variables, pagePtr1, false); func1 = (delegate* unmanaged[SuppressGCTransition])pagePtr1; var jit2 = new Amd64OptimizingJit(ctx); - jit2.Compile(after, variables, pagePtr2, true); + jit2.Compile(after, variables, pagePtr2, false); func2 = (delegate* unmanaged[SuppressGCTransition])pagePtr2; var vArray = stackalloc ulong[variables.Count]; @@ -99,6 +100,7 @@ public unsafe bool ProbablyEquivalent(bool slowHeuristics = false) return true; } + public static bool Log = false; private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses) { var clone = new ulong[variables.Count]; @@ -113,6 +115,10 @@ private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses) var op1 = func1(vArray); var op2 = func2(vArray); + + if(Log) + Console.WriteLine($"{op1}, {op2}"); + if (op1 != op2) return false; } @@ -128,6 +134,8 @@ private unsafe bool AllCombs(ulong* vArray, ulong a, ulong b) return false; if (!SignatureVectorEquivalent(vArray, a, b)) return false; + if (!SignatureVectorEquivalent(vArray, b, a)) + return false; return true; } @@ -168,34 +176,45 @@ private ulong Next() public static void ProbablyEquivalentZ3(AstCtx ctx, AstIdx before, AstIdx after) { - var z3Ctx = new Context(); - var translator = new Z3Translator(ctx, z3Ctx); - var beforeZ3 = translator.Translate(before); - var afterZ3 = translator.Translate(after); - var solver = z3Ctx.MkSolver("QF_BV"); - - // Set the maximum timeout to 10 seconds. - var p = z3Ctx.MkParams(); - uint solverLimit = 10000; - p.Add("timeout", solverLimit); - solver.Parameters = p; - - Console.WriteLine("Proving equivalence...\n"); - solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); - var check = solver.Check(); - - var printModel = (Model model) => + using (var z3Ctx = new Context()) { - var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); - return $"[{String.Join(", ", values)}]"; - }; - - if (check == Status.UNSATISFIABLE) - Console.WriteLine("Expressions are equivalent."); - else if (check == Status.SATISFIABLE) - Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); - else - Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + var translator = new Z3Translator(ctx, z3Ctx); + var beforeZ3 = translator.Translate(before); + var afterZ3 = translator.Translate(after); + var solver = z3Ctx.MkSolver("QF_BV"); + + // Set the maximum timeout to 10 seconds. + var p = z3Ctx.MkParams(); + uint solverLimit = 5000; + p.Add("timeout", solverLimit); + solver.Parameters = p; + + Console.WriteLine("Proving equivalence...\n"); + solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3))); + var check = solver.Check(); + + var printModel = (Model model) => + { + var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}"); + return $"[{String.Join(", ", values)}]"; + }; + + if (check == Status.UNSATISFIABLE) + { + //Console.WriteLine("Expressions are equivalent."); + } + else if (check == Status.SATISFIABLE) + { + Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}"); + Debugger.Break(); + throw new InvalidOperationException(); + + } + else + { + //Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms"); + } + } } } diff --git a/Mba.Simplifier/Utility/DagFormatter.cs b/Mba.Simplifier/Utility/DagFormatter.cs new file mode 100644 index 0000000..936d147 --- /dev/null +++ b/Mba.Simplifier/Utility/DagFormatter.cs @@ -0,0 +1,102 @@ +using Mba.Simplifier.Bindings; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Mba.Simplifier.Utility +{ + public static class DagFormatter + { + public static string Format(AstCtx ctx, AstIdx idx) + { + var sb = new StringBuilder(); + Format(sb, ctx, idx, new()); + return sb.ToString(); + } + + private static void Format(StringBuilder sb, AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + // Allocate value numbers for the operands if necessary + var opc = ctx.GetOpcode(idx); + var opcount = GetOpCount(opc); + if (opcount >= 1 && !valueNumbers.ContainsKey(ctx.GetOp0(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp0(idx))) + Format(sb, ctx, ctx.GetOp0(idx), valueNumbers); + if (opcount >= 2 && !valueNumbers.ContainsKey(ctx.GetOp1(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp1(idx))) + Format(sb, ctx, ctx.GetOp1(idx), valueNumbers); + + var op0 = () => $"{Lookup(ctx, ctx.GetOp0(idx), valueNumbers)}"; + var op1 = () => $"{Lookup(ctx, ctx.GetOp1(idx), valueNumbers)}"; + + var vNum = valueNumbers.Count; + valueNumbers.Add(idx, vNum); + var width = ctx.GetWidth(idx); + if (opc == AstOp.Symbol) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetSymbolName(idx)}"); + else if (opc == AstOp.Constant) + sb.AppendLine($"i{width} t{vNum} = {ctx.GetConstantValue(idx)}"); + else if (opc == AstOp.Neg) + sb.AppendLine($"i{width} t{vNum} = ~{op0()}"); + else if (opc == AstOp.Zext || opc == AstOp.Trunc) + { + sb.AppendLine($"i{width} t{vNum} = {GetOperatorName(opc)} i{ctx.GetWidth(ctx.GetOp0(idx))} {op0()} to i{width}"); + } + else + { + sb.AppendLine($"i{width} t{vNum} = {op0()} {GetOperatorName(opc)} {op1()}"); + } + } + + private static bool IsConstOrSymbol(AstCtx ctx, AstIdx idx) + => ctx.GetOpcode(idx) == AstOp.Constant || ctx.GetOpcode(idx) == AstOp.Symbol; + + private static string Lookup(AstCtx ctx, AstIdx idx, Dictionary valueNumbers) + { + var opc = ctx.GetOpcode(idx); + if (opc == AstOp.Constant) + return ctx.GetConstantValue(idx).ToString(); + if (opc == AstOp.Symbol) + return ctx.GetSymbolName(idx); + return $"t{valueNumbers[idx]}"; + } + + private static int GetOpCount(AstOp opc) + { + return opc switch + { + AstOp.None => 0, + AstOp.Add => 2, + AstOp.Mul => 2, + AstOp.Pow => 2, + AstOp.And => 2, + AstOp.Or => 2, + AstOp.Xor => 2, + AstOp.Neg => 1, + AstOp.Lshr => 2, + AstOp.Constant => 0, + AstOp.Symbol => 0, + AstOp.Zext => 1, + AstOp.Trunc => 1, + }; + } + + private static string GetOperatorName(AstOp opc) + { + return opc switch + { + AstOp.Add => "+", + AstOp.Mul => "*", + AstOp.Pow => "**", + AstOp.And => "&", + AstOp.Or => "|", + AstOp.Xor => "^", + AstOp.Neg => "~", + AstOp.Lshr => ">>", + AstOp.Zext => "zext", + AstOp.Trunc => "trunc", + _ => throw new InvalidOperationException(), + }; + } + } +} diff --git a/Simplifier/DatasetTester.cs b/Simplifier/DatasetTester.cs new file mode 100644 index 0000000..f010f27 --- /dev/null +++ b/Simplifier/DatasetTester.cs @@ -0,0 +1,125 @@ +using Mba.Simplifier.Bindings; +using Mba.Simplifier.Pipeline; +using Mba.Simplifier.Utility; +using Mba.Utility; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Reflection.Metadata; +using System.Text; +using System.Threading.Tasks; + +namespace Simplifier +{ + public static class DatasetTester + { + public static void Run() + { + Console.WriteLine(" "); + var lines = File.ReadLines("C:\\Users\\colton\\source\\repos\\mba-database\\real-world-nonlinear-full.txt"); + var beforeAndAfter = lines.Select(x => (x.Split(",")[0], x.Split(",")[1])).ToList(); + + var ctx = new AstCtx(); + AstIdx.ctx = ctx; + + var asts = beforeAndAfter.Select(x => (RustAstParser.Parse(ctx, x.Item1, 64), RustAstParser.Parse(ctx, x.Item2, 64))).ToList(); + + + Parallel.ForEach(asts, x => + { + + ProbableEquivalenceChecker.ProbablyEquivalentZ3(ctx, x.Item1, x.Item2); + } + ); + + foreach(var (before, after) in asts) + { + ProbableEquivalenceChecker.ProbablyEquivalentZ3(ctx, before, after); + } + + Debugger.Break(); + + foreach (var (strBefore_, strAfter) in beforeAndAfter) + { + var strBefore = strBefore_; + + //if (strBefore != "((((((((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))^(213:i64|(-214:i64&RSI:i64)))*2199023256422:i64)+(((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&(~((213:i64&RSI:i64)|(-214:i64^(-214:i64&RSI:i64)))))*7378699388702425784:i64))+(((((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&RSI:i64)^-1:i64)|213:i64)*-7378698289190797573:i64))+(7378697189679169362:i64*(~(5:i64&RSI:i64))))+(((-1:i64*(~((-209:i64&RSI:i64)|(208:i64^(208:i64&RSI:i64)))))+((4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))|RSI:i64))*3689348594839584681:i64))+((((-9223372036854775808:i64^((-9223372036854775803:i64&RSI:i64)|(9223372036854775802:i64^(9223372036854775802:i64&RSI:i64))))+(0:i64+(((-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))&RSI:i64)*-1:i64)))+(0:i64+(((4040198467629586909:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))&(~RSI:i64))*-1:i64)))*3689349694351212892:i64))") + // continue; + + + //strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //strBefore = "(7610965373738707464:i64+((((956575116354345:i64*(5292288:i64&RBX:i64))+(3689349694351212892:i64*(4611686018427387690:i64&RSI:i64)))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + // strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //strBefore = "(228698418667888:i64+((((((((3689349694351212892:i64*(4611686018427387690:i64&RSI:i64))+(3689348594839584681:i64*(5:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(691752243057131259:i64*(208:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(-3689348594839584681:i64*(5:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-691754442080387681:i64*(208:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "(228698418667888:i64+((((3689349694351212892:i64*(4611686018427387690:i64&RSI:i64))+(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64))))))+(-3689349694351212892:i64*(4611686018427387690:i64&(RSI:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))))+(-3689349694351212892:i64*(RSI:i64&(-4040198467629586910:i64+(72056494526299725:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))"; + + //strBefore = "-214&((4040198467629586696+(1099511628211*(5292288&RBX))))"; + + //strBefore = "(1099511628211:i64*(-214:i64&(4040198467629586696:i64+(1099511628211:i64*(5292288:i64&RBX:i64)))))"; + + //strBefore = "(-3689349694351212892*(4611686018427387690&(RSI&(4040198467629586696+(1099511628211*(5292288&RBX))))))+(-3689349694351212892*(RSI&(-4040198467629586910+(72056494526299725*(5292288&RBX)))))"; + + //if (strBefore != "(((-433557024052896108:i64+(46702856230664876:i64*(5292288:i64&RBX:i64)))+((((0:i64+((((7610965373738707464:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))|5370260760:i64)&-125:i64)*-1:i64))+((-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))^5370260736:i64))+(-7610965373738707565:i64+(71101018921573591:i64*(5292288:i64&RBX:i64))))*2199023256422:i64))+((((7610965373738707572:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))^5370260760:i64)+0:i64)*3298534884633:i64))") + // continue; + //strBefore = "(((-433557024052896108:i64+(46702856230664876:i64*(5292288:i64&RBX:i64)))+((((0:i64+((((7610965373738707464:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))|5370260760:i64)&-125:i64)*-1:i64))+((-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))^5370260736:i64))+(-7610965373738707565:i64+(71101018921573591:i64*(5292288:i64&RBX:i64))))*2199023256422:i64))+((((7610965373738707572:i64+(956575116354345:i64*(5292288:i64&RBX:i64)))^5370260760:i64)+0:i64)*3298534884633:i64))"; + + //strBefore = "(2342386684228996530:i64+((((23351428115332438:i64*(5292288:i64&RBX:i64))+(1099511628211:i64*(-5370260861:i64&(7610965373738707456:i64+(956575116354345:i64*(5292288:i64&RBX:i64))))))+(-3298534884633:i64*(5370260744:i64&(7610965373738707456:i64+(956575116354345:i64*(5292288:i64&RBX:i64))))))+(2199023256422:i64*(5370260736:i64^(-7610965373738707465:i64+(71101018921573591:i64*(5292288:i64&RBX:i64)))))))"; + + //strBefore = "((1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64)))&-214:i64)"; + //strBefore = "(-214:i64&(1099511628211:i64*(659114373011020351:i64|(5292288:i64&RBX:i64))))"; + + var before = RustAstParser.Parse(ctx, strBefore, 64); + + + + var after = RustAstParser.Parse(ctx, strAfter, 64); + + + + // if (ctx.GetAstString(before) != "((~(-72056494526299725:i64*(5631088361628047935:i64|(5292288:i64&RBX:i64))))&213:i64)") + // continue; + + + + var cls = ctx.GetClass(before); + var clsAfter = ctx.GetClass(after); + + var generalSimplifier = new GeneralSimplifier(ctx); + var simplified = generalSimplifier.SimplifyGeneral(before); + for (int i = 0; i < 10; i++) + { + generalSimplifier = new(ctx); + simplified = generalSimplifier.SimplifyGeneral(simplified); + } + + var kb = ctx.GetKnownBits(simplified); + var kb2 = ctx.GetKnownBits(before); + Console.WriteLine(kb.ToString() == kb2.ToString()); + + var r = LinearSimplifier.Run(ctx.GetWidth(before), ctx, before, false, true); + var r2 = LinearSimplifier.Run(ctx.GetWidth(simplified), ctx, simplified, false, true); + if (r != r2) + Debugger.Break(); + var rClass = ctx.GetClass(r); + var simplifiedClass = ctx.GetClass(simplified); + if (ctx.GetClass(after) != simplifiedClass) + Debugger.Break(); + //if (ctx.GetClass(before) == AstClassification.Nonlinear) + // Debugger.Break(); + //if (ctx.GetClass(simplified) != AstClassification.Nonlinear || !ctx.IsConstant(r))) + // continue; + + //Debugger.Break(); + } + + Debugger.Break(); + } + } +} diff --git a/Simplifier/Program.cs b/Simplifier/Program.cs index 882b645..8d1e0df 100644 --- a/Simplifier/Program.cs +++ b/Simplifier/Program.cs @@ -9,6 +9,7 @@ using Mba.Simplifier.Utility; using Mba.Utility; using Microsoft.Z3; +using Simplifier; using System.ComponentModel; using System.Diagnostics; @@ -18,6 +19,31 @@ bool proveEquivalence = true; string inputText = null; +//DatasetTester.Run(); + + +inputText = "((((1:i32&((uns17:i8 zx i32)&(~uns18:i32)))|(4294964010:i32&(~((uns17:i8 zx i32)|(~uns18:i32)))))|(4294964011:i32&((uns17:i8 zx i32)&uns18:i32)))|(4:i32*(1:i32&(uns19:i8 zx i32))))"; + +inputText = "((2041933603239772578:i64+((((((((((((((((-27487790705275:i64*uns121:i64)+(-9223358842715237276:i64*(-860922984064492326:i64&uns121:i64)))+(9223354444668724432:i64*uns131:i64))+(-9223350046622211588:i64*(860922984064492325:i64&uns131:i64)))+(-8796093025688:i64*uns132:i64))+(4398046512844:i64*uns34:i64))+(17592186051376:i64*uns65:i64))+(-3298534884633:i64*uns91:i64))+(9223367638808262964:i64*(8362449052790283482:i64&uns91:i64)))+(13194139538532:i64*(860922984064492325:i64&(uns121:i64&uns130:i64))))+(14293651166743:i64*(-3750763034362895579:i64&(uns121:i64&uns67:i64))))+(4398046512844:i64*(uns130:i64&uns133:i64)))+(-8796093025688:i64*(1444920025149201626:i64&(uns130:i64&uns91:i64))))+(-4398046512844:i64*(uns131:i64&uns133:i64)))+(-9223350046622211588:i64*(3750763034362895578:i64&(uns131:i64&uns91:i64))))+(-9895604653899:i64*((uns121:i64&uns130:i64)&uns91:i64))))+(3062923494603851298:i64+(((((((-9895604653899:i64*uns130:i64)+(9895604653899:i64*(-3750763034362895579:i64&uns131:i64)))+(9895604653899:i64*(3750763034362895578:i64&uns17:i64)))+(-9895604653899:i64*(-3750763034362895579:i64&(uns121:i64&uns131:i64))))+(9895604653899:i64*(uns130:i64&uns134:i64)))+(-9895604653899:i64*(uns131:i64&uns134:i64)))+(9895604653899:i64*(-3750763034362895579:i64&(uns131:i64&uns91:i64))))))"; + +inputText = "((-1:i64*(~((((8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64))))&((~uns4:i64)^(((((~uns4:i64)&((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))|(((~uns4:i64)&((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&uns16:i64)))|(255:i64&(uns16:i64&((((~uns4:i64)&(~((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&(~((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64)))&(~(8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))))))|(-256:i64&(((uns16:i64&(8614007388540201639:i64+(((((6919028725695267695:i64*(183:i64&uns16:i64))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))&((11175:i16+(((((52335:i16*(183:i16&(uns16:i64 tr i16)))+(5265:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))&(((~uns4:i64)&((167:i16+(((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))|((~uns4:i64)&((167:i16+(((((111:i16*(183:i16&(uns16:i64 tr i16)))+(145:i16*(72:i16&(uns16:i64 tr i16))))+(256:i16*(uns22:i64 tr i16)))+(104:i16*((65451:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))+(65432:i16*((171:i16*(~((183:i16&(uns16:i64 tr i16))|(72:i16^(72:i16&(uns16:i64 tr i16))))))&(uns22:i64 tr i16))))) zx i64))))))))+(-1:i64*(((~uns16:i64)&(8614011786586714483:i64+(((((((4398046512844:i64*(5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64))))))+(4398046512844:i64*(-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))))+(6919028725695267695:i64*(183:i64&uns16:i64)))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))))|(uns4:i64&(8614011786586714483:i64+(((((((4398046512844:i64*(5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64))))))+(4398046512844:i64*(-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))))+(6919028725695267695:i64*(183:i64&uns16:i64)))+(2304343311159508113:i64*(72:i64&uns16:i64)))+(8796093025688:i64*uns22:i64))+(-8796093025688:i64*((5156503906449953109:i64+(-624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64)))+(-8796093025688:i64*((-5156503906449953110:i64+(624165263380053675:i64*((183:i64&uns16:i64)|(72:i64^(72:i64&uns16:i64)))))&uns22:i64))))))))"; + +inputText = "(((-1099511628211:i64*((uns173:i64&(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))|(uns174:i64&(~(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))))))+(uns175:i64*(-2:i64+(-1:i64*uns174:i64))))+(2199023256422:i64*((((-1:i64*(-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64)))))))+(-1:i64*uns173:i64))+(2:i64*((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns173:i64)))+((-46179488384862:i64+(((((((((((((((((((((-3298534884633:i64*uns158:i64)+(8796093025688:i64*uns159:i64))+(-4398046512844:i64*uns160:i64))+(-4398046512844:i64*(uns158:i64&uns159:i64)))+(17592186051376:i64*(uns159:i64&uns160:i64)))+(-19791209307798:i64*(uns159:i64&uns162:i64)))+(-21990232564220:i64*(uns159:i64&uns165:i64)))+(-13194139538532:i64*(uns160:i64&uns167:i64)))+(-14293651166743:i64*(uns167:i64&uns168:i64)))+(-8796093025688:i64*((uns158:i64&uns159:i64)&uns164:i64)))+(21990232564220:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns164:i64)&uns167:i64)))+(-13194139538532:i64*((uns159:i64&uns160:i64)&uns164:i64)))+(13194139538532:i64*((uns160:i64&uns164:i64)&uns167:i64)))+(-8796093025688:i64*uns166:i64))+(-4398046512844:i64*uns163:i64))+(-17592186051376:i64*uns161:i64))+(-4398046512844:i64*((uns169:i64&(~uns164:i64))|((~uns169:i64)&(~uns165:i64)))))+(-9895604653899:i64*(~uns171:i64)))+(9895604653899:i64+((((9895604653899:i64*(uns158:i64&uns165:i64))+(9895604653899:i64*(uns160:i64&uns165:i64)))+(-9895604653899:i64*((uns158:i64&uns159:i64)&uns165:i64)))+(-9895604653899:i64*((uns160:i64&uns165:i64)&uns167:i64)))))+(9895604653899:i64*(((~uns170:i64)&(~uns164:i64))|(uns170:i64&(~uns165:i64))))))&uns174:i64))))"; + +inputText = "(2374945116151681 + 1152921504606846706*(x&8796093022192)) + (-2383741209174193 + 271*(x&8796093022192))"; + +inputText = "(3*(x&0x7FFFFFFFFF0) - 0x180000000490) - 0x100000000350 + 0x178"; + +inputText = "(-43980465112680+(3*(-8796093022208|(8796093022192&(x&~15)))))"; + +inputText = "((0xFFFFFFFFFFFFFE71 * ~(~a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0)))) + (~a4 & (v37 + 0x570)) - ((~a4 & (v37 + 0x570)) | a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0)))) + (0xFFFFFFFFFFFFFE70 * (~a4 & (0x8000000023F - (a1 & 0x7FFFFFFFFF0))))"; + +inputText = "((2 * (a1 & 0x7FFFFFFFFF0)) - 0x100000000350 + 0x178)"; + +inputText = " (2 * (x & 0x7FFFFFFFFF0)) - 0x100000000350 + 0x178"; + +inputText = "-3 * ~e_cr3 + (mask ^ e_cr3) + -2 * (mask ^ (mask | e_cr3)) + 2 * (((byte_1400807D0 & 0x10 | 0x3F71D992FBB2CCEB) ^ 0xC08E266D044D3314) + (~e_cr3 | (mask ^ 0x3F71D992FBB2CCEB))) - (mask ^ ~e_cr3 ^ 0x3F71D992FBB2CCEB) - 0x3F71D992FBB2CCEB"; + var printHelp = () => { Console.WriteLine("Usage: Simplifier.exe"); @@ -72,6 +98,21 @@ Console.WriteLine($"\nExpression: {ctx.GetAstString(id)}\n\n\n"); +Console.WriteLine(DagFormatter.Format(ctx, id)); + + +var bx = LinearSimplifier.Run(bitWidth, ctx, id, false, true); + +while(false) +{ + var simplifier = new GeneralSimplifier(ctx); + + var sw = Stopwatch.StartNew(); + var r = simplifier.SimplifyGeneral(id); + sw.Stop(); + Console.WriteLine($"Took {sw.ElapsedMilliseconds}ms"); +} + var input = id; id = ctx.RecursiveSimplify(id); for (int i = 0; i < 3; i++)