Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion EqSat/src/simple_ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct AstIdx(pub u32);
pub struct Arena {
pub elements: Vec<(SimpleAst, AstData)>,
ast_to_idx: AHashMap<SimpleAst, AstIdx>,
isle_cache: AHashMap<AstIdx, AstIdx>,

// Map a name to it's corresponds symbol index.
symbol_ids: Vec<(String, AstIdx)>,
Expand All @@ -37,13 +38,15 @@ impl Arena {
pub fn new() -> Self {
let elements = Vec::with_capacity(65536);
let ast_to_idx = AHashMap::with_capacity(65536);
let isle_cache = AHashMap::with_capacity(65536);

let symbol_ids = Vec::with_capacity(255);
let name_to_symbol = AHashMap::with_capacity(255);

Arena {
elements: elements,
ast_to_idx: ast_to_idx,
isle_cache: isle_cache,

symbol_ids: symbol_ids,
name_to_symbol: name_to_symbol,
Expand Down Expand Up @@ -813,6 +816,9 @@ pub fn eval_ast(ctx: &Context, idx: AstIdx, value_mapping: &HashMap<AstIdx, u64>

// Recursively apply ISLE over an AST.
pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx {
if ctx.arena.isle_cache.get(&idx).is_some() {
return *ctx.arena.isle_cache.get(&idx).unwrap();
}
let mut ast = ctx.arena.get_node(idx).clone();

match ast {
Expand Down Expand Up @@ -862,7 +868,9 @@ pub fn recursive_simplify(ctx: &mut Context, idx: AstIdx) -> AstIdx {
ast = result.unwrap();
}

return ctx.arena.ast_to_idx[&ast];
let result = ctx.arena.ast_to_idx[&ast];
ctx.arena.isle_cache.insert(idx, result);
result
}

// Evaluate the current AST for all possible combinations of zeroes and ones as inputs.
Expand Down
5 changes: 5 additions & 0 deletions Mba.Simplifier/Bindings/AstIdx.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ public override string ToString()
return ctx.GetAstString(Idx);
}

public override int GetHashCode()
{
return Idx.GetHashCode();
}

public unsafe static implicit operator uint(AstIdx reg) => reg.Idx;

public unsafe static implicit operator AstIdx(uint reg) => new AstIdx(reg);
Expand Down
2 changes: 1 addition & 1 deletion Mba.Simplifier/Minimization/BooleanMinimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ private static AstIdx MinimizeAnf(AstCtx ctx, IReadOnlyList<AstIdx> variables, T
}

var r = ctx.MinimizeAnf(TableDatabase.Instance.db, truthTable, tempVars, MultibitSiMBA.JitPage.Value);
var backSubst = GeneralSimplifier.ApplyBackSubstitution(ctx, r, invSubstMapping);
var backSubst = GeneralSimplifier.BackSubstitute(ctx, r, invSubstMapping);
return backSubst;
}
}
Expand Down
335 changes: 210 additions & 125 deletions Mba.Simplifier/Pipeline/GeneralSimplifier.cs

Large diffs are not rendered by default.

37 changes: 30 additions & 7 deletions Mba.Simplifier/Pipeline/LinearSimplifier.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public class LinearSimplifier
private readonly bool tryDecomposeMultiBitBases;

private readonly Action<ulong[], ulong>? resultVectorHook;

private readonly int depth;
private readonly ApInt moduloMask = 0;

// Number of combinations of input variables(2^n), for a single bit index.
Expand All @@ -69,14 +69,14 @@ public class LinearSimplifier

private AstIdx? initialInput = null;

public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList<AstIdx> variables = null, Action<ulong[], ApInt>? resultVectorHook = null, ApInt[] inVec = null)
public static AstIdx Run(uint bitSize, AstCtx ctx, AstIdx? ast, bool alreadySplit = false, bool multiBit = false, bool tryDecomposeMultiBitBases = false, IReadOnlyList<AstIdx> variables = null, Action<ulong[], ApInt>? resultVectorHook = null, ApInt[] inVec = null, int depth = 0)
{
if (variables == null)
variables = ctx.CollectVariables(ast.Value);
return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec).Simplify(false, alreadySplit);
return new LinearSimplifier(ctx, ast, variables, bitSize, refine: true, multiBit, tryDecomposeMultiBitBases, resultVectorHook, inVec, depth).Simplify(false, alreadySplit);
}

public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList<AstIdx> variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action<ulong[], ApInt>? resultVectorHook = null, ApInt[] inVec = null)
public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList<AstIdx> variables, uint bitSize, bool refine = true, bool multiBit = false, bool tryDecomposeMultiBitBases = true, Action<ulong[], ApInt>? resultVectorHook = null, ApInt[] inVec = null, int depth = 0)
{
// If we are given an AST, verify that the correct width was passed.
if (ast != null && bitSize != ctx.GetWidth(ast.Value))
Expand All @@ -90,6 +90,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList<AstIdx> variables
this.multiBit = multiBit;
this.tryDecomposeMultiBitBases = tryDecomposeMultiBitBases;
this.resultVectorHook = resultVectorHook;
this.depth = depth;
moduloMask = (ApInt)ModuloReducer.GetMask(bitSize);
groupSizes = GetGroupSizes(variables.Count);
numCombinations = (ApInt)Math.Pow(2, variables.Count);
Expand Down Expand Up @@ -124,7 +125,7 @@ public LinearSimplifier(AstCtx ctx, AstIdx? ast, IReadOnlyList<AstIdx> variables
}
}

