From 8521e62fa5b1b076468ee3e12e9f4fe0d7eb5504 Mon Sep 17 00:00:00 2001 From: Yue Huang Date: Tue, 21 Oct 2025 17:03:02 +0100 Subject: [PATCH] [MLIR][Presburger] Fix Gaussian elimination --- mlir/lib/Analysis/Presburger/Barvinok.cpp | 14 ++++++------- .../Analysis/Presburger/IntegerRelation.cpp | 21 +++++++++++++++++-- .../Analysis/Presburger/BarvinokTest.cpp | 7 +++++++ .../Presburger/IntegerRelationTest.cpp | 15 +++++++++++++ 4 files changed, 48 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp index 75d592e976edf..c31b27794f01e 100644 --- a/mlir/lib/Analysis/Presburger/Barvinok.cpp +++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp @@ -178,13 +178,13 @@ mlir::presburger::detail::solveParametricEquations(FracMatrix equations) { for (unsigned i = 0; i < d; ++i) { // First ensure that the diagonal element is nonzero, by swapping // it with a row that is non-zero at column i. - if (equations(i, i) != 0) - continue; - for (unsigned j = i + 1; j < d; ++j) { - if (equations(j, i) == 0) - continue; - equations.swapRows(j, i); - break; + if (equations(i, i) == 0) { + for (unsigned j = i + 1; j < d; ++j) { + if (equations(j, i) == 0) + continue; + equations.swapRows(j, i); + break; + } } Fraction diagElement = equations(i, i); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 0dcdd5bb97bc8..e81fa7d568eb2 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1111,15 +1111,28 @@ unsigned IntegerRelation::gaussianEliminateVars(unsigned posStart, return posLimit - posStart; } +static std::optional +findEqualityWithNonZeroAfterRow(IntegerRelation &rel, unsigned fromRow, + unsigned colIdx) { + assert(fromRow < rel.getNumVars() && colIdx < rel.getNumCols() && + "position out of bounds"); + for (unsigned rowIdx = fromRow; rowIdx < rel.getNumEqualities(); ++rowIdx) { + if (rel.atEq(rowIdx, colIdx) != 0) + return rowIdx; + } + return std::nullopt; +} + bool IntegerRelation::gaussianEliminate() { gcdTightenInequalities(); unsigned firstVar = 0, vars = getNumVars(); unsigned nowDone, eqs; std::optional pivotRow; for (nowDone = 0, eqs = getNumEqualities(); nowDone < eqs; ++nowDone) { - // Finds the first non-empty column. + // Finds the first non-empty column that we haven't dealt with. for (; firstVar < vars; ++firstVar) { - if ((pivotRow = findConstraintWithNonZeroAt(firstVar, /*isEq=*/true))) + if ((pivotRow = + findEqualityWithNonZeroAfterRow(*this, nowDone, firstVar))) break; } // The matrix has been normalized to row echelon form. @@ -1142,6 +1155,10 @@ bool IntegerRelation::gaussianEliminate() { inequalities.normalizeRow(i); } gcdTightenInequalities(); + + // The column is finished. Tell the next iteration to start at the next + // column. + firstVar++; } // No redundant rows. diff --git a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp index eaf04379cb529..d687a0072a158 100644 --- a/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp +++ b/mlir/unittests/Analysis/Presburger/BarvinokTest.cpp @@ -301,3 +301,10 @@ TEST(BarvinokTest, computeNumTermsPolytope) { gf = count[0].second; EXPECT_EQ(gf.getNumerators().size(), 24u); } + +TEST(BarvinokTest, solveParametricEquations) { + FracMatrix equations = makeFracMatrix(2, 3, {{2, 3, -4}, {2, 6, -7}}); + FracMatrix solution = *solveParametricEquations(equations); + EXPECT_EQ(solution.at(0, 0), Fraction(1, 2)); + EXPECT_EQ(solution.at(1, 0), 1); +} diff --git a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp index 9ae90a4841f3c..dace588e80153 100644 --- a/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp +++ b/mlir/unittests/Analysis/Presburger/IntegerRelationTest.cpp @@ -725,3 +725,18 @@ TEST(IntegerRelationTest, addLocalModulo) { EXPECT_TRUE(rel.containsPointNoLocal({x, x % 32})); } } + +TEST(IntegerRelationTest, simplify) { + IntegerRelation rel = + parseRelationFromSet("(x, y)[N]: (2*x + y - 4*N - 3 == 0, 3*x - y - 3*N" + "+ 2 == 0, x + 3*y - 5*N - 8 == 0, x - y + N >= 0)", + 2); + IntegerRelation simplified = parseRelationFromSet( + "(x, y)[N]: (2*x + y - 4*N - 3 == 0, -5*y + 6*N + 13 == 0, N - 2 >= 0)", + 2); + rel.simplify(); + + EXPECT_TRUE(rel.isEqual(simplified)); + // The third equality is redundant and should be removed. + EXPECT_TRUE(rel.getNumEqualities() == 2); +}