1//===- Utils.cpp - General utilities for Presburger library ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Utility functions required by the Presburger Library.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Analysis/Presburger/Utils.h"
14#include "mlir/Analysis/Presburger/IntegerRelation.h"
15#include "mlir/Analysis/Presburger/MPInt.h"
16#include "mlir/Analysis/Presburger/PresburgerSpace.h"
17#include "mlir/Support/LLVM.h"
18#include "mlir/Support/LogicalResult.h"
19#include "llvm/ADT/STLFunctionalExtras.h"
20#include "llvm/ADT/SmallBitVector.h"
21#include "llvm/Support/raw_ostream.h"
22#include <algorithm>
23#include <cassert>
24#include <cstddef>
25#include <cstdint>
26#include <functional>
27#include <numeric>
28
29#include <numeric>
30#include <optional>
31
32using namespace mlir;
33using namespace presburger;
34
35/// Normalize a division's `dividend` and the `divisor` by their GCD. For
36/// example: if the dividend and divisor are [2,0,4] and 4 respectively,
37/// they get normalized to [1,0,2] and 2. The divisor must be non-negative;
38/// it is allowed for the divisor to be zero, but nothing is done in this case.
39static void normalizeDivisionByGCD(MutableArrayRef<MPInt> dividend,
40 MPInt &divisor) {
41 assert(divisor > 0 && "divisor must be non-negative!");
42 if (divisor == 0 || dividend.empty())
43 return;
44 // We take the absolute value of dividend's coefficients to make sure that
45 // `gcd` is positive.
46 MPInt gcd = presburger::gcd(a: abs(x: dividend.front()), b: divisor);
47
48 // The reason for ignoring the constant term is as follows.
49 // For a division:
50 // floor((a + m.f(x))/(m.d))
51 // It can be replaced by:
52 // floor((floor(a/m) + f(x))/d)
53 // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not
54 // influence the result of the floor division and thus, can be ignored.
55 for (size_t i = 1, m = dividend.size() - 1; i < m; i++) {
56 gcd = presburger::gcd(a: abs(x: dividend[i]), b: gcd);
57 if (gcd == 1)
58 return;
59 }
60
61 // Normalize the dividend and the denominator.
62 std::transform(first: dividend.begin(), last: dividend.end(), result: dividend.begin(),
63 unary_op: [gcd](MPInt &n) { return floorDiv(lhs: n, rhs: gcd); });
64 divisor /= gcd;
65}
66
67/// Check if the pos^th variable can be represented as a division using upper
68/// bound inequality at position `ubIneq` and lower bound inequality at position
69/// `lbIneq`.
70///
71/// Let `var` be the pos^th variable, then `var` is equivalent to
72/// `expr floordiv divisor` if there are constraints of the form:
73/// 0 <= expr - divisor * var <= divisor - 1
74/// Rearranging, we have:
75/// divisor * var - expr + (divisor - 1) >= 0 <-- Lower bound for 'var'
76/// -divisor * var + expr >= 0 <-- Upper bound for 'var'
77///
78/// For example:
79/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k'
80/// 32*k <= 16*i + j <-- Upper bound for 'k'
81/// expr = 16*i + j, divisor = 32
82/// k = ( 16*i + j ) floordiv 32
83///
84/// 4q >= i + j - 2 <-- Lower bound for 'q'
85/// 4q <= i + j + 1 <-- Upper bound for 'q'
86/// expr = i + j + 1, divisor = 4
87/// q = (i + j + 1) floordiv 4
88//
89/// This function also supports detecting divisions from bounds that are
90/// strictly tighter than the division bounds described above, since tighter
91/// bounds imply the division bounds. For example:
92/// 4q - i - j + 2 >= 0 <-- Lower bound for 'q'
93/// -4q + i + j >= 0 <-- Tight upper bound for 'q'
94///
95/// To extract floor divisions with tighter bounds, we assume that the
96/// constraints are of the form:
97/// c <= expr - divisior * var <= divisor - 1, where 0 <= c <= divisor - 1
98/// Rearranging, we have:
99/// divisor * var - expr + (divisor - 1) >= 0 <-- Lower bound for 'var'
100/// -divisor * var + expr - c >= 0 <-- Upper bound for 'var'
101///
102/// If successful, `expr` is set to dividend of the division and `divisor` is
103/// set to the denominator of the division, which will be positive.
104/// The final division expression is normalized by GCD.
105static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
106 unsigned ubIneq, unsigned lbIneq,
107 MutableArrayRef<MPInt> expr, MPInt &divisor) {
108
109 assert(pos <= cst.getNumVars() && "Invalid variable position");
110 assert(ubIneq <= cst.getNumInequalities() &&
111 "Invalid upper bound inequality position");
112 assert(lbIneq <= cst.getNumInequalities() &&
113 "Invalid upper bound inequality position");
114 assert(expr.size() == cst.getNumCols() && "Invalid expression size");
115 assert(cst.atIneq(lbIneq, pos) > 0 && "lbIneq is not a lower bound!");
116 assert(cst.atIneq(ubIneq, pos) < 0 && "ubIneq is not an upper bound!");
117
118 // Extract divisor from the lower bound.
119 divisor = cst.atIneq(i: lbIneq, j: pos);
120
121 // First, check if the constraints are opposite of each other except the
122 // constant term.
123 unsigned i = 0, e = 0;
124 for (i = 0, e = cst.getNumVars(); i < e; ++i)
125 if (cst.atIneq(i: ubIneq, j: i) != -cst.atIneq(i: lbIneq, j: i))
126 break;
127
128 if (i < e)
129 return failure();
130
131 // Then, check if the constant term is of the proper form.
132 // Due to the form of the upper/lower bound inequalities, the sum of their
133 // constants is `divisor - 1 - c`. From this, we can extract c:
134 MPInt constantSum = cst.atIneq(i: lbIneq, j: cst.getNumCols() - 1) +
135 cst.atIneq(i: ubIneq, j: cst.getNumCols() - 1);
136 MPInt c = divisor - 1 - constantSum;
137
138 // Check if `c` satisfies the condition `0 <= c <= divisor - 1`.
139 // This also implictly checks that `divisor` is positive.
140 if (!(0 <= c && c <= divisor - 1)) // NOLINT
141 return failure();
142
143 // The inequality pair can be used to extract the division.
144 // Set `expr` to the dividend of the division except the constant term, which
145 // is set below.
146 for (i = 0, e = cst.getNumVars(); i < e; ++i)
147 if (i != pos)
148 expr[i] = cst.atIneq(i: ubIneq, j: i);
149
150 // From the upper bound inequality's form, its constant term is equal to the
151 // constant term of `expr`, minus `c`. From this,
152 // constant term of `expr` = constant term of upper bound + `c`.
153 expr.back() = cst.atIneq(i: ubIneq, j: cst.getNumCols() - 1) + c;
154 normalizeDivisionByGCD(dividend: expr, divisor);
155
156 return success();
157}
158
159/// Check if the pos^th variable can be represented as a division using
160/// equality at position `eqInd`.
161///
162/// For example:
163/// 32*k == 16*i + j - 31 <-- `eqInd` for 'k'
164/// expr = 16*i + j - 31, divisor = 32
165/// k = (16*i + j - 31) floordiv 32
166///
167/// If successful, `expr` is set to dividend of the division and `divisor` is
168/// set to the denominator of the division. The final division expression is
169/// normalized by GCD.
170static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
171 unsigned eqInd, MutableArrayRef<MPInt> expr,
172 MPInt &divisor) {
173
174 assert(pos <= cst.getNumVars() && "Invalid variable position");
175 assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
176 assert(expr.size() == cst.getNumCols() && "Invalid expression size");
177
178 // Extract divisor, the divisor can be negative and hence its sign information
179 // is stored in `signDiv` to reverse the sign of dividend's coefficients.
180 // Equality must involve the pos-th variable and hence `tempDiv` != 0.
181 MPInt tempDiv = cst.atEq(i: eqInd, j: pos);
182 if (tempDiv == 0)
183 return failure();
184 int signDiv = tempDiv < 0 ? -1 : 1;
185
186 // The divisor is always a positive integer.
187 divisor = tempDiv * signDiv;
188
189 for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
190 if (i != pos)
191 expr[i] = -signDiv * cst.atEq(i: eqInd, j: i);
192
193 expr.back() = -signDiv * cst.atEq(i: eqInd, j: cst.getNumCols() - 1);
194 normalizeDivisionByGCD(dividend: expr, divisor);
195
196 return success();
197}
198
199// Returns `false` if the constraints depends on a variable for which an
200// explicit representation has not been found yet, otherwise returns `true`.
201static bool checkExplicitRepresentation(const IntegerRelation &cst,
202 ArrayRef<bool> foundRepr,
203 ArrayRef<MPInt> dividend,
204 unsigned pos) {
205 // Exit to avoid circular dependencies between divisions.
206 for (unsigned c = 0, e = cst.getNumVars(); c < e; ++c) {
207 if (c == pos)
208 continue;
209
210 if (!foundRepr[c] && dividend[c] != 0) {
211 // Expression can't be constructed as it depends on a yet unknown
212 // variable.
213 //
214 // TODO: Visit/compute the variables in an order so that this doesn't
215 // happen. More complex but much more efficient.
216 return false;
217 }
218 }
219
220 return true;
221}
222
223/// Check if the pos^th variable can be expressed as a floordiv of an affine
224/// function of other variables (where the divisor is a positive constant).
225/// `foundRepr` contains a boolean for each variable indicating if the
226/// explicit representation for that variable has already been computed.
227/// Returns the `MaybeLocalRepr` struct which contains the indices of the
228/// constraints that can be expressed as a floordiv of an affine function. If
229/// the representation could be computed, `dividend` and `denominator` are set.
230/// If the representation could not be computed, the kind attribute in
231/// `MaybeLocalRepr` is set to None.
232MaybeLocalRepr presburger::computeSingleVarRepr(const IntegerRelation &cst,
233 ArrayRef<bool> foundRepr,
234 unsigned pos,
235 MutableArrayRef<MPInt> dividend,
236 MPInt &divisor) {
237 assert(pos < cst.getNumVars() && "invalid position");
238 assert(foundRepr.size() == cst.getNumVars() &&
239 "Size of foundRepr does not match total number of variables");
240 assert(dividend.size() == cst.getNumCols() && "Invalid dividend size");
241
242 SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
243 cst.getLowerAndUpperBoundIndices(pos, lbIndices: &lbIndices, ubIndices: &ubIndices, eqIndices: &eqIndices);
244 MaybeLocalRepr repr{};
245
246 for (unsigned ubPos : ubIndices) {
247 for (unsigned lbPos : lbIndices) {
248 // Attempt to get divison representation from ubPos, lbPos.
249 if (failed(result: getDivRepr(cst, pos, ubIneq: ubPos, lbIneq: lbPos, expr: dividend, divisor)))
250 continue;
251
252 if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
253 continue;
254
255 repr.kind = ReprKind::Inequality;
256 repr.repr.inequalityPair = {.lowerBoundIdx: ubPos, .upperBoundIdx: lbPos};
257 return repr;
258 }
259 }
260 for (unsigned eqPos : eqIndices) {
261 // Attempt to get divison representation from eqPos.
262 if (failed(result: getDivRepr(cst, pos, eqInd: eqPos, expr: dividend, divisor)))
263 continue;
264
265 if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
266 continue;
267
268 repr.kind = ReprKind::Equality;
269 repr.repr.equalityIdx = eqPos;
270 return repr;
271 }
272 return repr;
273}
274
275MaybeLocalRepr presburger::computeSingleVarRepr(
276 const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
277 SmallVector<int64_t, 8> &dividend, unsigned &divisor) {
278 SmallVector<MPInt, 8> dividendMPInt(cst.getNumCols());
279 MPInt divisorMPInt;
280 MaybeLocalRepr result =
281 computeSingleVarRepr(cst, foundRepr, pos, dividend: dividendMPInt, divisor&: divisorMPInt);
282 dividend = getInt64Vec(range: dividendMPInt);
283 divisor = unsigned(int64_t(divisorMPInt));
284 return result;
285}
286
287llvm::SmallBitVector presburger::getSubrangeBitVector(unsigned len,
288 unsigned setOffset,
289 unsigned numSet) {
290 llvm::SmallBitVector vec(len, false);
291 vec.set(I: setOffset, E: setOffset + numSet);
292 return vec;
293}
294
295void presburger::mergeLocalVars(
296 IntegerRelation &relA, IntegerRelation &relB,
297 llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
298 assert(relA.getSpace().isCompatible(relB.getSpace()) &&
299 "Spaces should be compatible.");
300
301 // Merge local vars of relA and relB without using division information,
302 // i.e. append local vars of `relB` to `relA` and insert local vars of `relA`
303 // to `relB` at start of its local vars.
304 unsigned initLocals = relA.getNumLocalVars();
305 relA.insertVar(kind: VarKind::Local, pos: relA.getNumLocalVars(),
306 num: relB.getNumLocalVars());
307 relB.insertVar(kind: VarKind::Local, pos: 0, num: initLocals);
308
309 // Get division representations from each rel.
310 DivisionRepr divsA = relA.getLocalReprs();
311 DivisionRepr divsB = relB.getLocalReprs();
312
313 for (unsigned i = initLocals, e = divsB.getNumDivs(); i < e; ++i)
314 divsA.setDiv(i, dividend: divsB.getDividend(i), divisor: divsB.getDenom(i));
315
316 // Remove duplicate divisions from divsA. The removing duplicate divisions
317 // call, calls `merge` to effectively merge divisions in relA and relB.
318 divsA.removeDuplicateDivs(merge);
319}
320
321SmallVector<MPInt, 8> presburger::getDivUpperBound(ArrayRef<MPInt> dividend,
322 const MPInt &divisor,
323 unsigned localVarIdx) {
324 assert(divisor > 0 && "divisor must be positive!");
325 assert(dividend[localVarIdx] == 0 &&
326 "Local to be set to division must have zero coeff!");
327 SmallVector<MPInt, 8> ineq(dividend.begin(), dividend.end());
328 ineq[localVarIdx] = -divisor;
329 return ineq;
330}
331
332SmallVector<MPInt, 8> presburger::getDivLowerBound(ArrayRef<MPInt> dividend,
333 const MPInt &divisor,
334 unsigned localVarIdx) {
335 assert(divisor > 0 && "divisor must be positive!");
336 assert(dividend[localVarIdx] == 0 &&
337 "Local to be set to division must have zero coeff!");
338 SmallVector<MPInt, 8> ineq(dividend.size());
339 std::transform(first: dividend.begin(), last: dividend.end(), result: ineq.begin(),
340 unary_op: std::negate<MPInt>());
341 ineq[localVarIdx] = divisor;
342 ineq.back() += divisor - 1;
343 return ineq;
344}
345
346MPInt presburger::gcdRange(ArrayRef<MPInt> range) {
347 MPInt gcd(0);
348 for (const MPInt &elem : range) {
349 gcd = presburger::gcd(a: gcd, b: abs(x: elem));
350 if (gcd == 1)
351 return gcd;
352 }
353 return gcd;
354}
355
356MPInt presburger::normalizeRange(MutableArrayRef<MPInt> range) {
357 MPInt gcd = gcdRange(range);
358 if ((gcd == 0) || (gcd == 1))
359 return gcd;
360 for (MPInt &elem : range)
361 elem /= gcd;
362 return gcd;
363}
364
365void presburger::normalizeDiv(MutableArrayRef<MPInt> num, MPInt &denom) {
366 assert(denom > 0 && "denom must be positive!");
367 MPInt gcd = presburger::gcd(a: gcdRange(range: num), b: denom);
368 for (MPInt &coeff : num)
369 coeff /= gcd;
370 denom /= gcd;
371}
372
373SmallVector<MPInt, 8> presburger::getNegatedCoeffs(ArrayRef<MPInt> coeffs) {
374 SmallVector<MPInt, 8> negatedCoeffs;
375 negatedCoeffs.reserve(N: coeffs.size());
376 for (const MPInt &coeff : coeffs)
377 negatedCoeffs.emplace_back(Args: -coeff);
378 return negatedCoeffs;
379}
380
381SmallVector<MPInt, 8> presburger::getComplementIneq(ArrayRef<MPInt> ineq) {
382 SmallVector<MPInt, 8> coeffs;
383 coeffs.reserve(N: ineq.size());
384 for (const MPInt &coeff : ineq)
385 coeffs.emplace_back(Args: -coeff);
386 --coeffs.back();
387 return coeffs;
388}
389
390SmallVector<std::optional<MPInt>, 4>
391DivisionRepr::divValuesAt(ArrayRef<MPInt> point) const {
392 assert(point.size() == getNumNonDivs() && "Incorrect point size");
393
394 SmallVector<std::optional<MPInt>, 4> divValues(getNumDivs(), std::nullopt);
395 bool changed = true;
396 while (changed) {
397 changed = false;
398 for (unsigned i = 0, e = getNumDivs(); i < e; ++i) {
399 // If division value is found, continue;
400 if (divValues[i])
401 continue;
402
403 ArrayRef<MPInt> dividend = getDividend(i);
404 MPInt divVal(0);
405
406 // Check if we have all the division values required for this division.
407 unsigned j, f;
408 for (j = 0, f = getNumDivs(); j < f; ++j) {
409 if (dividend[getDivOffset() + j] == 0)
410 continue;
411 // Division value required, but not found yet.
412 if (!divValues[j])
413 break;
414 divVal += dividend[getDivOffset() + j] * *divValues[j];
415 }
416
417 // We have some division values that are still not found, but are required
418 // to find the value of this division.
419 if (j < f)
420 continue;
421
422 // Fill remaining values.
423 divVal = std::inner_product(first1: point.begin(), last1: point.end(), first2: dividend.begin(),
424 init: divVal);
425 // Add constant.
426 divVal += dividend.back();
427 // Take floor division with denominator.
428 divVal = floorDiv(lhs: divVal, rhs: denoms[i]);
429
430 // Set div value and continue.
431 divValues[i] = divVal;
432 changed = true;
433 }
434 }
435
436 return divValues;
437}
438
439void DivisionRepr::removeDuplicateDivs(
440 llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
441
442 // Find and merge duplicate divisions.
443 // TODO: Add division normalization to support divisions that differ by
444 // a constant.
445 // TODO: Add division ordering such that a division representation for local
446 // variable at position `i` only depends on local variables at position <
447 // `i`. This would make sure that all divisions depending on other local
448 // variables that can be merged, are merged.
449 normalizeDivs();
450 for (unsigned i = 0; i < getNumDivs(); ++i) {
451 // Check if a division representation exists for the `i^th` local var.
452 if (denoms[i] == 0)
453 continue;
454 // Check if a division exists which is a duplicate of the division at `i`.
455 for (unsigned j = i + 1; j < getNumDivs(); ++j) {
456 // Check if a division representation exists for the `j^th` local var.
457 if (denoms[j] == 0)
458 continue;
459 // Check if the denominators match.
460 if (denoms[i] != denoms[j])
461 continue;
462 // Check if the representations are equal.
463 if (dividends.getRow(row: i) != dividends.getRow(row: j))
464 continue;
465
466 // Merge divisions at position `j` into division at position `i`. If
467 // merge fails, do not merge these divs.
468 bool mergeResult = merge(i, j);
469 if (!mergeResult)
470 continue;
471
472 // Update division information to reflect merging.
473 unsigned divOffset = getDivOffset();
474 dividends.addToColumn(sourceColumn: divOffset + j, targetColumn: divOffset + i, /*scale=*/1);
475 dividends.removeColumn(pos: divOffset + j);
476 dividends.removeRow(pos: j);
477 denoms.erase(CI: denoms.begin() + j);
478
479 // Since `j` can never be zero, we do not need to worry about overflows.
480 --j;
481 }
482 }
483}
484
485void DivisionRepr::normalizeDivs() {
486 for (unsigned i = 0, e = getNumDivs(); i < e; ++i) {
487 if (getDenom(i) == 0 || getDividend(i).empty())
488 continue;
489 normalizeDiv(num: getDividend(i), denom&: getDenom(i));
490 }
491}
492
493void DivisionRepr::insertDiv(unsigned pos, ArrayRef<MPInt> dividend,
494 const MPInt &divisor) {
495 assert(pos <= getNumDivs() && "Invalid insertion position");
496 assert(dividend.size() == getNumVars() + 1 && "Incorrect dividend size");
497
498 dividends.appendExtraRow(elems: dividend);
499 denoms.insert(I: denoms.begin() + pos, Elt: divisor);
500 dividends.insertColumn(pos: getDivOffset() + pos);
501}
502
503void DivisionRepr::insertDiv(unsigned pos, unsigned num) {
504 assert(pos <= getNumDivs() && "Invalid insertion position");
505 dividends.insertColumns(pos: getDivOffset() + pos, count: num);
506 dividends.insertRows(pos, count: num);
507 denoms.insert(I: denoms.begin() + pos, NumToInsert: num, Elt: MPInt(0));
508}
509
510void DivisionRepr::print(raw_ostream &os) const {
511 os << "Dividends:\n";
512 dividends.print(os);
513 os << "Denominators\n";
514 for (const MPInt &denom : denoms)
515 os << denom << " ";
516 os << "\n";
517}
518
519void DivisionRepr::dump() const { print(os&: llvm::errs()); }
520
521SmallVector<MPInt, 8> presburger::getMPIntVec(ArrayRef<int64_t> range) {
522 SmallVector<MPInt, 8> result(range.size());
523 std::transform(first: range.begin(), last: range.end(), result: result.begin(), unary_op: mpintFromInt64);
524 return result;
525}
526
527SmallVector<int64_t, 8> presburger::getInt64Vec(ArrayRef<MPInt> range) {
528 SmallVector<int64_t, 8> result(range.size());
529 std::transform(first: range.begin(), last: range.end(), result: result.begin(), unary_op: int64FromMPInt);
530 return result;
531}
532
533Fraction presburger::dotProduct(ArrayRef<Fraction> a, ArrayRef<Fraction> b) {
534 assert(a.size() == b.size() &&
535 "dot product is only valid for vectors of equal sizes!");
536 Fraction sum = 0;
537 for (unsigned i = 0, e = a.size(); i < e; i++)
538 sum += a[i] * b[i];
539 return sum;
540}
541
542/// Find the product of two polynomials, each given by an array of
543/// coefficients, by taking the convolution.
544std::vector<Fraction> presburger::multiplyPolynomials(ArrayRef<Fraction> a,
545 ArrayRef<Fraction> b) {
546 // The length of the convolution is the sum of the lengths
547 // of the two sequences. We pad the shorter one with zeroes.
548 unsigned len = a.size() + b.size() - 1;
549
550 // We define accessors to avoid out-of-bounds errors.
551 auto getCoeff = [](ArrayRef<Fraction> arr, unsigned i) -> Fraction {
552 if (i < arr.size())
553 return arr[i];
554 else
555 return 0;
556 };
557
558 std::vector<Fraction> convolution;
559 convolution.reserve(n: len);
560 for (unsigned k = 0; k < len; ++k) {
561 Fraction sum(0, 1);
562 for (unsigned l = 0; l <= k; ++l)
563 sum += getCoeff(a, l) * getCoeff(b, k - l);
564 convolution.push_back(x: sum);
565 }
566 return convolution;
567}
568
569bool presburger::isRangeZero(ArrayRef<Fraction> arr) {
570 return llvm::all_of(Range&: arr, P: [&](Fraction f) { return f == 0; });
571}
572

source code of mlir/lib/Analysis/Presburger/Utils.cpp