private static IReadOnlyList<AstIdx> CastVariables(AstCtx ctx, IReadOnlyList<AstIdx> variables, uint bitSize)
public static IReadOnlyList<AstIdx> CastVariables(AstCtx ctx, IReadOnlyList<AstIdx> variables, uint bitSize)
{
// If all variables are of a correct size, no casting is necessary.
if (!variables.Any(x => ctx.GetWidth(x) != bitSize))
Expand Down Expand Up @@ -610,6 +611,20 @@ private AstIdx FindTwoTermsUnnegated(ApInt constant, ApInt a, ApInt b)

private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demandedMask, ApInt[] variableCombinations, List<List<(ApInt coeff, ApInt bitMask)>> linearCombinations)
{

var vNames = this.variables.Select(x => ctx.GetAstString(x));
//var expected = new List<string>() { "subst594:i32", "subst595:i32", "subst596:i32", "subst597:i32", "(uns45:i64 tr i32)", "(uns48:i64 tr i32)"};

/*
var expected = new List<string>() { "subst27:i32", "(uns41:i64 tr i32)", "(uns74:i64 tr i32)", "(uns75:i64 tr i32)" };
if (expected.All(x => vNames.Any(y => y.Contains(x))))
Debugger.Break();

if (depth > 10)
Debugger.Break();
*/


// Collect all variables used in the output expression.
List<AstIdx> mutVars = new(variables.Count);
while (demandedMask != 0)
Expand All @@ -619,6 +634,8 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded
demandedMask &= ~(1ul << xorIdx);
}


var clone = variables.ToList();
AstIdx sum = ctx.Constant(constantOffset, width);
for (int i = 0; i < linearCombinations.Count; i++)
{
Expand All @@ -628,14 +645,20 @@ private AstIdx EliminateDeadVarsAndSimplify(ApInt constantOffset, ApInt demanded
continue;

var combMask = variableCombinations[i];
var vComb = ctx.GetConjunctionFromVarMask(mutVars, combMask);
var widths = variables.Select(x => ctx.GetWidth(x)).ToList();
//Console.WriteLine(widths.Distinct().Count());
//foreach (var vIdx in variables)
// Console.WriteLine($"{ctx.GetAstString(vIdx)} => {ctx.GetWidth(vIdx)}");
//Console.WriteLine("\n\n");

var vComb = ctx.GetConjunctionFromVarMask(clone, combMask);
var term = Term(vComb, curr[0].coeff);
sum = ctx.Add(sum, term);
}

// TODO: Instead of constructing a result vector inside the recursive linear simplifier call, we could instead convert the ANF vector back to DNF.
// This should be much more efficient than constructing a result vector via JITing and evaluating an AST representation of the ANF vector.
return LinearSimplifier.Run(width, ctx, sum, false, false, false, variables);
return LinearSimplifier.Run(width, ctx, sum, false, false, false, mutVars, depth: depth + 1);
}

private void EliminateUniqueValues(Dictionary<ApInt, TruthTable> coeffToTable)
Expand Down
77 changes: 48 additions & 29 deletions Mba.Simplifier/Pipeline/ProbableEquivalenceChecker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using Microsoft.Z3;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
Expand Down Expand Up @@ -59,11 +60,11 @@ public ProbableEquivalenceChecker(AstCtx ctx, List<AstIdx> variables, AstIdx bef
public unsafe bool ProbablyEquivalent(bool slowHeuristics = false)
{
var jit1 = new Amd64OptimizingJit(ctx);
jit1.Compile(before, variables, pagePtr1, true);
jit1.Compile(before, variables, pagePtr1, false);
func1 = (delegate* unmanaged[SuppressGCTransition]<ulong*, ulong>)pagePtr1;

var jit2 = new Amd64OptimizingJit(ctx);
jit2.Compile(after, variables, pagePtr2, true);
jit2.Compile(after, variables, pagePtr2, false);
func2 = (delegate* unmanaged[SuppressGCTransition]<ulong*, ulong>)pagePtr2;

var vArray = stackalloc ulong[variables.Count];
Expand Down Expand Up @@ -99,6 +100,7 @@ public unsafe bool ProbablyEquivalent(bool slowHeuristics = false)
return true;
}

public static bool Log = false;
private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses)
{
var clone = new ulong[variables.Count];
Expand All @@ -113,6 +115,10 @@ private unsafe bool RandomlyEquivalent(ulong* vArray, int numGuesses)

var op1 = func1(vArray);
var op2 = func2(vArray);

if(Log)
Console.WriteLine($"{op1}, {op2}");

if (op1 != op2)
return false;
}
Expand All @@ -128,6 +134,8 @@ private unsafe bool AllCombs(ulong* vArray, ulong a, ulong b)
return false;
if (!SignatureVectorEquivalent(vArray, a, b))
return false;
if (!SignatureVectorEquivalent(vArray, b, a))
return false;

return true;
}
Expand Down Expand Up @@ -168,34 +176,45 @@ private ulong Next()

public static void ProbablyEquivalentZ3(AstCtx ctx, AstIdx before, AstIdx after)
{
var z3Ctx = new Context();
var translator = new Z3Translator(ctx, z3Ctx);
var beforeZ3 = translator.Translate(before);
var afterZ3 = translator.Translate(after);
var solver = z3Ctx.MkSolver("QF_BV");

// Set the maximum timeout to 10 seconds.
var p = z3Ctx.MkParams();
uint solverLimit = 10000;
p.Add("timeout", solverLimit);
solver.Parameters = p;

Console.WriteLine("Proving equivalence...\n");
solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3)));
var check = solver.Check();

var printModel = (Model model) =>
using (var z3Ctx = new Context())
{
var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}");
return $"[{String.Join(", ", values)}]";
};

if (check == Status.UNSATISFIABLE)
Console.WriteLine("Expressions are equivalent.");
else if (check == Status.SATISFIABLE)
Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}");
else
Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms");
var translator = new Z3Translator(ctx, z3Ctx);
var beforeZ3 = translator.Translate(before);
var afterZ3 = translator.Translate(after);
var solver = z3Ctx.MkSolver("QF_BV");

// Set the maximum timeout to 10 seconds.
var p = z3Ctx.MkParams();
uint solverLimit = 5000;
p.Add("timeout", solverLimit);
solver.Parameters = p;

Console.WriteLine("Proving equivalence...\n");
solver.Add(z3Ctx.MkNot(z3Ctx.MkEq(beforeZ3, afterZ3)));
var check = solver.Check();

var printModel = (Model model) =>
{
var values = model.Consts.Select(x => $"{x.Key.Name} = {(long)ulong.Parse(model.Eval(x.Value).ToString())}");
return $"[{String.Join(", ", values)}]";
};

if (check == Status.UNSATISFIABLE)
{
//Console.WriteLine("Expressions are equivalent.");
}
else if (check == Status.SATISFIABLE)
{
Console.WriteLine($"Expressions are not equivalent. Counterexample:\n{printModel(solver.Model)}");
Debugger.Break();
throw new InvalidOperationException();

}
else
{
//Console.WriteLine($"Solver timed out - expressions are probably equivalent. Could not find counterexample within {solverLimit}ms");
}
}
}

}
Expand Down
102 changes: 102 additions & 0 deletions Mba.Simplifier/Utility/DagFormatter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
using Mba.Simplifier.Bindings;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Mba.Simplifier.Utility
{
public static class DagFormatter
{
public static string Format(AstCtx ctx, AstIdx idx)
{
var sb = new StringBuilder();
Format(sb, ctx, idx, new());
return sb.ToString();
}

private static void Format(StringBuilder sb, AstCtx ctx, AstIdx idx, Dictionary<AstIdx, int> valueNumbers)
{
// Allocate value numbers for the operands if necessary
var opc = ctx.GetOpcode(idx);
var opcount = GetOpCount(opc);
if (opcount >= 1 && !valueNumbers.ContainsKey(ctx.GetOp0(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp0(idx)))
Format(sb, ctx, ctx.GetOp0(idx), valueNumbers);
if (opcount >= 2 && !valueNumbers.ContainsKey(ctx.GetOp1(idx)) && !IsConstOrSymbol(ctx, ctx.GetOp1(idx)))
Format(sb, ctx, ctx.GetOp1(idx), valueNumbers);

var op0 = () => $"{Lookup(ctx, ctx.GetOp0(idx), valueNumbers)}";
var op1 = () => $"{Lookup(ctx, ctx.GetOp1(idx), valueNumbers)}";

var vNum = valueNumbers.Count;
valueNumbers.Add(idx, vNum);
var width = ctx.GetWidth(idx);
if (opc == AstOp.Symbol)
sb.AppendLine($"i{width} t{vNum} = {ctx.GetSymbolName(idx)}");
else if (opc == AstOp.Constant)
sb.AppendLine($"i{width} t{vNum} = {ctx.GetConstantValue(idx)}");
else if (opc == AstOp.Neg)
sb.AppendLine($"i{width} t{vNum} = ~{op0()}");
else if (opc == AstOp.Zext || opc == AstOp.Trunc)
{
sb.AppendLine($"i{width} t{vNum} = {GetOperatorName(opc)} i{ctx.GetWidth(ctx.GetOp0(idx))} {op0()} to i{width}");
}
else
{
sb.AppendLine($"i{width} t{vNum} = {op0()} {GetOperatorName(opc)} {op1()}");
}
}

private static bool IsConstOrSymbol(AstCtx ctx, AstIdx idx)
=> ctx.GetOpcode(idx) == AstOp.Constant || ctx.GetOpcode(idx) == AstOp.Symbol;

private static string Lookup(AstCtx ctx, AstIdx idx, Dictionary<AstIdx, int> valueNumbers)
{
var opc = ctx.GetOpcode(idx);
if (opc == AstOp.Constant)
return ctx.GetConstantValue(idx).ToString();
if (opc == AstOp.Symbol)
return ctx.GetSymbolName(idx);
return $"t{valueNumbers[idx]}";
}

private static int GetOpCount(AstOp opc)
{
return opc switch
{
AstOp.None => 0,
AstOp.Add => 2,
AstOp.Mul => 2,
AstOp.Pow => 2,
AstOp.And => 2,
AstOp.Or => 2,
AstOp.Xor => 2,
AstOp.Neg => 1,
AstOp.Lshr => 2,
AstOp.Constant => 0,
AstOp.Symbol => 0,
AstOp.Zext => 1,
AstOp.Trunc => 1,
};
}

private static string GetOperatorName(AstOp opc)
{
return opc switch
{
AstOp.Add => "+",
AstOp.Mul => "*",
AstOp.Pow => "**",
AstOp.And => "&",
AstOp.Or => "|",
AstOp.Xor => "^",
AstOp.Neg => "~",
AstOp.Lshr => ">>",
AstOp.Zext => "zext",
AstOp.Trunc => "trunc",
_ => throw new InvalidOperationException(),
};
}
}
}
Loading